LSTM and GRU Deep Dive
It is 2:47 AM in a document review center in Chicago. A legal AI system is processing a 400-page merger agreement for a Fortune 100 acquisition. The system's task is straightforward on paper: flag every clause that conflicts with a clause introduced earlier in the document. Section 4.2 defines the indemnification ceiling. Section 312.7 - more than three hundred pages later - references a carve-out that quietly overrides it. The vanilla RNN powering the system cannot hold that information across three hundred pages of intervening text. It has no mechanism to keep a piece of information alive while hundreds of irrelevant tokens wash over it. By the time it reaches section 312.7, section 4.2 is gone. The conflict goes undetected. The acquisition closes. Six months later, during integration, the legal team discovers a $400M liability that should have been caught in review.
This is not a hypothetical. It is a class of failure that legal tech companies encountered repeatedly between 2017 and 2020 when they tried to apply vanilla RNNs to long-form contract analysis. The fundamental problem was architectural, not a matter of training data or compute. A standard RNN compresses everything it has seen into a single fixed-size hidden state vector. As the sequence grows longer, earlier information gets overwritten by later information through simple matrix multiplication. The architecture has no way to say "keep this, discard that." It has no doors, no valves, no selectivity. Everything flows through at the same rate, and important long-range signals decay to noise.
The engineers working on these systems knew about a 1997 paper by Sepp Hochreiter and Jürgen Schmidhuber. That paper had been sitting mostly dormant for over a decade. The paper was titled "Long Short-Term Memory." It proposed an architecture with explicit gating mechanisms - learnable controls over what information to store, what to discard, and what to surface as output. When the deep learning renaissance arrived around 2012, practitioners rediscovered LSTM and found that it had been the answer all along. The legal AI teams rebuilt their systems using LSTM-based encoders. The long-range dependency problem was not completely solved - nothing in machine learning is - but it was brought under control to the point where the systems became commercially viable.
This lesson is about understanding exactly why vanilla RNNs fail and how LSTM and its lighter-weight successor GRU were designed to fix those failures at the architectural level. Not at a high-level conceptual gloss, but at the level of the actual equations, the gradient flow dynamics, and the engineering decisions you need to make when deploying these architectures in production. You will understand what each gate does, why it exists, and what goes wrong when you misconfigure it. By the end, you will be able to implement an LSTM cell from scratch in NumPy, deploy a bidirectional LSTM in PyTorch, and answer the interview questions that separate candidates who have used these models from candidates who understand them.
The document review system was eventually rebuilt. The new version correctly flags the section 4.2 / section 312.7 conflict because the LSTM's cell state acts as a persistent memory that the network can choose to maintain across arbitrarily long spans of intervening text. The forget gate learns to keep the indemnification ceiling in memory because, during training on thousands of similar contracts, it learned that such clauses almost always matter later. This is the power of learned selectivity.
Why This Exists - The Vanishing Gradient Problem Demands a Solution
If you have read the previous lesson on vanilla RNNs, you have seen why the vanishing gradient problem arises. If you have not, here is the precise argument.
A vanilla RNN updates its hidden state at each time step as:
When you backpropagate through a sequence of length , the gradient of the loss with respect to a hidden state at time step involves a chain of Jacobian multiplications across every step from down to :
Each factor is the Jacobian of the tanh activation composed with the weight matrix . The spectral norm (largest singular value) of this Jacobian determines what happens to gradient magnitudes. If the dominant singular value is less than 1, the product of such terms shrinks exponentially to zero - this is vanishing gradients. If it is greater than 1, the product explodes exponentially - this is exploding gradients.
In practice, both happen. Exploding gradients can be managed with gradient clipping. Vanishing gradients cannot be fixed by any training trick applied to the vanilla RNN architecture. The architecture itself is the problem. The information highway from early time steps to late time steps is a series of tanh squashing functions and matrix multiplications that destroy signal. No learning rate schedule, no optimizer, no regularization technique fixes this. You need a different architecture.
The specific failure mode is dependency length. Vanilla RNNs can learn dependencies across 10-20 time steps with careful tuning. Beyond that, the gradients become too small to drive meaningful weight updates. For tasks where relevant context spans hundreds of tokens - legal documents, scientific papers, long-form conversations, multi-minute audio - vanilla RNNs simply cannot learn the right relationships.
The LSTM solution was to introduce an explicit, separate memory pathway - the cell state - that flows through time with only additive interactions, not multiplicative ones. Addition does not cause exponential decay. A gradient flowing backward through an addition is unchanged in magnitude. This single architectural insight - separating memory (cell state) from working state (hidden state) and connecting them additively - is the core reason LSTM works where vanilla RNNs fail.
Historical Context - From a Rejected Paper to the Backbone of Modern NLP
Sepp Hochreiter was a diploma student at the Technical University of Munich in 1991 when he wrote his diploma thesis analyzing why deep networks failed to learn long-range dependencies. His analysis was rigorous and largely correct. Four years later, he was a PhD student under Jürgen Schmidhuber at IDSIA in Lugano, Switzerland. Together they developed the Long Short-Term Memory architecture as a direct response to the gradient flow problem Hochreiter had identified.
The paper - "Long Short-Term Memory" - was submitted to Neural Computation and published in 1997. The reception was muted. The deep learning field was in a trough. Backpropagation was considered by many researchers to be fundamentally limited. Support Vector Machines were ascendant. The paper attracted a small community of readers but had almost no industrial impact for over a decade.
The LSTM architecture sat largely dormant until 2007 when Alex Graves and Jürgen Schmidhuber demonstrated that LSTM with Connectionist Temporal Classification (CTC) could achieve state-of-the-art results on handwriting recognition without explicit segmentation. This was a significant result, but the field's attention was elsewhere.
The inflection point came in 2012. After AlexNet demonstrated that deep networks trained on GPUs could win ImageNet by a large margin, the research community's interest in deep learning reignited. Researchers began systematically evaluating architectures that had been languishing. LSTM was rediscovered and found to work extremely well for language modeling, speech recognition, and machine translation. Google's speech recognition system switched to LSTM in 2012 and reduced error rates by over 20% in a single architectural change.
By 2014, LSTM was everywhere. Ilya Sutskever, Oriol Vinyals, and Quoc Le published "Sequence to Sequence Learning with Neural Networks" - a paper that stacked four layers of LSTM to perform English-to-French translation and achieved results competitive with phrase-based systems. That same year, Kyunghyun Cho, Yoshua Bengio, and collaborators published "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" (Cho et al. 2014), which introduced the Gated Recurrent Unit (GRU) as a simpler alternative to LSTM. The GRU reduced four gates to two and eliminated the cell state entirely, achieving comparable performance with fewer parameters on many tasks.
The period from 2014 to 2017 was the golden age of LSTM. Every major NLP system - sentiment analysis, named entity recognition, machine translation, question answering - was built on LSTM. Google Translate switched from phrase-based to LSTM sequence-to-sequence in September 2016, a change that reduced translation errors by approximately 60% overnight on some language pairs.
Then, in 2017, Vaswani et al. published "Attention Is All You Need," introducing the Transformer architecture. For tasks where data was plentiful and computational budgets were large, Transformers outperformed LSTMs because they could attend to any position in the sequence in parallel rather than processing it sequentially. LSTM's sequential nature - which was a feature for learning ordered dependencies - became a liability for parallelization on modern hardware.
But LSTM did not disappear. It remains the preferred architecture for tasks that require streaming inference, edge deployment with constrained memory, or processing very long sequences where quadratic attention costs are prohibitive. As of 2024, LSTM is actively used in production at companies including Apple (on-device Siri), Google (on-device speech), and countless specialized industrial applications.
Core Concept - The Memory Cell Analogy
Before any equations, build the intuition with a concrete mental model.
Imagine you are a meticulous reader working through a long legal contract. You carry a notepad with you. When you start reading, the notepad is blank.
As you read, you make decisions on every page:
- Should I cross out anything on my notepad? Maybe the information from page 3 has now been superseded by a new clause on page 47. You draw a line through it.
- Is there anything on this page worth adding? You write down the key claim numbers, the dollar figures, the parties being referenced.
- What is most relevant to report right now? Based on what is on your notepad and what you just read, you compose a brief summary of the current state of affairs.
Your notepad is the LSTM cell state. Your brief summary is the hidden state. The three decisions - what to erase, what to add, what to report - are the forget gate, the input gate, and the output gate. The fourth gate is the input modulation gate, which controls what candidate information even gets considered for writing to the notepad.
This is not a perfect analogy, but it captures the essential behavior. The cell state is the long-term memory. The hidden state is the short-term working output. The gates are learned sigmoid-activated linear transformations that produce values between 0 and 1 - 0 means "block completely," 1 means "pass completely through." The network learns during training what to block and what to pass for the task at hand.
Gate-by-Gate Breakdown - The Four Gates of LSTM
An LSTM cell takes three inputs at each time step : the current input , the previous hidden state , and the previous cell state . It produces two outputs: the new hidden state and the new cell state .
The notation convention: means the concatenation of the two vectors. is the sigmoid function. is the hyperbolic tangent. is element-wise multiplication (the Hadamard product).
The Forget Gate
Why it exists first: Before deciding what to remember from the current input, the cell needs to decide what from the previous memory to let go. The forget gate applies a multiplicative mask to the previous cell state.
- has shape
(hidden_size, hidden_size + input_size) - has shape
(hidden_size,), with each element in - Element of close to 0 means "forget cell state dimension "
- Element of close to 1 means "preserve cell state dimension "
Concrete example: in a language model, when the model encounters the word "he" after a paragraph about a female character, the forget gate should fire strongly to clear the stored gender information before storing the new subject. The network learns this behavior from training on text where pronoun agreement must be maintained.
The Input Gate
Why it exists: After deciding what to forget, the cell must decide which dimensions of the new candidate memory are worth writing at all.
This gate also produces values in . Element of controls how much of the candidate value (computed next) gets added to cell state dimension .
The Candidate Cell State
Why it exists: You need to compute what you might write to the cell state before deciding how much of it to actually write. The here squashes candidate values to , which regularizes what gets added.
Some textbooks call this (candidate cell state). Others call it to distinguish it from the actual cell state . It is the raw material that the input gate then filters.
The Cell State Update
Why this is the key equation: The cell state update is the step that solves the vanishing gradient problem. Note that the update is additive, not multiplicative with a nonlinearity applied to the entire state.
When you backpropagate through this equation, the gradient with respect to is:
The forget gate values. Not a product of sigmoid and tanh Jacobians as in a vanilla RNN - just the forget gate itself. If the network learns to keep close to 1 for relevant dimensions, gradients flow backward almost unchanged. This is the LSTM constant error carousel: the cell state provides an approximately constant gradient highway.
The Output Gate
Why it exists: Just because something is stored in the cell state does not mean it is relevant to the current output. The output gate filters which aspects of the cell state get exposed as the hidden state.
The on squashes cell state values to before the output gate masks them. The resulting is the hidden state passed to the next time step and also the output of the LSTM cell at time .
Complete Forward Pass Summary
At each time step , given , , :
f_t = σ(W_f · [h_{t-1}, x_t] + b_f) # forget gate
i_t = σ(W_i · [h_{t-1}, x_t] + b_i) # input gate
g_t = tanh(W_g · [h_{t-1}, x_t] + b_g) # candidate cell
C_t = f_t ⊙ C_{t-1} + i_t ⊙ g_t # cell state update
o_t = σ(W_o · [h_{t-1}, x_t] + b_o) # output gate
h_t = o_t ⊙ tanh(C_t) # hidden state output
The four weight matrices , , , are each of shape (hidden_size, hidden_size + input_size). In practice, these are often concatenated into a single weight matrix of shape (4 * hidden_size, hidden_size + input_size) and split afterward for computational efficiency - this is exactly what PyTorch does internally.
GRU - Simplifying LSTM with Two Gates
The Gated Recurrent Unit, introduced by Cho et al. in "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" (2014), asks: can we achieve similar performance with a simpler architecture?
The GRU makes two structural changes compared to LSTM:
- Eliminate the cell state. GRU merges the cell state and hidden state into a single hidden state. The hidden state carries both long-term and short-term information.
- Replace four gates with two gates. The reset gate and the update gate replace forget, input, cell update, and output.
GRU Equations
- update gate
- reset gate
- candidate hidden state
- hidden state update
The update gate controls how much of the previous hidden state to carry forward versus how much to update from the candidate. When is 1, the new hidden state is entirely the candidate. When is 0, the hidden state is entirely preserved from the previous step. Notice the symmetry: weights the old state, weights the new candidate. The network cannot increase both simultaneously - it trades off memory and update.
The reset gate controls how much of the previous hidden state to use when computing the candidate. When is 0, the candidate is computed as if the previous hidden state did not exist - it becomes a fresh representation based only on . This allows the GRU to "reset" its memory when it encounters a new topic or context.
GRU vs LSTM - When to Choose Which
GRU has fewer parameters (3 weight matrices vs 4 in LSTM). On datasets with fewer than approximately 10,000 training sequences, this can matter significantly for regularization. GRU trains faster due to simpler computation.
LSTM tends to outperform GRU on tasks requiring very precise long-range memory - tasks where the network must hold specific information for hundreds of steps and then use it exactly. The separate cell state and output gate give LSTM more expressive control over what to surface. The performance difference is often small (1-3% on standard benchmarks), which is why GRU gained wide adoption.
Empirical guidance from Chung et al. (2014), "Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling":
- Neither GRU nor LSTM consistently dominates on all tasks
- GRU is preferred when training data is limited and training speed matters
- LSTM is preferred when task accuracy is paramount and resources allow
- Both dramatically outperform vanilla RNNs on all tasks tested
NumPy from Scratch - LSTM Cell Forward Pass
The following implementation computes a single LSTM cell forward pass in NumPy. This is not production code - it is designed to make every operation explicit and traceable.
import numpy as np
def sigmoid(x: np.ndarray) -> np.ndarray:
"""Numerically stable sigmoid."""
return np.where(
x >= 0,
1 / (1 + np.exp(-x)),
np.exp(x) / (1 + np.exp(x))
)
class LSTMCell:
"""
Single LSTM cell - pure NumPy, explicit gate-by-gate.
Parameters
----------
input_size : int - dimensionality of input x_t
hidden_size : int - dimensionality of hidden state h_t and cell state C_t
seed : int - random seed for reproducibility
"""
def __init__(self, input_size: int, hidden_size: int, seed: int = 42):
rng = np.random.default_rng(seed)
self.hidden_size = hidden_size
# Xavier/Glorot initialization - keeps activations in a healthy range
scale = np.sqrt(2.0 / (input_size + hidden_size))
concat_size = hidden_size + input_size
# Weight matrices: shape (hidden_size, hidden_size + input_size)
# PyTorch concatenates input and hidden before multiplying;
# we do the same here for clarity.
self.W_f = rng.normal(0, scale, (hidden_size, concat_size)) # forget
self.W_i = rng.normal(0, scale, (hidden_size, concat_size)) # input
self.W_g = rng.normal(0, scale, (hidden_size, concat_size)) # candidate
self.W_o = rng.normal(0, scale, (hidden_size, concat_size)) # output
# Biases: shape (hidden_size,)
# Initialize forget gate bias to 1.0 - critical for gradient flow.
# See production notes section for why this matters.
self.b_f = np.ones(hidden_size) # forget gate bias: 1.0
self.b_i = np.zeros(hidden_size)
self.b_g = np.zeros(hidden_size)
self.b_o = np.zeros(hidden_size)
def forward(
self,
x_t: np.ndarray,
h_prev: np.ndarray,
C_prev: np.ndarray,
) -> tuple:
"""
Single LSTM cell forward pass.
Parameters
----------
x_t : (input_size,) - current input
h_prev : (hidden_size,) - previous hidden state
C_prev : (hidden_size,) - previous cell state
Returns
-------
h_t : (hidden_size,) - new hidden state
C_t : (hidden_size,) - new cell state
cache : dict of intermediate values (needed for backprop)
"""
# Concatenate previous hidden state and current input
# Shape: (hidden_size + input_size,)
combined = np.concatenate([h_prev, x_t])
# --- Forget Gate ---
# Decides what fraction of each cell state dimension to keep
f_t = sigmoid(self.W_f @ combined + self.b_f)
# --- Input Gate ---
# Decides which candidate values to write to cell state
i_t = sigmoid(self.W_i @ combined + self.b_i)
# --- Candidate Cell State ---
# The raw new information, squashed to (-1, 1)
g_t = np.tanh(self.W_g @ combined + self.b_g)
# --- Cell State Update ---
# Additive highway - this is where gradient flow is preserved
C_t = f_t * C_prev + i_t * g_t
# --- Output Gate ---
# Decides which parts of cell state to expose as hidden state
o_t = sigmoid(self.W_o @ combined + self.b_o)
# --- Hidden State ---
h_t = o_t * np.tanh(C_t)
cache = {
"combined": combined,
"f_t": f_t,
"i_t": i_t,
"g_t": g_t,
"o_t": o_t,
"C_prev": C_prev,
"C_t": C_t,
"h_t": h_t,
}
return h_t, C_t, cache
def run_lstm_on_sequence(
lstm_cell: LSTMCell,
X: np.ndarray,
) -> tuple:
"""
Run an LSTM cell over an entire sequence.
Parameters
----------
lstm_cell : LSTMCell instance
X : (seq_len, input_size) - the input sequence
Returns
-------
all_hidden : (seq_len, hidden_size) - hidden state at every step
C_T : (hidden_size,) - final cell state
"""
seq_len, _ = X.shape
hidden_size = lstm_cell.hidden_size
# Initialize hidden and cell state to zeros
h = np.zeros(hidden_size)
C = np.zeros(hidden_size)
all_hidden = np.zeros((seq_len, hidden_size))
for t in range(seq_len):
h, C, _ = lstm_cell.forward(X[t], h, C)
all_hidden[t] = h
return all_hidden, C
# --- Demo ---
if __name__ == "__main__":
INPUT_SIZE = 8
HIDDEN_SIZE = 16
SEQ_LEN = 50
cell = LSTMCell(input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE)
# Simulate a sequence: e.g., word embeddings of length 8
rng = np.random.default_rng(0)
X = rng.standard_normal((SEQ_LEN, INPUT_SIZE))
all_h, final_C = run_lstm_on_sequence(cell, X)
print(f"Hidden states shape : {all_h.shape}") # (50, 16)
print(f"Final cell state : {final_C.shape}") # (16,)
print(f"h_T mean : {all_h[-1].mean():.4f}")
print(f"C_T mean : {final_C.mean():.4f}")
# Inspect gate values at a single step to verify bounds
h0 = np.zeros(HIDDEN_SIZE)
C0 = np.zeros(HIDDEN_SIZE)
_, _, cache = cell.forward(X[0], h0, C0)
print("\nGate value ranges at t=0:")
print(f" forget gate f_t : min={cache['f_t'].min():.3f}, max={cache['f_t'].max():.3f}")
print(f" input gate i_t : min={cache['i_t'].min():.3f}, max={cache['i_t'].max():.3f}")
print(f" output gate o_t : min={cache['o_t'].min():.3f}, max={cache['o_t'].max():.3f}")
print(f" candidate g_t : min={cache['g_t'].min():.3f}, max={cache['g_t'].max():.3f}")
Run this and you will see the hidden state and cell state shapes, and that the gate values are correctly bounded in their expected ranges. The forget gate initialized with bias 1.0 will produce values closer to , biasing the network toward remembering - this is the correct initialization behavior for early training.
PyTorch Implementation - Production-Grade LSTM
Basic nn.LSTM Usage
import torch
import torch.nn as nn
# Define a 2-layer bidirectional LSTM
lstm = nn.LSTM(
input_size=128, # embedding dimension
hidden_size=256, # hidden state size per direction
num_layers=2, # stacked LSTM layers
batch_first=True, # input shape: (batch, seq_len, input_size)
bidirectional=True, # forward + backward pass
dropout=0.3, # applied between layers (not on last layer)
)
# For bidirectional LSTM with hidden_size=256:
# output shape : (batch, seq_len, 256 * 2) = (batch, seq_len, 512)
# h_n shape : (num_layers * 2, batch, 256) = (4, batch, 256)
# c_n shape : (num_layers * 2, batch, 256) = (4, batch, 256)
batch_size = 32
seq_len = 100
x = torch.randn(batch_size, seq_len, 128) # (batch, seq, input)
output, (h_n, c_n) = lstm(x)
print(f"Output shape : {output.shape}") # (32, 100, 512)
print(f"h_n shape : {h_n.shape}") # (4, 32, 256)
print(f"c_n shape : {c_n.shape}") # (4, 32, 256)
# Extract the final hidden state (last layer, forward + backward)
# For bidirectional: h_n[-2] is forward, h_n[-1] is backward
h_forward = h_n[-2] # (batch, 256)
h_backward = h_n[-1] # (batch, 256)
final_hidden = torch.cat([h_forward, h_backward], dim=-1) # (batch, 512)
Handling Variable-Length Sequences with Packed Sequences
In real workloads, sequences in a batch have different lengths. Padding short sequences with zeros is wasteful and corrupts LSTM computation - the LSTM processes padding tokens and updates its hidden state based on them. PyTorch's packed sequence API fixes this.
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class SequenceClassifier(nn.Module):
"""
LSTM-based classifier that handles variable-length sequences correctly.
Uses packed sequences to avoid processing padding tokens.
"""
def __init__(
self,
vocab_size: int,
embed_dim: int,
hidden_size: int,
num_classes: int,
num_layers: int = 2,
dropout: float = 0.3,
bidirectional: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
dropout=dropout if num_layers > 1 else 0.0,
)
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size * self.num_directions, num_classes),
)
def forward(
self,
token_ids: torch.Tensor, # (batch, max_seq_len) - padded
lengths: torch.Tensor, # (batch,) - actual lengths, sorted descending
) -> torch.Tensor:
# Embed tokens: (batch, max_seq_len, embed_dim)
embedded = self.embedding(token_ids)
# Pack the sequence - removes padding from computation
# enforce_sorted=True requires lengths to be in descending order
packed = pack_padded_sequence(
embedded,
lengths.cpu(),
batch_first=True,
enforce_sorted=True,
)
# Run LSTM on packed sequence
packed_output, (h_n, c_n) = self.lstm(packed)
# Unpack if you need per-step outputs (e.g., for sequence labeling)
# output: (batch, max_seq_len, hidden * num_directions)
output, output_lengths = pad_packed_sequence(packed_output, batch_first=True)
# For classification, use the final hidden state
# h_n shape: (num_layers * num_directions, batch, hidden_size)
if self.bidirectional:
# Concatenate final forward and backward hidden states (last layer)
h_forward = h_n[-2] # last layer, forward direction
h_backward = h_n[-1] # last layer, backward direction
final_h = torch.cat([h_forward, h_backward], dim=-1)
else:
final_h = h_n[-1] # last layer
# Classify: (batch, num_classes)
logits = self.classifier(final_h)
return logits
def train_epoch(
model: SequenceClassifier,
dataloader,
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> float:
model.train()
total_loss = 0.0
criterion = nn.CrossEntropyLoss()
for batch in dataloader:
token_ids, lengths, labels = batch
# Sort by length descending - required for pack_padded_sequence
lengths, sort_idx = lengths.sort(descending=True)
token_ids = token_ids[sort_idx].to(device)
labels = labels[sort_idx].to(device)
optimizer.zero_grad()
logits = model(token_ids, lengths)
loss = criterion(logits, labels)
# Gradient clipping - essential for LSTM stability
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
# Instantiate the model
model = SequenceClassifier(
vocab_size=30000,
embed_dim=128,
hidden_size=256,
num_classes=5,
num_layers=2,
dropout=0.3,
bidirectional=True,
)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
GRU Implementation for Comparison
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class GRUClassifier(nn.Module):
"""
GRU-based classifier - simpler than LSTM, often comparable performance.
Key difference: GRU has no cell state c_n, only hidden state h_n.
"""
def __init__(
self,
vocab_size: int,
embed_dim: int,
hidden_size: int,
num_classes: int,
num_layers: int = 2,
dropout: float = 0.3,
bidirectional: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.gru = nn.GRU(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=bidirectional,
dropout=dropout if num_layers > 1 else 0.0,
)
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(hidden_size * self.num_directions, num_classes),
)
def forward(
self,
token_ids: torch.Tensor,
lengths: torch.Tensor,
) -> torch.Tensor:
embedded = self.embedding(token_ids)
packed = pack_padded_sequence(
embedded, lengths.cpu(), batch_first=True, enforce_sorted=True
)
# GRU returns (output, h_n) - no cell state
packed_output, h_n = self.gru(packed)
if self.bidirectional:
final_h = torch.cat([h_n[-2], h_n[-1]], dim=-1)
else:
final_h = h_n[-1]
return self.classifier(final_h)
# Compare parameter counts
lstm_model = SequenceClassifier(
vocab_size=30000, embed_dim=128, hidden_size=256,
num_classes=5, num_layers=2, bidirectional=True,
)
gru_model = GRUClassifier(
vocab_size=30000, embed_dim=128, hidden_size=256,
num_classes=5, num_layers=2, bidirectional=True,
)
lstm_params = sum(p.numel() for p in lstm_model.parameters())
gru_params = sum(p.numel() for p in gru_model.parameters())
print(f"LSTM parameters : {lstm_params:,}")
print(f"GRU parameters : {gru_params:,}")
print(f"Reduction : {(lstm_params - gru_params) / lstm_params:.1%}")
# Typical output: GRU has roughly 25% fewer parameters in the recurrent layers
Mermaid Diagram - LSTM Cell Architecture
Production Engineering Notes
Gradient Clipping is Non-Negotiable
Even though LSTM largely solved vanishing gradients, it did not solve exploding gradients. The forget gate can allow gradients to accumulate over many steps, and on long sequences with large learning rates, gradient norms can spike dramatically. Always clip gradients when training LSTMs.
# After loss.backward(), before optimizer.step():
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
The max_norm=1.0 value is a widely used default from Pascanu et al. (2013), "On the difficulty of training recurrent neural networks." Values between 0.5 and 5.0 are common depending on task. If your training loss shows periodic spikes, gradient norms are likely the cause - log them to verify.
# How to monitor gradient norm during training
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5
# Log total_norm to your experiment tracker before clipping
Forget Gate Bias Initialization
Initialize the forget gate bias to 1.0, not 0.0. This is critical for early training stability. A bias of 0.0 means , which means the network starts by forgetting half of everything on every step. A bias of 1.0 means , biasing toward remembering. The network will learn to adjust from there. Jozefowicz et al. (2015), "An Empirical Exploration of Recurrent Network Architectures," validated this empirically across a range of tasks.
PyTorch's nn.LSTM initializes all biases to 0 by default. Fix this after instantiation:
def init_lstm_forget_bias(lstm: nn.LSTM, value: float = 1.0) -> None:
"""
Initialize LSTM forget gate biases to a positive value.
PyTorch stores LSTM biases in four quarters: [W_ii, W_if, W_ig, W_io]
The forget gate is the second quarter: indices [hidden_size : 2*hidden_size]
This applies to both bias_ih and bias_hh for each layer.
"""
for name, param in lstm.named_parameters():
if "bias" in name:
n = param.shape[0]
start = n // 4 # forget gate starts at 1/4
end = n // 2 # forget gate ends at 2/4
param.data[start:end].fill_(value)
lstm = nn.LSTM(input_size=128, hidden_size=256, num_layers=2, batch_first=True)
init_lstm_forget_bias(lstm, value=1.0)
When LSTMs Beat Transformers in Production
Streaming and online inference: Transformers with full attention over previous context require all previous tokens to be available. For real-time streaming tasks - speech recognition frame by frame, live sensor data, trading systems - the LSTM processes each new input in time and memory. The Transformer requires memory and computation per step for context length .
Edge and on-device deployment: Transformer models are large. Apple's on-device Siri (as of 2023 reports) uses LSTM-based acoustic models because LSTM inference requires a fixed, small amount of memory per step regardless of context length. For microcontrollers with 4MB RAM, LSTM is the only viable recurrent architecture.
Very long sequences with limited compute: The quadratic attention complexity of Transformers means that processing a 10,000-token document requires 100x more attention computation than a 1,000-token document. An LSTM processes it in time. For tasks like whole-genome sequence analysis or multi-hour audio recordings, this difference is decisive.
Low-latency requirements with tight SLA: A stacked two-layer bidirectional LSTM on a CPU can process short sequences with sub-millisecond latency. A Transformer of comparable accuracy is typically 5-20x slower on CPU because its attention mechanism does not benefit from simple loop-based optimization the way LSTM does.
Layer Count and Stacking
Two stacked LSTM layers is the most common production configuration. Single-layer LSTMs underfit on complex tasks. Three or more layers rarely provide meaningful gains and slow training significantly. When three layers do help, add aggressive dropout (0.3-0.5) between layers to prevent overfitting.
Bidirectional vs Unidirectional
Bidirectional LSTMs process the sequence forward and backward simultaneously, allowing each hidden state to incorporate context from both past and future tokens. For classification and labeling tasks (sentiment analysis, NER, POS tagging), bidirectional almost always outperforms unidirectional. For generation tasks (language modeling, next-step prediction), bidirectional cannot be used - there is no future context available at inference time.
Common Mistakes
:::danger Forgetting to sort sequences before pack_padded_sequence
pack_padded_sequence with enforce_sorted=True (the default) requires the batch to be sorted by sequence length in descending order. If your data is not sorted, you will get a cryptic RuntimeError or, worse, silently incorrect outputs. Always sort before packing.
# Correct pattern - always sort before packing
lengths, sort_idx = lengths.sort(descending=True)
token_ids = token_ids[sort_idx]
labels = labels[sort_idx]
To avoid the sorting requirement entirely, pass enforce_sorted=False. PyTorch handles the sorting internally. This is slightly slower but eliminates the bug class entirely.
:::
:::danger Processing padding tokens without packed sequences
If you call lstm(padded_input) without packing, the LSTM processes every padding token as a real input. On sequences padded to length 512 with actual lengths of 20-30 tokens, you are spending over 90% of your LSTM computation on zeros. Worse, the LSTM's hidden state is updated by these zeros, and the hidden state you extract at output[:, -1, :] is the state after processing hundreds of padding tokens - not after the actual last real token.
Always use pack_padded_sequence in production, or extract the hidden state at the correct position:
def get_last_real_hidden(
output: torch.Tensor, # (batch, seq_len, hidden_size)
lengths: torch.Tensor, # (batch,) - actual sequence lengths
) -> torch.Tensor: # (batch, hidden_size)
"""Extract hidden state at the last REAL token, not the last padded position."""
batch_size = output.shape[0]
# lengths - 1 gives the 0-indexed position of the last real token
idx = (lengths - 1).unsqueeze(1).unsqueeze(2)
idx = idx.expand(batch_size, 1, output.shape[2])
return output.gather(1, idx).squeeze(1)
:::
:::warning Not applying gradient clipping
Training LSTM without gradient clipping often results in loss spikes, NaN losses, or training divergence - especially on long sequences or with large learning rates. Add clip_grad_norm_ before every optimizer step. It takes one line and prevents hours of debugging.
:::
:::warning Using the default forget gate bias initialization
PyTorch initializes all LSTM biases to 0. This includes the forget gate bias. Starting with a zero forget gate bias means the network initially discards roughly 50% of its memory at every step, making early gradient signals for long-range dependencies very weak. Initialize forget gate biases to 1.0 with the utility function shown in the production notes section above.
:::
:::warning Using bidirectional LSTM for autoregressive generation
A bidirectional LSTM cannot be used for tasks that generate sequences step by step - language modeling, text generation, time-series forecasting. The backward pass requires future tokens, which do not exist at generation time. Using bidirectional for generation is a shape-compatible error that PyTorch will not catch. The model will train fine on teacher-forced data (where future tokens are provided), and the bug only manifests at inference time when you have no backward hidden state to initialize.
:::
:::warning Stacking too many LSTM layers without dropout
Three or more LSTM layers overfit aggressively on datasets smaller than roughly 100K sequences. Each additional layer adds parameters and depth without any regularization benefit unless you add dropout between layers. PyTorch's nn.LSTM dropout parameter applies dropout only between layers, not after the final layer - add an explicit nn.Dropout after the LSTM output if you need it on the final layer's output.
:::
Interview Q&A
Q1: Explain exactly why LSTM solves the vanishing gradient problem that vanilla RNNs have.
The vanishing gradient problem in vanilla RNNs arises because backpropagation through time requires multiplying together a chain of Jacobians - one per time step. Each Jacobian involves the derivative of the tanh activation composed with the weight matrix. The spectral norm of each Jacobian tends to be less than 1, so the product of many such terms decays exponentially to zero.
LSTM introduces the cell state, which has a fundamentally different gradient dynamic. The cell state update is:
The gradient of the loss with respect to through this equation is simply - the forget gate vector. There is no Jacobian of a nonlinearity applied to the entire state vector. The gradient flowing backward through the cell state is multiplied by the forget gate at each step. If the network learns for relevant dimensions, the gradient flows back with near-constant magnitude - this is what Hochreiter and Schmidhuber called the "constant error carousel."
The key distinction: vanilla RNNs have multiplicative interactions with nonlinearities at every step. LSTM's cell state has additive interactions, and the multiplicative factor (the forget gate) is learned to be close to 1 when information needs to be preserved over long ranges.
Q2: What is the difference between the hidden state and the cell state in an LSTM? Why do we need both?
The cell state is the long-term memory. It is designed to flow through time with minimal interaction - only scaled by the forget gate and incremented by the gated candidate. It does not directly feed into any nonlinearity that would squash its gradients over long distances.
The hidden state is the short-term working output. It is computed as - a filtered, squashed version of the cell state. It is what gets passed to the next layer and to any downstream network (classifier, decoder, etc.).
The reason for having both is that you need two distinct behaviors: a gradient highway for learning long-range dependencies (cell state), and a rich expressive output at each step (hidden state). The output gate controls which aspects of the cell state are surfaced as hidden state output at each step. This means the network can maintain information in the cell state without necessarily outputting it as the hidden state - it can "hold" information until it is contextually relevant to surface.
In GRU, these two are merged into a single hidden state. GRU achieves similar long-range behavior through the update gate's term, which allows linear flow-through. The tradeoff is that GRU has less independent control over what is stored versus what is output at each step.
Q3: When would you choose a Transformer over an LSTM, and when would you go the other direction?
Choose a Transformer when:
- You have large amounts of training data (tens of thousands to millions of labeled examples)
- The task benefits from full bidirectional context (BERT-style classification)
- Sequence lengths are moderate (up to a few thousand tokens)
- You have GPU or TPU resources for both training and inference
- Parallelism during training is important - Transformer training is highly parallelizable; LSTM is sequential
Choose LSTM when:
- You are deploying on edge hardware with strict memory constraints
- You need streaming inference where each new input must be processed in time without a growing KV cache
- Your sequences are extremely long (tens of thousands of tokens) where attention cost is prohibitive
- Your dataset is small (fewer than roughly 10K examples) where Transformer overparameterization hurts generalization
- Your task is inherently autoregressive and low-latency (real-time speech, sensor processing, embedded systems)
A practical heuristic from production experience: if you have GPU inference budget and more than 50K labeled examples, start with a pre-trained Transformer. If you need sub-millisecond CPU inference or are processing streams, start with LSTM. The performance gap between the two on many classification tasks is 2-5%, which is often not worth the 10x inference cost of a Transformer.
Q4: What is the purpose of the forget gate bias initialization to 1.0, and what happens if you skip it?
Without the forget gate bias initialized to 1.0, the forget gate starts at - the network begins by discarding approximately half of its cell state at every step. This means that in early training, when the network is trying to learn which long-range signals matter, the gradient signal from those long-range dependencies is attenuated by a factor of where is the number of steps back. On a sequence of 100 steps, that is - effectively zero.
Initializing to 1.0 means the forget gate starts at . The network begins with a bias toward remembering, and gradient signals from earlier time steps can propagate backward with reasonable magnitude. The network then learns from these gradients to selectively forget where appropriate.
In practice, training without this initialization often leads to models that fail to learn any long-range dependencies, producing behavior similar to a short-memory vanilla RNN despite being an LSTM. The model will still converge on local patterns and achieve reasonable short-range accuracy - making the failure mode hard to detect without examining gradient norms at different sequence distances. Jozefowicz et al. (2015) validated this initialization across multiple benchmarks and found consistent improvements.
Q5: Explain packed sequences in PyTorch - what problem they solve and how they work mechanically.
In a training batch, sequences typically have different lengths. The standard approach is to pad shorter sequences with a special padding token so all sequences in the batch have the same length. This enables tensor operations on the batch.
The problem: if you pass the padded batch directly to nn.LSTM, the LSTM processes every padding token as if it were a real input. This wastes computation proportional to the amount of padding, and - more critically - it corrupts the hidden state. After processing padding tokens, the LSTM's hidden state reflects the network's response to zeros, not the end of the real sequence. If you extract h_n (the final hidden state), it is the state after processing all the padding, not after the last real token.
Packed sequences solve this by packing only the real tokens across the batch into a single flat tensor, sorted by length. PyTorch's PackedSequence format stores two things: (1) the flat data tensor containing all real tokens across all timesteps, and (2) a batch_sizes tensor indicating how many real tokens exist at each timestep across the batch.
The LSTM processes this packed format efficiently. At each timestep, only the sequences that have not yet ended participate in the computation. The result is mechanically correct - the final hidden state reflects the network state after the last real token in each sequence, not after padding.
Mechanically: pack_padded_sequence takes a padded tensor of shape (batch, max_len, input_size) and a lengths vector, and reorders and flattens the data. pad_packed_sequence takes the packed output and reconstructs a padded tensor, which you need if you want per-step outputs for tasks like sequence labeling or named entity recognition.
Q6: How does a bidirectional LSTM work, and for what tasks is it appropriate?
A bidirectional LSTM runs two independent LSTM networks over the sequence: one processes tokens from left to right (forward), the other from right to left (backward). At each time step , the bidirectional LSTM produces a hidden state that is the concatenation of the forward LSTM's (which has seen tokens 0 through ) and the backward LSTM's (which has seen tokens through ).
In PyTorch, this is handled transparently by setting bidirectional=True. The output shape becomes (batch, seq_len, 2 * hidden_size) and h_n has shape (num_layers * 2, batch, hidden_size) where even indices are forward layers and odd indices are backward layers.
The result: every position in the sequence has access to the full left and right context at that position. For tasks like named entity recognition, part-of-speech tagging, or sentence classification, this full context dramatically improves performance - the model knows what comes before and after every word simultaneously.
The constraint: bidirectional LSTM cannot be used for autoregressive generation (language modeling, time series forecasting step by step) because the backward pass requires future tokens that are not available at inference time. It is appropriate for encoder-side tasks (encoding a full sequence into a representation) but not for decoder-side tasks (generating sequence elements one at a time). Confusing these use cases is one of the most common mistakes engineers make when transitioning from Transformer encoders to recurrent architectures.
Summary
LSTM and GRU are not historical curiosities. They are production-grade tools that remain the right choice for a specific set of constraints: streaming inference, edge deployment, very long sequences, and small datasets where Transformer overparameterization is a liability.
The architectural insight at the heart of LSTM - separating memory (cell state) from output (hidden state) and connecting them through additive updates with learned gating - is one of the most elegant solutions in the history of neural architecture design. Hochreiter and Schmidhuber's 1997 paper spent a decade waiting for the field to catch up. When it did, LSTM powered the transformation of speech recognition, machine translation, and NLP from academic curiosities into products used by billions of people.
The four gates - forget, input, candidate, output - each serve a distinct purpose. The forget gate clears stale memory. The input gate controls what new information gets written. The candidate provides the raw new information. The output gate controls what portion of the memory is exposed as the working state. Understanding each gate's purpose, not just its equation, is what separates engineers who deploy LSTMs correctly from engineers who tune hyperparameters randomly until training converges.
GRU simplifies this to two gates and merges cell and hidden state. It achieves comparable performance with fewer parameters on most tasks, making it the practical default when LSTM's additional expressiveness is not clearly needed.
In production: always clip gradients, always initialize forget gate biases to 1.0, always use packed sequences for variable-length inputs, and always verify that your "final hidden state" is actually extracted from the last real token - not the last padding token. These four practices prevent the most common and the most expensive LSTM bugs in production systems.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the LSTM Gates Visualization demo on the EngineersOfAI Playground - no code required.
:::
