Jan 22, 2026
Continuing in our journey into Deep Learning Series, previosly we looked into RNNs architecture, now we explore the next evolution in sequence modeling: Long Short-Term Memory (LSTM) networks. So, imagine you’re reading a captivating mystery novel. As you progress through the chapters, you naturally remember important details from earlier pages, the suspicious character introduced in chapter 2, the clue hidden in chapter 5, and the red herring from chapter 8. This ability to retain and use information from the distant past while processing new information is exactly what makes you a good reader and what makes Long Short-Term Memory (LSTM) networks so powerful in the world of artificial intelligence.
Before diving into LSTMs, let’s understand recap into RNNs: Recurrent Neural Networks (RNNs). Unlike traditional neural networks that process inputs independently, RNNs are designed to work with sequences. Think of them as networks that can handle data that comes one piece at a time, like words in a sentence or stock prices over time.
The key innovation of RNNs is their ability to maintain a “hidden state”, a kind of memory that gets updated as they process each new input. This memory allows them to consider previous information when making decisions about current inputs.
However, basic RNNs have a critical flaw: short-term memory loss. Just like a person with amnesia might forget conversations from a few minutes ago, RNNs struggle to remember information from more than a few steps back in a sequence.
This happens due to a technical issue called the “vanishing gradient problem”. During training, the network learns by calculating gradients (essentially, learning signals) and propagating them backward through time. But these gradients get smaller and smaller as they travel back, eventually becoming so tiny that they can’t effectively update the earlier parts of the network.
Think of it like playing “telephone” with a hundred people, by the time the message reaches the end, it’s often completely different from what was originally said.
Long Short-Term Memory networks, introduced by Hochreiter and Schmidhuber in 1997, solve this problem by introducing a sophisticated memory management system. Instead of having simple memory that gets overwritten at each step, LSTMs have a “memory cell” with three specialized gates that control what information to remember, forget, and output.
Let us take a look into a simple analogy of a conveyor belt system in a factory to understand LSTMs intuitively. So, imagine a smart conveyor belt system in a factory. This isn’t just any conveyor belt, it’s equipped with three intelligent controllers:
As items (information) flow along the belt (through time), each controller examines the current item and the belt’s contents, making intelligent decisions about what to keep, add, or output.
Let’s break down an LSTM cell into its core components:
Previous Hidden State (h_{t-1}) ──┐
│
Current Input (x_t) ─────────────┼──► [LSTM Cell] ──► Current Hidden State (h_t)
│ │
Previous Cell State (C_{t-1}) ────┘ └──► Current Cell State (C_t)
The cell state is like the main conveyor belt, it carries information through time. Unlike simple RNN hidden states, the cell state has a more direct path through the network, making it easier for gradients to flow backward during training.
The forget gate looks at the previous hidden state and current input to decide what information should be discarded from the cell state. It outputs values between 0 and 1 for each piece of information, where 0 means “completely forget” and 1 means “completely keep.”
The input gate has two parts:
Together, they determine what new information should be stored in the cell state.
The output gate decides what parts of the cell state should be output as the hidden state. It filters the cell state to produce the final output for this time step.
Let’s dive into the equations that make LSTMs work. Don’t worry, we’ll keep the math intuitive! Here are the key equations that govern an LSTM cell:
\(f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)\)
The sigmoid function (σ) acts like a dimmer switch, outputting values between 0 and 1. This gate looks at the previous hidden state and current input to decide what to forget.
\(i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)\) \(\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)\)
The input gate $(i_t)$ decides which values to update, while the candidate values $(C̃_t)$ represent new information that could be stored. The tanh function outputs values between -1 and 1, providing the actual content to potentially add.
\(C_t = f_t * C_{t-1} + i_t * \tilde{C}_t\)
This is where the magic happens! The new cell state is a combination of:
\(o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)\) \(h_t = o_t * \tanh(C_t)\)
The output gate determines what parts of the cell state to output as the hidden state.
Let’s implement a practical example where we are predicting the next number in a simple arithmetic sequence. We’ll create an LSTM that can learn patterns like “2, 4, 6, 8, ?” and predict the next number (10).
We’ll train our LSTM on arithmetic sequences:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
class SimpleLSTM(nn.Module):
"""
A simple LSTM model for sequence prediction.
This model consists of:
1. An LSTM layer that processes sequences
2. A fully connected layer that maps LSTM output to predictions
"""
def __init__(self, input_size=1, hidden_size=50, num_layers=1, output_size=1):
super(SimpleLSTM, self).__init__()
# Store parameters
self.hidden_size = hidden_size
self.num_layers = num_layers
# LSTM layer: processes sequences and maintains memory
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
# Fully connected layer: maps LSTM output to final prediction
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
"""
Forward pass through the network.
Args:
x: Input sequences of shape (batch_size, sequence_length, input_size)
Returns:
predictions: Output predictions of shape (batch_size, output_size)
"""
# Initialize hidden state and cell state
batch_size = x.size(0)
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
# Pass through LSTM
# lstm_out shape: (batch_size, sequence_length, hidden_size)
# hidden shape: (num_layers, batch_size, hidden_size)
lstm_out, (hidden, cell) = self.lstm(x, (h0, c0))
# Use the output from the last time step
# lstm_out[:, -1, :] selects the last time step for each sequence in the batch
last_output = lstm_out[:, -1, :]
# Pass through fully connected layer to get final prediction
predictions = self.fc(last_output)
return predictions
def create_sequences(data, seq_length):
"""
Create input-output sequence pairs for training.
Args:
data: List of numbers
seq_length: Length of input sequences
Returns:
X: Input sequences
y: Target outputs (next number in sequence)
"""
X, y = [], []
for i in range(len(data) - seq_length):
# Input: sequence of seq_length numbers
sequence = data[i:i + seq_length]
# Output: the next number
target = data[i + seq_length]
X.append(sequence)
y.append(target)
return np.array(X), np.array(y)
def generate_arithmetic_data(num_sequences=1000, seq_length=5):
"""
Generate arithmetic sequences for training.
Creates sequences like:
- [1, 2, 3, 4, 5] with step=1
- [2, 4, 6, 8, 10] with step=2
- [5, 10, 15, 20, 25] with step=5
"""
all_sequences = []
for _ in range(num_sequences):
# Random starting point and step size
start = np.random.randint(1, 10)
step = np.random.randint(1, 6)
# Generate arithmetic sequence
sequence = [start + i * step for i in range(seq_length + 1)]
all_sequences.extend(sequence)
return all_sequences
def train_model(model, train_loader, num_epochs=100):
"""
Train the LSTM model.
"""
# Loss function: Mean Squared Error for regression
criterion = nn.MSELoss()
# Optimizer: Adam with learning rate 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Track training progress
train_losses = []
print("Starting training...")
print(f"Training on {len(train_loader.dataset)} sequences for {num_epochs} epochs")
for epoch in range(num_epochs):
model.train() # Set model to training mode
epoch_loss = 0.0
for batch_X, batch_y in train_loader:
# Reset gradients
optimizer.zero_grad()
# Forward pass
predictions = model(batch_X)
# Calculate loss
loss = criterion(predictions.squeeze(), batch_y)
# Backward pass
loss.backward()
# Update weights
optimizer.step()
epoch_loss += loss.item()
# Calculate average loss for this epoch
avg_loss = epoch_loss / len(train_loader)
train_losses.append(avg_loss)
# Print progress every 20 epochs
if (epoch + 1) % 20 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
print("Training completed!")
return train_losses
def test_model(model, test_cases):
"""
Test the trained model on specific sequences.
"""
model.eval() # Set model to evaluation mode
print("\n" + "="*50)
print("TESTING THE TRAINED MODEL")
print("="*50)
with torch.no_grad(): # Disable gradient computation for testing
for i, (sequence, expected) in enumerate(test_cases, 1):
# Prepare input (add batch and feature dimensions)
input_seq = torch.FloatTensor(sequence).unsqueeze(0).unsqueeze(-1)
# Get prediction
prediction = model(input_seq)
predicted_value = prediction.item()
print(f"\nTest Case {i}:")
print(f"Input sequence: {sequence}")
print(f"Expected next number: {expected}")
print(f"Model prediction: {predicted_value:.2f}")
print(f"Error: {abs(predicted_value - expected):.2f}")
# Main execution
if __name__ == "__main__":
print("LSTM Sequence Prediction Tutorial")
print("="*40)
# 1. Generate training data
print("\n1. Generating training data...")
raw_data = generate_arithmetic_data(num_sequences=500, seq_length=5)
# Create input-output pairs
sequence_length = 4 # Use 4 numbers to predict the 5th
X, y = create_sequences(raw_data, sequence_length)
print(f"Created {len(X)} training sequences")
print(f"Each input sequence has {sequence_length} numbers")
print(f"Sample input: {X[0]} → target: {y[0]}")
# 2. Prepare data for PyTorch
print("\n2. Preparing data for PyTorch...")
# Convert to PyTorch tensors and add feature dimension
X_tensor = torch.FloatTensor(X).unsqueeze(-1) # Shape: (num_samples, seq_length, 1)
y_tensor = torch.FloatTensor(y) # Shape: (num_samples,)
# Create dataset and data loader
dataset = TensorDataset(X_tensor, y_tensor)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
print(f"Data prepared: {X_tensor.shape} input, {y_tensor.shape} targets")
# 3. Initialize model
print("\n3. Initializing LSTM model...")
model = SimpleLSTM(input_size=1, hidden_size=50, num_layers=1, output_size=1)
print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")
print("Model architecture:")
print(model)
# 4. Train the model
print("\n4. Training the model...")
train_losses = train_model(model, train_loader, num_epochs=100)
# 5. Test the model
test_cases = [
([1, 2, 3, 4], 5), # Simple counting
([2, 4, 6, 8], 10), # Even numbers
([5, 10, 15, 20], 25), # Multiples of 5
([3, 6, 9, 12], 15), # Multiples of 3
([1, 3, 5, 7], 9), # Odd numbers
]
test_model(model, test_cases)
# 6. Visualize training progress
print("\n6. Visualizing training progress...")
plt.figure(figsize=(10, 6))
plt.plot(train_losses)
plt.title('Training Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()
print("\nTutorial completed!")
print("The LSTM has learned to predict the next number in arithmetic sequences.")
Now that we’ve seen how LSTMs work in practice, let’s discuss their advantages and disadvantages:
Long-term Memory: LSTMs can effectively capture dependencies across long sequences, remembering important information from much earlier time steps.
Vanishing Gradient Solution: The cell state pathway allows gradients to flow more easily during backpropagation, solving the vanishing gradient problem that plagues basic RNNs.
Selective Memory: The gate mechanism allows the network to selectively remember, forget, and output information, making it highly flexible for different sequence modeling tasks.
Versatility: LSTMs work well across many domains: natural language processing, speech recognition, time series forecasting, and more.
Well-Established: Extensive research and proven track record in many applications, with lots of resources and pre-trained models available.
Computational Complexity: LSTMs are significantly more computationally expensive than simple RNNs due to their complex gate structure.
Training Time: The additional parameters and computations mean longer training times, especially on large datasets.
Memory Requirements: LSTMs require more memory to store the additional parameters and intermediate states.
Hyperparameter Sensitivity: LSTMs have many hyperparameters (hidden size, number of layers, dropout rates) that need careful tuning.
Sequential Processing: Unlike transformers, LSTMs process sequences step-by-step, making them harder to parallelize during training.
Newer Alternatives: More recent architectures like Transformers often outperform LSTMs on many tasks, especially in natural language processing.
Now that you’ve just learned about one of the most important architectures in deep learning. LSTMs revolutionized sequence modeling by solving the memory problem that plagued early neural networks, paving the way for many of the AI applications we use today.
It’s important to remember that the field of AI is constantly evolving. While LSTMs are powerful, newer architectures like Transformers are pushing the boundaries even further. The key takeaway is that understanding the fundamentals helps you appreciate these more advanced concepts.
Keep experimenting, keep learning, and most importantly, keep building! The world of AI is constantly evolving, and there’s always something new and exciting to discover.
For hands-on practice, check out the companion notebooks - Understanding LSTM Networks with PyTorch