Multi-Head Attention
Reading time: ~35 min · Interview relevance: Essential · Target roles: ML Engineer, AI Engineer, Research Engineer
The Meeting Room Problem
David's team at a financial NLP startup had built a solid sentence embedder using single-head self-attention. It worked well for most tasks, but struggled on a specific type of query: "Who authorized the transaction, and when?"
The problem wasn't the individual words - "who" and "when" were represented correctly. The problem was that a single attention head could only ask one "question" per forward pass. When the model tried to simultaneously capture "what entity performed an action" and "at what point in time," it couldn't. It would attend to the agent when it needed the time, or vice versa.
David's team tried adding more model depth - extra layers, wider projections. It helped slightly, but the fundamental issue remained. A single attention head learns a single type of relationship per layer. If the task requires simultaneously tracking multiple relationship types (agent, patient, time, location, modality), one head cannot capture all of them.
When they switched to multi-head attention, the diagnostic was immediate. They visualized the attention patterns across 8 heads. Head 3 had learned to track syntactic subjects. Head 6 tracked temporal expressions. Head 7 tracked possessive relationships. The model was naturally decomposing the problem into independent subproblems and solving each in parallel.
This is what multi-head attention does: it runs several attention operations simultaneously, each learning to focus on a different type of relationship. The outputs are concatenated and projected back, giving a final representation that incorporates multiple perspectives on the same input.
Why Single-Head Attention Is Insufficient
In the previous lesson, we implemented single-head attention. Each layer computes one attention function:
The problem: , , and are all projected from the same -dimensional space into a single -dimensional subspace. The model has one set of linear projections, meaning it can only compute one type of query-key relationship per layer.
Language requires simultaneously reasoning about multiple relationship types:
- Syntactic: What is the subject? What is the object?
- Semantic: What is the meaning? What is the context?
- Coreference: Which pronoun refers to which entity?
- Positional: What is near? What is far?
A single attention head picks one "focus" per layer. To capture all these relationships, you'd need either many layers (depth) or a better mechanism. Multi-head attention is that mechanism: run attention operations in parallel, each in a lower-dimensional subspace, then combine.
The Multi-Head Attention Formula
where each head is:
The parameters:
- - query projection for head
- - key projection for head
- - value projection for head
- - output projection
In the original transformer: .
The key observation: each head has its own projection matrices. Head 1 has . Head 2 has . They learn completely different projections - different linear subspaces - and thus different attention patterns.
The Efficient Implementation
The naive implementation runs separate attention operations. The efficient implementation fuses them into a single batch operation using reshaping tricks.
Instead of running 8 separate attentions, we:
- Project to in one shot with a matrix
- Reshape to , then transpose to
- Run batched attention across all heads simultaneously
- Transpose and reshape back
This is identical mathematically but runs in a single batched matrix multiplication - GPU-efficient.
Parameter Count Analysis
For a single multi-head attention block with , , :
| Matrix | Shape | Parameters |
|---|---|---|
| (all heads combined) | 262,144 | |
| (all heads combined) | 262,144 | |
| (all heads combined) | 262,144 | |
| 262,144 | ||
| Total | - | 1,048,576 (~1M) |
Note that combining all heads' matrices into one matrix is equivalent to having separate matrices - the parameter count is the same. The efficiency comes from doing it as a single matmul.
For GPT-3 with , , each attention block has parameters. Multiply by 96 layers: attention alone accounts for ~58B of the 175B total parameters.
NumPy Implementation from Scratch
import numpy as np
def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
"""Numerically stable softmax."""
x_shifted = x - x.max(axis=axis, keepdims=True)
exp_x = np.exp(x_shifted)
return exp_x / exp_x.sum(axis=axis, keepdims=True)
def multi_head_attention_numpy(
X: np.ndarray,
W_q: np.ndarray,
W_k: np.ndarray,
W_v: np.ndarray,
W_o: np.ndarray,
num_heads: int,
mask: np.ndarray = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Multi-head self-attention from scratch.
Args:
X: (batch, seq, d_model)
W_q, W_k, W_v: (d_model, d_model) -- all heads packed together
W_o: (d_model, d_model)
num_heads: h
mask: (batch, 1, seq, seq) or None
Returns:
output: (batch, seq, d_model)
all_attn_weights: (batch, num_heads, seq, seq)
"""
batch, seq, d_model = X.shape
d_k = d_model // num_heads
# === Step 1: Project to Q, K, V (all heads at once) ===
# Shape after projection: (batch, seq, d_model)
Q_full = X @ W_q # (batch, seq, d_model)
K_full = X @ W_k
V_full = X @ W_v
# === Step 2: Split into heads ===
# Reshape: (batch, seq, d_model) -> (batch, seq, num_heads, d_k)
# Transpose: -> (batch, num_heads, seq, d_k)
Q = Q_full.reshape(batch, seq, num_heads, d_k).transpose(0, 2, 1, 3)
K = K_full.reshape(batch, seq, num_heads, d_k).transpose(0, 2, 1, 3)
V = V_full.reshape(batch, seq, num_heads, d_k).transpose(0, 2, 1, 3)
# === Step 3: Scaled dot-product attention for all heads in parallel ===
# Q: (batch, num_heads, seq, d_k)
# K^T: (batch, num_heads, d_k, seq)
# scores: (batch, num_heads, seq, seq)
scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(d_k)
if mask is not None:
scores = np.where(mask, scores, -1e9)
attn_weights = softmax(scores, axis=-1) # (batch, num_heads, seq, seq)
# (batch, num_heads, seq, seq) @ (batch, num_heads, seq, d_k)
# -> (batch, num_heads, seq, d_k)
context = np.matmul(attn_weights, V)
# === Step 4: Concatenate heads ===
# Transpose: (batch, num_heads, seq, d_k) -> (batch, seq, num_heads, d_k)
# Reshape: -> (batch, seq, d_model)
context = context.transpose(0, 2, 1, 3).reshape(batch, seq, d_model)
# === Step 5: Output projection ===
output = context @ W_o # (batch, seq, d_model)
return output, attn_weights
# Test
np.random.seed(42)
batch, seq, d_model, num_heads = 2, 6, 16, 4
d_k = d_model // num_heads # 4
# Initialize weights
scale = 0.1
W_q = np.random.randn(d_model, d_model) * scale
W_k = np.random.randn(d_model, d_model) * scale
W_v = np.random.randn(d_model, d_model) * scale
W_o = np.random.randn(d_model, d_model) * scale
X = np.random.randn(batch, seq, d_model)
output, weights = multi_head_attention_numpy(X, W_q, W_k, W_v, W_o, num_heads)
print(f"Input shape: {X.shape}") # (2, 6, 16)
print(f"Output shape: {output.shape}") # (2, 6, 16)
print(f"Weights shape: {weights.shape}") # (2, 4, 6, 6)
# Each head's attention should sum to 1 per query
print(f"\nAttention weights sum (head 0, item 0):")
print(np.round(weights[0, 0].sum(axis=-1), 4)) # [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""
Production-grade multi-head attention implementation.
Matches the original Vaswani et al. (2017) formulation.
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Combined projections - more efficient than separate per-head matrices
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.d_k)
def forward(
self,
query: torch.Tensor, # (batch, seq_q, d_model)
key: torch.Tensor, # (batch, seq_k, d_model)
value: torch.Tensor, # (batch, seq_k, d_model)
mask: torch.Tensor = None, # broadcastable to (batch, num_heads, seq_q, seq_k)
) -> tuple[torch.Tensor, torch.Tensor]:
batch = query.size(0)
def split_heads(x: torch.Tensor) -> torch.Tensor:
"""(batch, seq, d_model) -> (batch, num_heads, seq, d_k)"""
seq = x.size(1)
return x.view(batch, seq, self.num_heads, self.d_k).transpose(1, 2)
# Project to Q, K, V and split into heads
Q = split_heads(self.W_q(query)) # (batch, h, seq_q, d_k)
K = split_heads(self.W_k(key)) # (batch, h, seq_k, d_k)
V = split_heads(self.W_v(value)) # (batch, h, seq_k, d_k)
# Scaled dot-product attention across all heads in parallel
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
attn_weights = self.dropout(attn_weights)
# Weighted sum of values
context = torch.matmul(attn_weights, V) # (batch, h, seq_q, d_k)
# Concatenate heads: (batch, h, seq_q, d_k) -> (batch, seq_q, d_model)
context = context.transpose(1, 2).contiguous().view(batch, -1, self.d_model)
# Output projection
output = self.W_o(context)
return output, attn_weights # return weights for inspection
# === Test and verification ===
torch.manual_seed(42)
d_model, num_heads = 64, 8
d_k = d_model // num_heads # 8 per head
mha = MultiHeadAttention(d_model, num_heads)
# Self-attention: query=key=value (same input)
batch, seq = 3, 10
x = torch.randn(batch, seq, d_model)
out, weights = mha(x, x, x)
print(f"Output shape: {out.shape}") # (3, 10, 64)
print(f"Weights shape: {weights.shape}") # (3, 8, 10, 10)
# Verify attention weights sum to 1 (before dropout in eval mode)
mha.eval()
_, weights_eval = mha(x, x, x)
sums = weights_eval[0, :, 0, :].sum(dim=-1)
print(f"Attention sums (head x, query 0): {torch.round(sums, decimals=4)}")
# Should all be ~1.0
# Cross-attention: query from decoder, key/value from encoder
seq_decoder, seq_encoder = 7, 12
query = torch.randn(batch, seq_decoder, d_model)
encoder_output = torch.randn(batch, seq_encoder, d_model)
out_cross, weights_cross = mha(query, encoder_output, encoder_output)
print(f"\nCross-attention output: {out_cross.shape}") # (3, 7, 64)
print(f"Cross-attention weights: {weights_cross.shape}") # (3, 8, 7, 12)
Head Specialization: What Different Heads Learn
One of the most fascinating properties of multi-head attention is that different heads spontaneously learn to track different linguistic phenomena, without any explicit supervision.
Research by Voita et al. (2019) ("Analyzing Multi-Head Self-Attention") analyzed heads in a 6-layer, 8-head transformer trained for translation and found:
Positional heads: Attend to adjacent tokens (offset by 1 or 2 positions). These learn local patterns - essentially functioning like n-gram detectors.
Syntactic heads: Attend to syntactically related words (subject-verb pairs, noun-adjective pairs). When visualized, these heads show clear syntactic structure.
Rare word heads: Attend disproportionately to low-frequency, semantically important tokens. These heads seem to boost signal from content words.
The critical finding: Only a small subset of heads are crucial. Voita et al. found that pruning ~70% of heads from a trained model caused less than 0.15 BLEU degradation. Most heads are redundant given sufficient model capacity.
Choosing the Number of Heads
Real model configurations:
| Model | Heads () | per head | |
|---|---|---|---|
| Transformer-base | 512 | 8 | 64 |
| Transformer-large | 1024 | 16 | 64 |
| BERT-base | 768 | 12 | 64 |
| BERT-large | 1024 | 16 | 64 |
| GPT-2 | 768 | 12 | 64 |
| GPT-3 | 12,288 | 96 | 128 |
| LLaMA-2 70B | 8,192 | 64 | 128 |
The pattern: per head is typically 64 or 128, and the number of heads scales with . The choice is empirically good - small enough that each head has a focused subspace, large enough to represent meaningful relationships.
Increasing heads without increasing means smaller per head. This limits what each head can represent. You generally want at minimum.
More heads = more types of relationships captured, but also more parameters ( sets of projection matrices). The output projection learns to weight the contributions of each head appropriately.
:::tip Grouped Query Attention (GQA) Modern large models (LLaMA-2, Mistral, Gemma) use Grouped Query Attention: instead of having separate key/value projections, they share K and V projections across groups of query heads. This reduces the key-value cache size during inference significantly. LLaMA-2 70B uses GQA with 8 K/V heads shared across 64 query heads. :::
Multi-Query Attention (MQA)
An even more aggressive sharing: all query heads share a single K and V projection (Shazeer, 2019). Used in PaLM and Falcon.
Standard MHA: h query heads, h key heads, h value heads
GQA: h query heads, g key heads, g value heads (g < h)
MQA: h query heads, 1 key head, 1 value head
The motivation is purely inference-time efficiency: the key-value cache (KV cache) during generation stores K and V tensors for all previous tokens, for all heads, for all layers. With MQA, the KV cache is times smaller, enabling either longer context or larger batch sizes at the same memory budget.
Production Engineering Notes
KV Cache During Generation
When generating text autoregressively, for each new token you need to run attention over all previous tokens. A naive implementation would recompute K and V for all previous tokens at each step - total computation.
The KV cache stores K and V tensors from all previous steps. For each new token, you only compute Q for the new token, then attend to the cached K, V. This reduces generation to per step.
Memory cost: For GPT-3 (, , 96 layers), the KV cache for a 4096-token sequence is:
This is why batch size during inference is severely memory-limited, and why MQA/GQA matter.
Flash Attention Compatibility
Flash Attention 2 (Dao et al., 2023) supports multi-head attention natively. In PyTorch 2.0+:
# This uses Flash Attention automatically when possible
output = F.scaled_dot_product_attention(
query, # (batch, heads, seq_q, d_k)
key, # (batch, heads, seq_k, d_k)
value, # (batch, heads, seq_k, d_v)
is_causal=True, # for decoder
)
The key constraint: Flash Attention requires contiguous tensors and specific dtypes (float16/bfloat16). Always profile before assuming it's being used.
Common Mistakes
:::danger Not reshaping before the output projection
After concatenating heads, the tensor shape is (batch, seq, num_heads * d_k). If num_heads * d_k != d_model (which should never happen, but often does in buggy implementations with wrong dimension choices), the output projection W_o will have the wrong shape and the model will silently use the wrong dimensions. Always assert d_model == num_heads * d_k at init time.
:::
:::danger Transposing the wrong dimensions when splitting heads
The reshape sequence is: (batch, seq, d_model) → reshape to (batch, seq, h, d_k) → transpose to (batch, h, seq, d_k). A common mistake is transposing dimensions 1 and 2 on the wrong tensor shape, or forgetting .contiguous() before reshaping after the transpose. Without .contiguous(), the reshape throws a RuntimeError or silently produces wrong results.
:::
:::warning Attention in inference: forgetting to disable dropout
Multi-head attention often includes attention dropout (applied to weights after softmax). In training mode this is correct. In evaluation/inference mode, model.eval() disables it - but if you manually set model.train() for some part of inference (e.g., for MC Dropout uncertainty estimation), remember that attention patterns will be randomly dropped.
:::
:::tip Debugging with head patterns If your model is underperforming, visualize attention weights per head before diving into hyperparameters. A healthy multi-head attention should show diverse patterns across heads. If all heads look identical, the model has collapsed to using a single subspace - try adjusting initialization or adding dropout. If all heads are uniform (no clear pattern), the learning rate may be too high or the model isn't training. :::
Interview Q&A
Q1: Why use multiple attention heads instead of one big attention head with the same parameter count?
Answer: It's not just about parameter count - it's about expressivity in multiple subspaces.
With a single head of dimension , you get one attention pattern per layer. The model must choose: does it compute syntactic relationships? Semantic relationships? Coreference? It can only pick one "question" to ask per layer.
With heads each of dimension , you get different attention patterns. Each head's projection matrices (, , ) project the input into a different subspace. Gradients during training specialize these subspaces for different types of relationships.
The parameter count is the same: heads of dimension vs 1 head of dimension. But the functional capacity is completely different - you get independent views of the same sequence rather than one.
Empirically: Vaswani et al. (2017) showed in ablation studies that 8 heads outperformed 1 head significantly, with performance peaking around 16 heads for their model size.
Q2: What does the output projection do? Why is it necessary?
Answer: After concatenating attention heads, you have a tensor of shape . The output projection projects this back to .
It does two things:
-
Dimensionality restoration: The residual connection requires the same dimension as the input, so you must project back to .
-
Linear mixing of head outputs: The heads computed parallel, potentially contradictory, views of the input. The output projection learns to combine these - weighting some heads more than others for different positions, synthesizing a coherent output.
Without , you'd simply concatenate the heads and add them to the residual. The output projection gives the model freedom to learn non-trivial combinations of the per-head outputs.
Q3: In production, GPT-3 uses 96 attention heads. Is there evidence that all 96 are doing different things?
Answer: Not all 96 heads are necessary or doing unique things. Research consistently shows attention head redundancy:
- Michel et al. (2019) showed that 70-90% of BERT heads can be pruned with less than 1% accuracy drop on downstream tasks.
- Voita et al. (2019) found that most crucial heads fell into a small number of functional categories (positional, syntactic, rare-word).
In large models like GPT-3, most heads likely serve similar functions across different layers. The value of having 96 heads is not that each is unique - it's that having many heads gives the model spare capacity during training to discover useful specializations, and the redundancy provides robustness.
Practically: for inference optimization, you can prune attention heads. For fine-tuning small datasets, reducing heads can reduce overfitting. For most production applications, the pre-trained head structure should be kept intact.
Q4: Explain the computational and memory cost of multi-head attention for a 4096-token context.
Answer: For a single MHA block with heads, dimension, sequence length :
Compute: for the attention scores. For , : roughly billion operations per layer. Across 32 layers: ~2.2 trillion FLOPs. That's just attention - FFN adds roughly the same.
Memory: The attention weight matrix alone is float16 values. For batch=1, h=32, n=4096: GB just for attention weights, per layer. This is why Flash Attention (which avoids materializing this matrix) is not optional for long contexts.
KV Cache: For autoregressive generation with 4096 context, storing K and V for all layers: . For GPT-3: ~18 GB. This dominates inference memory for long contexts.
Q5: What is the difference between self-attention and multi-head attention? Can you have multi-head cross-attention?
Answer:
- Self-attention refers to where Q, K, V come from (same sequence). It's a structural choice.
- Multi-head attention refers to running parallel attention operations. It's an architectural choice.
These are orthogonal concepts that are combined:
- Multi-head self-attention (MHSA): Used in encoder layers and decoder masked layers. Q, K, V all from same sequence, run in parallel heads.
- Multi-head cross-attention: Used in encoder-decoder decoder layers. Q from decoder, K/V from encoder output, run in parallel heads.
Yes, cross-attention is always multi-head in transformers. The heads in cross-attention learn different ways to align between source and target - some heads learn syntactic alignment, others semantic alignment, others positional patterns between languages.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Multi-Head Attention Patterns demo on the EngineersOfAI Playground - no code required.
:::
