Feed-Forward Layers
Reading time: ~30 min · Interview relevance: High · Target roles: ML Engineer, AI Engineer, Research Engineer
The Hidden Engineer
In 2021, a team at the Allen Institute for AI was trying to understand where factual knowledge lives in a language model. They had a simple question: when GPT-2 correctly completes "The capital of France is ___", where in the network is the fact "Paris" stored?
The intuitive answer was attention: attention is the mechanism that can retrieve information by querying against stored representations. But when the researchers ran probing experiments and activation patching, they found something unexpected. The factual associations weren't in the attention layers - they were predominantly in the feed-forward layers.
Geva et al. (2021) published "Transformer Feed-Forward Layers Are Key-Value Memories" and showed that each neuron in the FFN acts as a key-value pair. The first layer's weights are "keys" that pattern-match input patterns. The second layer's weights are "values" that retrieve the associated fact. The attention mechanism contextualizes - the FFN memorizes.
This was a significant reframing. Most explanations of transformers focus on attention as the interesting part and treat the FFN as "just a MLP between attention layers." In reality, the FFN is where the model's factual knowledge lives. When you fine-tune a model to update its knowledge, you're primarily updating the FFN weights. When a model hallucinates a fact, it's often because the wrong FFN neuron fired.
Understanding the feed-forward layer is not just about architecture. It's about understanding where an LLM's knowledge actually lives.
Why Feed-Forward Layers Exist
After multi-head attention, each token has a contextualized representation - it has "seen" all other tokens and aggregated information from them. But this representation is still a linear combination of value vectors. Attention is fundamentally a linear operation (a weighted sum).
To give the model nonlinear processing and the ability to compute complex functions of the input, each transformer block includes a position-wise feed-forward network. This is a two-layer MLP applied independently to each position:
The original paper uses ReLU activation, but modern models use GeLU or SwiGLU.
"Position-wise" means the same FFN (same weights) is applied to each position independently. There is no mixing across positions in the FFN - that's attention's job. The FFN processes each token's representation individually.
The Expansion Factor: Why 4x?
The FFN in the original transformer expands from to (4×) and then projects back:
Why expand? The nonlinear activation (ReLU) is what gives the FFN its expressive power. But the capacity of a linear layer's function space is limited by its output dimension. By expanding to a larger intermediate dimension and then contracting, you create a "projection, activate, project back" pattern that:
- Increases representation capacity: More neurons = more patterns that can be stored
- Creates a bottleneck: The compression from back to forces the model to distill information
- Enables selective activation: With ReLU, typically ~50% of the neurons are active for any given input. Different inputs activate different subsets - this is how the FFN selectively retrieves knowledge
The 4× factor is empirically robust. Too small () and the FFN is a bottleneck. Too large ( without other changes) and parameters are wasted.
Real model FFN dimensions:
| Model | Ratio | Activation | ||
|---|---|---|---|---|
| Transformer-base | 512 | 2,048 | 4× | ReLU |
| BERT-base | 768 | 3,072 | 4× | GeLU |
| GPT-2 | 768 | 3,072 | 4× | GeLU |
| GPT-3 | 12,288 | 49,152 | 4× | GeLU |
| LLaMA-2 7B | 4,096 | 11,008 | ~2.7× | SwiGLU |
| LLaMA-2 70B | 8,192 | 28,672 | 3.5× | SwiGLU |
Note: LLaMA uses a slightly different ratio because SwiGLU uses 3 weight matrices instead of 2, so the intermediate dimension is adjusted to keep parameter count similar to 4× with 2 matrices.
The FFN as Key-Value Memory
Geva et al. (2021) formalized this view: treat the FFN as a collection of key-value memories.
For a two-layer FFN:
- Each column of is a key - a pattern in the input space. The activation tells you how much input matches key .
- Each row of is a value - the output associated with that key.
- The FFN output is a weighted sum of values, where weights are the activation strengths.
This means: when the input pattern-matches key , the corresponding value (row of ) is added to the output with weight .
Geva et al. showed that specific neurons in the FFN consistently fire for specific factual patterns. For example, a neuron that fires when the input contains "Eiffel Tower" might have a corresponding value that promotes "Paris" and "France" in the output distribution.
This is why fact editing research focuses on the FFN: to change what a model "knows" about a fact, you must change the FFN weights in specific layers - not the attention patterns.
Activation Functions: ReLU to SwiGLU
ReLU (Original Transformer)
Simple, fast, well-understood. Approximately 50% of neurons are dead (output 0) for any given input. This sparsity is a feature - it creates selectivity. The dead neurons for input A may be active for input B.
Problem: "dying ReLU" - neurons that are always negative for all inputs and contribute nothing. This can waste capacity.
GeLU (BERT, GPT-2, GPT-3)
where is the cumulative distribution function of the standard normal distribution.
GeLU smoothly approximates ReLU with a stochastic interpretation: it multiplies the input by the probability that a standard normal sample is less than . This gives a smooth curve that is approximately 0 for negative inputs and approximately linear for large positive inputs, with a smooth transition around 0.
Empirically: GeLU consistently outperforms ReLU by 0.5-1% on benchmark tasks in transformers. The smoothness aids gradient flow.
SwiGLU (LLaMA, PaLM)
Noam Shazeer (2020) introduced a gated variant called SwiGLU (Swish-Gated Linear Unit):
where and is element-wise multiplication.
Instead of one expansion matrix , SwiGLU uses two expansion matrices and :
- produces the "gate": after applying Swish, this controls how much of each feature passes through
- produces the "content": the actual values that get gated
This is a form of attention-within-the-FFN - the gate decides which "memories" are relevant.
Why SwiGLU outperforms GeLU: At equal parameter count, SwiGLU achieves lower perplexity. The gating mechanism provides better gradient flow and more selective activation. Shazeer's paper showed consistent improvements across model sizes.
Trade-off: SwiGLU uses 3 weight matrices (, , ) vs ReLU/GeLU's 2 (, ). For the same parameter budget, you must reduce the intermediate dimension to compensate - hence LLaMA's instead of .
import torch
import torch.nn as nn
import torch.nn.functional as F
class FFNWithReLU(nn.Module):
"""Original transformer FFN - simple 2-layer MLP with ReLU."""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.W1 = nn.Linear(d_model, d_ff)
self.W2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.W2(self.dropout(F.relu(self.W1(x))))
class FFNWithGeLU(nn.Module):
"""BERT/GPT-2 style FFN with GeLU activation."""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.W1 = nn.Linear(d_model, d_ff)
self.W2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.W2(self.dropout(F.gelu(self.W1(x))))
class FFNWithSwiGLU(nn.Module):
"""
LLaMA/PaLM style FFN with SwiGLU activation.
Uses 3 matrices and typically d_ff = 8/3 * d_model (approx 2.7x)
to match parameter count of standard 4x FFN.
Reference: Shazeer (2020) "GLU Variants Improve Transformer"
"""
def __init__(self, d_model: int, d_ff: int = None, dropout: float = 0.0):
super().__init__()
# Default: 8/3 * d_model, rounded to nearest multiple of 256 for efficiency
if d_ff is None:
d_ff = int(8 * d_model / 3)
d_ff = (d_ff + 255) // 256 * 256 # round up to multiple of 256
# Gate projection (applies Swish)
self.W_gate = nn.Linear(d_model, d_ff, bias=False)
# Content projection
self.W_up = nn.Linear(d_model, d_ff, bias=False)
# Output projection
self.W_down = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = F.silu(self.W_gate(x)) # Swish activation (SiLU = Sigmoid Linear Unit = Swish)
up = self.W_up(x) # Content
fused = gate * up # Gating: element-wise multiplication
return self.W_down(self.dropout(fused))
# Compare parameter counts
d_model = 4096
relu_ffn = FFNWithReLU(d_model, d_ff=4 * d_model)
gelu_ffn = FFNWithGeLU(d_model, d_ff=4 * d_model)
swiglu_ffn = FFNWithSwiGLU(d_model) # uses ~2.7x by default
relu_params = sum(p.numel() for p in relu_ffn.parameters())
gelu_params = sum(p.numel() for p in gelu_ffn.parameters())
swiglu_params = sum(p.numel() for p in swiglu_ffn.parameters())
print(f"ReLU FFN params (4x): {relu_params:,}") # 4096*16384 + 16384*4096 = ~134M
print(f"GeLU FFN params (4x): {gelu_params:,}") # Same
print(f"SwiGLU FFN params (~2.7x): {swiglu_params:,}") # 3 matrices, ~same total
# Test forward pass
batch, seq = 2, 32
x = torch.randn(batch, seq, d_model)
out_relu = relu_ffn(x)
out_gelu = gelu_ffn(x)
out_swiglu = swiglu_ffn(x)
print(f"\nOutput shapes (should all be {x.shape}):")
print(f"ReLU: {out_relu.shape}")
print(f"GeLU: {out_gelu.shape}")
print(f"SwiGLU: {out_swiglu.shape}")
# Activation sparsity check (what fraction of neurons fire > 0?)
with torch.no_grad():
x_test = torch.randn(100, d_model)
# ReLU activations
relu_acts = F.relu(relu_ffn.W1(x_test))
sparsity = (relu_acts == 0).float().mean()
print(f"\nReLU sparsity: {sparsity:.2%}") # ~50% zeros
# GeLU activations (smooth, not truly sparse)
gelu_acts = F.gelu(gelu_ffn.W1(x_test))
near_zero = (gelu_acts.abs() < 0.01).float().mean()
print(f"GeLU near-zero: {near_zero:.2%}") # Fewer zeros, smoother
Mixture of Experts: Sparse FFN
Mixture of Experts (MoE) generalizes the FFN: instead of one FFN applied to every token, use "expert" FFNs and route each token to only of them (typically out of 8 or 64).
The Mixtral 8×7B model uses 8 FFN experts per layer, activating 2 per token. With 8 experts each of size equivalent to a 7B model's FFN, the total parameters are ~47B, but each forward pass only uses ~12B active parameters - giving GPT-3.5 quality at the compute cost of a much smaller model.
Why does MoE work? Different tokens benefit from different types of processing. A medical question benefits from "medical expert" FFN patterns. Code completion benefits from "code expert" patterns. Routing tokens to specialized experts allows each expert to specialize, using the same compute budget more efficiently.
The challenge: expert load balancing. Without intervention, the router tends to always pick the same 2 experts for everything - a collapse. MoE training requires auxiliary load-balancing losses to ensure all experts are used.
class MoEFFN(nn.Module):
"""
Simplified Mixture of Experts FFN.
Each expert is an independent FFN. Top-k routing.
"""
def __init__(
self,
d_model: int,
d_ff: int,
num_experts: int = 8,
top_k: int = 2,
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Router: maps token -> expert weights
self.router = nn.Linear(d_model, num_experts, bias=False)
# N independent expert FFNs
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff, bias=False),
nn.SiLU(),
nn.Linear(d_ff, d_model, bias=False),
)
for _ in range(num_experts)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, seq, d_model = x.shape
x_flat = x.view(-1, d_model) # (batch*seq, d_model)
# Router: compute expert weights and select top-k
router_logits = self.router(x_flat) # (batch*seq, num_experts)
router_weights = torch.softmax(router_logits, dim=-1)
# Select top-k experts per token
top_k_weights, top_k_indices = torch.topk(router_weights, self.top_k, dim=-1)
# Normalize the top-k weights to sum to 1
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
# Apply each expert to the tokens routed to it
output = torch.zeros_like(x_flat)
for expert_idx in range(self.num_experts):
# Find which tokens were routed to this expert (and at what k-slot)
for k_pos in range(self.top_k):
mask = (top_k_indices[:, k_pos] == expert_idx)
if mask.sum() == 0:
continue
expert_output = self.experts[expert_idx](x_flat[mask])
output[mask] += top_k_weights[mask, k_pos].unsqueeze(-1) * expert_output
return output.view(batch, seq, d_model)
Production Engineering Notes
FFN Dominates Parameter Count
Attention receives most of the conceptual focus, but the FFN usually dominates parameter count:
For a transformer block with , , , :
- Attention: params
- FFN: params
FFN is 2× the attention parameter count in this configuration. For 32 layers:
- Attention: ~2.1B params
- FFN: ~4.3B params
This means that for parameter-efficient fine-tuning (LoRA), targeting both attention and FFN layers gives much better coverage than attention alone. Many implementations default to LoRA on attention only - adding FFN coverage often significantly improves results.
Memory vs Compute Trade-offs
For inference optimization:
- Quantization (INT8, INT4) works well for FFN weight matrices - they're large enough that quantization error averages out
- Pruning: FFN neurons can be pruned based on activation frequency. Neurons that rarely fire (low activation across many samples) can often be removed without significant quality loss
- FFN caching: For certain architectures, FFN outputs for common prefixes can be cached (similar to KV cache)
Common Mistakes
:::danger Using bias=True in LLaMA-style FFNs
Modern LLMs (LLaMA, Mistral, Gemma) use bias=False in their FFN layers. Adding bias terms increases parameters slightly but, more importantly, can interfere with certain normalization schemes. Always check the reference implementation's bias settings when replicating a model.
:::
:::warning Wrong SwiGLU dimension calculation
If you implement SwiGLU with the standard 4× intermediate dimension (instead of the adjusted ~2.7×), you'll have ~50% more parameters in the FFN than intended. LLaMA's FFN has d_ff = int(8/3 * d_model) rounded up, not 4 * d_model. Check your parameter count against published model specs.
:::
:::tip Profiling FFN compute
The FFN's two matrix multiplications (x @ W1 and h @ W2) are typically the most time-consuming operations in transformer inference, especially for large batch sizes. Profile with torch.profiler to verify before optimizing. For long sequences with Flash Attention, attention may become the bottleneck - but for typical generation batch sizes (1-32), FFN dominates.
:::
Interview Q&A
Q1: What is the purpose of the feed-forward layer in a transformer? Why is it needed if we already have attention?
Answer: Attention and FFN serve different computational roles:
Attention is a weighted averaging operation - its output is always a linear combination of value vectors. It is good at routing information from relevant positions, but it cannot compute nonlinear functions of the combined information.
The FFN provides nonlinear processing and pointwise computation. After attention aggregates context, the FFN processes each token's representation independently with a two-layer MLP. The nonlinearity (ReLU, GeLU, or SwiGLU) allows the model to compute complex, nonlinear transformations.
Equally important: the FFN is where factual knowledge is stored (Geva et al., 2021). The weights of the FFN act as a key-value memory - specific patterns in the input activate specific neurons, which retrieve associated knowledge. Attention routes information; the FFN retrieves facts and applies transformations.
Without the FFN, transformers would be dramatically less capable - they'd be limited to computing linear functions of attention-weighted mixtures.
Q2: Why does the FFN expand to 4× the model dimension? What would happen with 2× or 8×?
Answer: The expansion creates a large intermediate space where nonlinear computation happens. More neurons = more patterns that can be stored.
- Too small (1× or 2×): The bottleneck limits the number of key-value memories the FFN can store. Model capacity is wasted on other components (attention, embeddings) that can't be utilized effectively without a capable FFN.
- 4×: Empirically sweet spot from the original paper. Tested in ablation studies.
- Too large (8×+): Without changing other dimensions, this creates parameter imbalance - the FFN becomes too large relative to the embedding dimension. Parameters in the FFN aren't as useful because the bottleneck is now the dimension.
The 4× rule is actually somewhat flexible in modern architectures. LLaMA uses ~2.7× with SwiGLU (which has 3 matrices), matching the effective capacity at the same parameter budget.
Q3: Explain SwiGLU and why it outperforms ReLU/GeLU.
Answer: SwiGLU is a gated activation: .
Instead of applying a single nonlinearity to the expanded features, SwiGLU uses a learned gate: one linear projection creates the gate (through Swish activation), another creates the content, and they're multiplied element-wise.
Why it outperforms:
-
Better gradient flow: The multiplication gate allows gradients to flow through the content path even when the gate is partially closed. Unlike hard cutoffs (ReLU), the gradient always has a path.
-
Adaptive computation: The gate is input-dependent - it can "open" specific features for specific inputs. This is more flexible than a fixed nonlinearity.
-
Empirical results: Shazeer (2020) showed consistent 0.5-1 perplexity improvements across model sizes compared to ReLU/GeLU. The LLaMA paper confirmed this at scale.
Shazeer's explanation: GLU-based variants "might be more effective at learning piecewise-linear function approximations, which may be the type of behavior learned in FFN layers."
Q4: What is Mixture of Experts and when would you use it?
Answer: MoE replaces the single FFN in each transformer block with expert FFNs, routing each token to the top- experts (typically ) based on a learned router.
When to use:
- You want GPT-4 quality at GPT-2 compute cost (Mixtral 8×7B achieves this)
- You have multiple specialized tasks that benefit from specialized processing
- You have the infrastructure to distribute experts across devices (MoE requires expert parallelism)
Why it works: experts out of means only of parameters are active per token. You get 4× the parameters at 1× the compute - conditional computation.
Challenges:
- Expert collapse: without load-balancing loss, the router routes everything to 2 experts
- Communication overhead: in distributed training, tokens may need to be routed to different devices
- Evaluation complexity: active parameters ≠ total parameters - both metrics matter
MoE is increasingly common in frontier models: GPT-4 is reportedly MoE, Mixtral 8×7B and 8×22B are open MoE, and Google's Gemini models use MoE variants.
Q5: Given that FFN weights store factual knowledge, what does this mean for fine-tuning and knowledge editing?
Answer: Several practical implications:
Fine-tuning for knowledge: When you fine-tune to add new facts (e.g., updating a model's knowledge cutoff), the primary learning happens in FFN weights. Attention weights are relatively stable. This means LoRA applied to FFN layers is important for knowledge-intensive fine-tuning, not just attention layers.
Targeted knowledge editing: Methods like ROME (Rank-One Model Editing) and MEMIT directly modify specific FFN weights to change factual associations without full fine-tuning. If you want to change "The Eiffel Tower is in ___" from "Paris" to "Berlin" (hypothetically), you can identify and surgically edit the specific FFN weights responsible.
Catastrophic forgetting: When you fine-tune on new data, you overwrite FFN weights. Old facts stored in those weights may be forgotten. This is why retrieval-augmented generation (RAG) often outperforms fine-tuned models for up-to-date knowledge - RAG doesn't overwrite weights.
Hallucinations: Many hallucinations come from incorrect or weakly stored facts in FFN weights. The model's FFN fires on patterns similar to (but not the same as) a real fact, retrieving a plausible but wrong association. This is why retrieval grounding helps - instead of relying on FFN memory, you inject the correct fact via context.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Feed-Forward Network Sub-Layer demo on the EngineersOfAI Playground - no code required.
:::
