Skip to main content

Multi-Head Attention

Reading time: ~35 min · Interview relevance: Essential · Target roles: ML Engineer, AI Engineer, Research Engineer


The Meeting Room Problem

David's team at a financial NLP startup had built a solid sentence embedder using single-head self-attention. It worked well for most tasks, but struggled on a specific type of query: "Who authorized the transaction, and when?"

The problem wasn't the individual words - "who" and "when" were represented correctly. The problem was that a single attention head could only ask one "question" per forward pass. When the model tried to simultaneously capture "what entity performed an action" and "at what point in time," it couldn't. It would attend to the agent when it needed the time, or vice versa.

David's team tried adding more model depth - extra layers, wider projections. It helped slightly, but the fundamental issue remained. A single attention head learns a single type of relationship per layer. If the task requires simultaneously tracking multiple relationship types (agent, patient, time, location, modality), one head cannot capture all of them.

When they switched to multi-head attention, the diagnostic was immediate. They visualized the attention patterns across 8 heads. Head 3 had learned to track syntactic subjects. Head 6 tracked temporal expressions. Head 7 tracked possessive relationships. The model was naturally decomposing the problem into independent subproblems and solving each in parallel.

This is what multi-head attention does: it runs several attention operations simultaneously, each learning to focus on a different type of relationship. The outputs are concatenated and projected back, giving a final representation that incorporates multiple perspectives on the same input.


Why Single-Head Attention Is Insufficient

In the previous lesson, we implemented single-head attention. Each layer computes one attention function:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

The problem: QQ, KK, and VV are all projected from the same dmodeld_{model}-dimensional space into a single dkd_k-dimensional subspace. The model has one set of linear projections, meaning it can only compute one type of query-key relationship per layer.

Language requires simultaneously reasoning about multiple relationship types:

  • Syntactic: What is the subject? What is the object?
  • Semantic: What is the meaning? What is the context?
  • Coreference: Which pronoun refers to which entity?
  • Positional: What is near? What is far?

A single attention head picks one "focus" per layer. To capture all these relationships, you'd need either many layers (depth) or a better mechanism. Multi-head attention is that mechanism: run hh attention operations in parallel, each in a lower-dimensional subspace, then combine.


The Multi-Head Attention Formula

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O

where each head ii is:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V)

The parameters:

  • WiQRdmodel×dkW_i^Q \in \mathbb{R}^{d_{model} \times d_k} - query projection for head ii
  • WiKRdmodel×dkW_i^K \in \mathbb{R}^{d_{model} \times d_k} - key projection for head ii
  • WiVRdmodel×dvW_i^V \in \mathbb{R}^{d_{model} \times d_v} - value projection for head ii
  • WORhdv×dmodelW^O \in \mathbb{R}^{h \cdot d_v \times d_{model}} - output projection

In the original transformer: dk=dv=dmodel/h=512/8=64d_k = d_v = d_{model} / h = 512 / 8 = 64.

The key observation: each head has its own projection matrices. Head 1 has W1Q,W1K,W1VW_1^Q, W_1^K, W_1^V. Head 2 has W2Q,W2K,W2VW_2^Q, W_2^K, W_2^V. They learn completely different projections - different linear subspaces - and thus different attention patterns.


The Efficient Implementation

The naive implementation runs hh separate attention operations. The efficient implementation fuses them into a single batch operation using reshaping tricks.

Instead of running 8 separate (batch,seq,64)(batch,seq,64)(batch, seq, 64) \to (batch, seq, 64) attentions, we:

  1. Project to (batch,seq,h×dk)(batch, seq, h \times d_k) in one shot with a (dmodel,h×dk)(d_{model}, h \times d_k) matrix
  2. Reshape to (batch,seq,h,dk)(batch, seq, h, d_k), then transpose to (batch,h,seq,dk)(batch, h, seq, d_k)
  3. Run batched attention across all heads simultaneously
  4. Transpose and reshape back

This is identical mathematically but runs in a single batched matrix multiplication - GPU-efficient.


Parameter Count Analysis

For a single multi-head attention block with dmodel=512d_{model} = 512, h=8h = 8, dk=dv=64d_k = d_v = 64:

MatrixShapeParameters
WQW^Q (all heads combined)(512,512)(512, 512)262,144
WKW^K (all heads combined)(512,512)(512, 512)262,144
WVW^V (all heads combined)(512,512)(512, 512)262,144
WOW^O(512,512)(512, 512)262,144
Total-1,048,576 (~1M)

Note that combining all heads' WQW^Q matrices into one (dmodel,dmodel)(d_{model}, d_{model}) matrix is equivalent to having hh separate (dmodel,dk)(d_{model}, d_k) matrices - the parameter count is the same. The efficiency comes from doing it as a single matmul.

For GPT-3 with dmodel=12288d_{model} = 12288, h=96h = 96, each attention block has 4×122882603M4 \times 12288^2 \approx 603M parameters. Multiply by 96 layers: attention alone accounts for ~58B of the 175B total parameters.


NumPy Implementation from Scratch

import numpy as np


def softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
"""Numerically stable softmax."""
x_shifted = x - x.max(axis=axis, keepdims=True)
exp_x = np.exp(x_shifted)
return exp_x / exp_x.sum(axis=axis, keepdims=True)


def multi_head_attention_numpy(
X: np.ndarray,
W_q: np.ndarray,
W_k: np.ndarray,
W_v: np.ndarray,
W_o: np.ndarray,
num_heads: int,
mask: np.ndarray = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Multi-head self-attention from scratch.

Args:
X: (batch, seq, d_model)
W_q, W_k, W_v: (d_model, d_model) -- all heads packed together
W_o: (d_model, d_model)
num_heads: h
mask: (batch, 1, seq, seq) or None

Returns:
output: (batch, seq, d_model)
all_attn_weights: (batch, num_heads, seq, seq)
"""
batch, seq, d_model = X.shape
d_k = d_model // num_heads

# === Step 1: Project to Q, K, V (all heads at once) ===
# Shape after projection: (batch, seq, d_model)
Q_full = X @ W_q # (batch, seq, d_model)
K_full = X @ W_k
V_full = X @ W_v

# === Step 2: Split into heads ===
# Reshape: (batch, seq, d_model) -> (batch, seq, num_heads, d_k)
# Transpose: -> (batch, num_heads, seq, d_k)
Q = Q_full.reshape(batch, seq, num_heads, d_k).transpose(0, 2, 1, 3)
K = K_full.reshape(batch, seq, num_heads, d_k).transpose(0, 2, 1, 3)
V = V_full.reshape(batch, seq, num_heads, d_k).transpose(0, 2, 1, 3)

# === Step 3: Scaled dot-product attention for all heads in parallel ===
# Q: (batch, num_heads, seq, d_k)
# K^T: (batch, num_heads, d_k, seq)
# scores: (batch, num_heads, seq, seq)
scores = np.matmul(Q, K.transpose(0, 1, 3, 2)) / np.sqrt(d_k)

if mask is not None:
scores = np.where(mask, scores, -1e9)

attn_weights = softmax(scores, axis=-1) # (batch, num_heads, seq, seq)

# (batch, num_heads, seq, seq) @ (batch, num_heads, seq, d_k)
# -> (batch, num_heads, seq, d_k)
context = np.matmul(attn_weights, V)

# === Step 4: Concatenate heads ===
# Transpose: (batch, num_heads, seq, d_k) -> (batch, seq, num_heads, d_k)
# Reshape: -> (batch, seq, d_model)
context = context.transpose(0, 2, 1, 3).reshape(batch, seq, d_model)

# === Step 5: Output projection ===
output = context @ W_o # (batch, seq, d_model)

return output, attn_weights


# Test
np.random.seed(42)
batch, seq, d_model, num_heads = 2, 6, 16, 4
d_k = d_model // num_heads # 4

# Initialize weights
scale = 0.1
W_q = np.random.randn(d_model, d_model) * scale
W_k = np.random.randn(d_model, d_model) * scale
W_v = np.random.randn(d_model, d_model) * scale
W_o = np.random.randn(d_model, d_model) * scale

X = np.random.randn(batch, seq, d_model)

output, weights = multi_head_attention_numpy(X, W_q, W_k, W_v, W_o, num_heads)

print(f"Input shape: {X.shape}") # (2, 6, 16)
print(f"Output shape: {output.shape}") # (2, 6, 16)
print(f"Weights shape: {weights.shape}") # (2, 4, 6, 6)

# Each head's attention should sum to 1 per query
print(f"\nAttention weights sum (head 0, item 0):")
print(np.round(weights[0, 0].sum(axis=-1), 4)) # [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
"""
Production-grade multi-head attention implementation.
Matches the original Vaswani et al. (2017) formulation.
"""

def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
super().__init__()
assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"

self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads

# Combined projections - more efficient than separate per-head matrices
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)

self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.d_k)

def forward(
self,
query: torch.Tensor, # (batch, seq_q, d_model)
key: torch.Tensor, # (batch, seq_k, d_model)
value: torch.Tensor, # (batch, seq_k, d_model)
mask: torch.Tensor = None, # broadcastable to (batch, num_heads, seq_q, seq_k)
) -> tuple[torch.Tensor, torch.Tensor]:
batch = query.size(0)

def split_heads(x: torch.Tensor) -> torch.Tensor:
"""(batch, seq, d_model) -> (batch, num_heads, seq, d_k)"""
seq = x.size(1)
return x.view(batch, seq, self.num_heads, self.d_k).transpose(1, 2)

# Project to Q, K, V and split into heads
Q = split_heads(self.W_q(query)) # (batch, h, seq_q, d_k)
K = split_heads(self.W_k(key)) # (batch, h, seq_k, d_k)
V = split_heads(self.W_v(value)) # (batch, h, seq_k, d_k)

# Scaled dot-product attention across all heads in parallel
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))

attn_weights = F.softmax(scores, dim=-1)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
attn_weights = self.dropout(attn_weights)

# Weighted sum of values
context = torch.matmul(attn_weights, V) # (batch, h, seq_q, d_k)

# Concatenate heads: (batch, h, seq_q, d_k) -> (batch, seq_q, d_model)
context = context.transpose(1, 2).contiguous().view(batch, -1, self.d_model)

# Output projection
output = self.W_o(context)

return output, attn_weights # return weights for inspection


# === Test and verification ===
torch.manual_seed(42)

d_model, num_heads = 64, 8
d_k = d_model // num_heads # 8 per head

mha = MultiHeadAttention(d_model, num_heads)

# Self-attention: query=key=value (same input)
batch, seq = 3, 10
x = torch.randn(batch, seq, d_model)
out, weights = mha(x, x, x)

print(f"Output shape: {out.shape}") # (3, 10, 64)
print(f"Weights shape: {weights.shape}") # (3, 8, 10, 10)

# Verify attention weights sum to 1 (before dropout in eval mode)
mha.eval()
_, weights_eval = mha(x, x, x)
sums = weights_eval[0, :, 0, :].sum(dim=-1)
print(f"Attention sums (head x, query 0): {torch.round(sums, decimals=4)}")
# Should all be ~1.0

# Cross-attention: query from decoder, key/value from encoder
seq_decoder, seq_encoder = 7, 12
query = torch.randn(batch, seq_decoder, d_model)
encoder_output = torch.randn(batch, seq_encoder, d_model)
out_cross, weights_cross = mha(query, encoder_output, encoder_output)
print(f"\nCross-attention output: {out_cross.shape}") # (3, 7, 64)
print(f"Cross-attention weights: {weights_cross.shape}") # (3, 8, 7, 12)

Head Specialization: What Different Heads Learn

One of the most fascinating properties of multi-head attention is that different heads spontaneously learn to track different linguistic phenomena, without any explicit supervision.

Research by Voita et al. (2019) ("Analyzing Multi-Head Self-Attention") analyzed heads in a 6-layer, 8-head transformer trained for translation and found:

Positional heads: Attend to adjacent tokens (offset by 1 or 2 positions). These learn local patterns - essentially functioning like n-gram detectors.

Syntactic heads: Attend to syntactically related words (subject-verb pairs, noun-adjective pairs). When visualized, these heads show clear syntactic structure.

Rare word heads: Attend disproportionately to low-frequency, semantically important tokens. These heads seem to boost signal from content words.

The critical finding: Only a small subset of heads are crucial. Voita et al. found that pruning ~70% of heads from a trained model caused less than 0.15 BLEU degradation. Most heads are redundant given sufficient model capacity.


Choosing the Number of Heads

Real model configurations:

Modeldmodeld_{model}Heads (hh)dkd_k per head
Transformer-base512864
Transformer-large10241664
BERT-base7681264
BERT-large10241664
GPT-27681264
GPT-312,28896128
LLaMA-2 70B8,19264128

The pattern: dkd_k per head is typically 64 or 128, and the number of heads scales with dmodeld_{model}. The dk=64d_k = 64 choice is empirically good - small enough that each head has a focused subspace, large enough to represent meaningful relationships.

Increasing heads without increasing dmodeld_{model} means smaller dkd_k per head. This limits what each head can represent. You generally want dk32d_k \geq 32 at minimum.

More heads = more types of relationships captured, but also more parameters (hh sets of projection matrices). The output projection WOW^O learns to weight the contributions of each head appropriately.

:::tip Grouped Query Attention (GQA) Modern large models (LLaMA-2, Mistral, Gemma) use Grouped Query Attention: instead of having hh separate key/value projections, they share K and V projections across groups of query heads. This reduces the key-value cache size during inference significantly. LLaMA-2 70B uses GQA with 8 K/V heads shared across 64 query heads. :::


Multi-Query Attention (MQA)

An even more aggressive sharing: all query heads share a single K and V projection (Shazeer, 2019). Used in PaLM and Falcon.

Standard MHA: h query heads, h key heads, h value heads
GQA: h query heads, g key heads, g value heads (g < h)
MQA: h query heads, 1 key head, 1 value head

The motivation is purely inference-time efficiency: the key-value cache (KV cache) during generation stores K and V tensors for all previous tokens, for all heads, for all layers. With MQA, the KV cache is hh times smaller, enabling either longer context or larger batch sizes at the same memory budget.


Production Engineering Notes

KV Cache During Generation

When generating text autoregressively, for each new token you need to run attention over all previous tokens. A naive implementation would recompute K and V for all previous tokens at each step - O(n2)O(n^2) total computation.

The KV cache stores K and V tensors from all previous steps. For each new token, you only compute Q for the new token, then attend to the cached K, V. This reduces generation to O(n)O(n) per step.

Memory cost: For GPT-3 (dmodel=12288d_{model}=12288, h=96h=96, 96 layers), the KV cache for a 4096-token sequence is: 2×96×96×4096×12288/96×2 bytes18.9 GB2 \times 96 \times 96 \times 4096 \times 12288 / 96 \times 2 \text{ bytes} \approx 18.9 \text{ GB}

This is why batch size during inference is severely memory-limited, and why MQA/GQA matter.

Flash Attention Compatibility

Flash Attention 2 (Dao et al., 2023) supports multi-head attention natively. In PyTorch 2.0+:

# This uses Flash Attention automatically when possible
output = F.scaled_dot_product_attention(
query, # (batch, heads, seq_q, d_k)
key, # (batch, heads, seq_k, d_k)
value, # (batch, heads, seq_k, d_v)
is_causal=True, # for decoder
)

The key constraint: Flash Attention requires contiguous tensors and specific dtypes (float16/bfloat16). Always profile before assuming it's being used.


Common Mistakes

:::danger Not reshaping before the output projection After concatenating heads, the tensor shape is (batch, seq, num_heads * d_k). If num_heads * d_k != d_model (which should never happen, but often does in buggy implementations with wrong dimension choices), the output projection W_o will have the wrong shape and the model will silently use the wrong dimensions. Always assert d_model == num_heads * d_k at init time. :::

:::danger Transposing the wrong dimensions when splitting heads The reshape sequence is: (batch, seq, d_model) → reshape to (batch, seq, h, d_k) → transpose to (batch, h, seq, d_k). A common mistake is transposing dimensions 1 and 2 on the wrong tensor shape, or forgetting .contiguous() before reshaping after the transpose. Without .contiguous(), the reshape throws a RuntimeError or silently produces wrong results. :::

:::warning Attention in inference: forgetting to disable dropout Multi-head attention often includes attention dropout (applied to weights after softmax). In training mode this is correct. In evaluation/inference mode, model.eval() disables it - but if you manually set model.train() for some part of inference (e.g., for MC Dropout uncertainty estimation), remember that attention patterns will be randomly dropped. :::

:::tip Debugging with head patterns If your model is underperforming, visualize attention weights per head before diving into hyperparameters. A healthy multi-head attention should show diverse patterns across heads. If all heads look identical, the model has collapsed to using a single subspace - try adjusting initialization or adding dropout. If all heads are uniform (no clear pattern), the learning rate may be too high or the model isn't training. :::


Interview Q&A

Q1: Why use multiple attention heads instead of one big attention head with the same parameter count?

Answer: It's not just about parameter count - it's about expressivity in multiple subspaces.

With a single head of dimension dk=dmodeld_k = d_{model}, you get one attention pattern per layer. The model must choose: does it compute syntactic relationships? Semantic relationships? Coreference? It can only pick one "question" to ask per layer.

With hh heads each of dimension dk=dmodel/hd_k = d_{model}/h, you get hh different attention patterns. Each head's projection matrices (WiQW_i^Q, WiKW_i^K, WiVW_i^V) project the input into a different subspace. Gradients during training specialize these subspaces for different types of relationships.

The parameter count is the same: hh heads of dmodel/hd_{model}/h dimension vs 1 head of dmodeld_{model} dimension. But the functional capacity is completely different - you get hh independent views of the same sequence rather than one.

Empirically: Vaswani et al. (2017) showed in ablation studies that 8 heads outperformed 1 head significantly, with performance peaking around 16 heads for their model size.

Q2: What does the output projection WOW^O do? Why is it necessary?

Answer: After concatenating hh attention heads, you have a tensor of shape (batch,seq,hdv)(batch, seq, h \cdot d_v). The output projection WORhdv×dmodelW^O \in \mathbb{R}^{h \cdot d_v \times d_{model}} projects this back to (batch,seq,dmodel)(batch, seq, d_{model}).

It does two things:

  1. Dimensionality restoration: The residual connection requires the same dimension as the input, so you must project back to dmodeld_{model}.

  2. Linear mixing of head outputs: The heads computed hh parallel, potentially contradictory, views of the input. The output projection learns to combine these - weighting some heads more than others for different positions, synthesizing a coherent output.

Without WOW^O, you'd simply concatenate the heads and add them to the residual. The output projection gives the model freedom to learn non-trivial combinations of the per-head outputs.

Q3: In production, GPT-3 uses 96 attention heads. Is there evidence that all 96 are doing different things?

Answer: Not all 96 heads are necessary or doing unique things. Research consistently shows attention head redundancy:

  • Michel et al. (2019) showed that 70-90% of BERT heads can be pruned with less than 1% accuracy drop on downstream tasks.
  • Voita et al. (2019) found that most crucial heads fell into a small number of functional categories (positional, syntactic, rare-word).

In large models like GPT-3, most heads likely serve similar functions across different layers. The value of having 96 heads is not that each is unique - it's that having many heads gives the model spare capacity during training to discover useful specializations, and the redundancy provides robustness.

Practically: for inference optimization, you can prune attention heads. For fine-tuning small datasets, reducing heads can reduce overfitting. For most production applications, the pre-trained head structure should be kept intact.

Q4: Explain the computational and memory cost of multi-head attention for a 4096-token context.

Answer: For a single MHA block with hh heads, dmodeld_{model} dimension, sequence length nn:

Compute: O(n2dmodel)O(n^2 d_{model}) for the attention scores. For n=4096n=4096, dmodel=4096d_{model}=4096: roughly 40962×409668.74096^2 \times 4096 \approx 68.7 billion operations per layer. Across 32 layers: ~2.2 trillion FLOPs. That's just attention - FFN adds roughly the same.

Memory: The attention weight matrix alone is (batch×h×n×n)(batch \times h \times n \times n) float16 values. For batch=1, h=32, n=4096: 32×40962×2132 \times 4096^2 \times 2 \approx 1 GB just for attention weights, per layer. This is why Flash Attention (which avoids materializing this matrix) is not optional for long contexts.

KV Cache: For autoregressive generation with 4096 context, storing K and V for all layers: 2×layers×n×dmodel×2 bytes2 \times \text{layers} \times n \times d_{model} \times 2 \text{ bytes}. For GPT-3: ~18 GB. This dominates inference memory for long contexts.

Q5: What is the difference between self-attention and multi-head attention? Can you have multi-head cross-attention?

Answer:

  • Self-attention refers to where Q, K, V come from (same sequence). It's a structural choice.
  • Multi-head attention refers to running parallel attention operations. It's an architectural choice.

These are orthogonal concepts that are combined:

  • Multi-head self-attention (MHSA): Used in encoder layers and decoder masked layers. Q, K, V all from same sequence, run in hh parallel heads.
  • Multi-head cross-attention: Used in encoder-decoder decoder layers. Q from decoder, K/V from encoder output, run in hh parallel heads.

Yes, cross-attention is always multi-head in transformers. The heads in cross-attention learn different ways to align between source and target - some heads learn syntactic alignment, others semantic alignment, others positional patterns between languages.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Multi-Head Attention Patterns demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.