Skip to main content

Attention Mechanism - The Idea That Changed Everything

Reading time: ~40 min | Interview relevance: Critical | Roles: MLE, AI Eng, Research Engineer, NLP Engineer, LLM Engineer

The Real Interview Moment

You are in an OpenAI research engineer interview. The interviewer writes on the whiteboard:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

And asks: "Derive this from first principles. What problem does it solve? Why do we divide by dk\sqrt{d_k}? What would happen if we removed the scaling? Now extend it to multi-head attention and explain why multiple heads help."

This is the single most important equation in modern AI. Every Transformer, every LLM, every vision Transformer, every diffusion model uses this formula. Interviewers at Google, Meta, Anthropic, and every serious AI lab expect you to know it cold - not just the formula, but the motivation, the derivation, the numerical issues, and the design decisions.

Candidates who can recite the formula but cannot explain the scaling factor get a "lean no-hire." Candidates who can derive QKV from the information retrieval analogy, explain the softmax temperature effect, and discuss attention patterns get a "strong hire." This page takes you from zero to that level.

What You Will Master

  • Explain the seq2seq bottleneck that motivated attention
  • Derive Bahdanau (additive) attention and Luong (multiplicative) attention from scratch
  • Build up to self-attention from encoder-decoder attention
  • Derive the QKV formulation with clear intuition for each component
  • Prove why we scale by dk\sqrt{d_k} using variance analysis
  • Extend to multi-head attention and explain the benefits
  • Distinguish self-attention from cross-attention with concrete examples
  • Analyze attention patterns and what they reveal about model behavior
  • Solve interview problems on attention complexity, masking, and design tradeoffs

Self-Assessment: Where Are You Now?

Skill1 -- Cannot2 -- Vaguely3 -- Can Explain4 -- Can Derive5 -- Can TeachYour Score
Explain the seq2seq bottleneck___
Derive Bahdanau attention___
Write the scaled dot-product attention formula___
Explain why we divide by sqrt(d_k)___
Derive multi-head attention___
Explain Q, K, V with information retrieval analogy___
Distinguish self-attention from cross-attention___
Compute attention complexity (time and space)___

Target: All 4s and 5s before your interview.

Part 1 - The Problem: Seq2Seq Bottleneck

What Went Wrong with Encoder-Decoder RNNs

In the original seq2seq model (Sutskever et al., 2014), the encoder reads the entire input sequence and compresses it into a single fixed-size vector c=hTencc = h_T^{\text{enc}}. The decoder then generates the output conditioned only on this vector.

Seq2Seq Encoder-Decoder Bottleneck: All Information in One Vector

The problem: For a sentence like "The cat sat on the mat and watched the birds fly over the garden fence," the entire meaning must be compressed into a vector of (say) 512 dimensions. As sentences get longer, the quality degrades sharply - the BLEU score for translation dropped significantly beyond 20-30 tokens.

The insight: When translating word tt of the output, the decoder does not need equal access to all encoder states. It needs to focus on the relevant parts. Translating a French verb should attend primarily to the corresponding English verb, not the entire sentence.

60-Second Answer

"The attention mechanism was introduced to solve the information bottleneck in seq2seq models. Instead of compressing the entire input into one vector, attention lets the decoder look at all encoder hidden states and dynamically compute a weighted combination based on relevance to the current decoding step. This gives the decoder direct access to any part of the input, regardless of sequence length. The weighted combination is called the context vector, and the weights are called attention weights - they form a probability distribution over the input positions."

Part 2 - Bahdanau Attention (Additive)

The Mechanism (2014)

Bahdanau et al. introduced the first attention mechanism for neural machine translation. At each decoder step tt:

Step 1: Compute alignment scores.

et,i=a(st1,hi)=vTtanh(Wsst1+Whhi)e_{t,i} = a(s_{t-1}, h_i) = v^T \tanh(W_s s_{t-1} + W_h h_i)

Where st1s_{t-1} is the decoder state, hih_i is the ii-th encoder hidden state, and vv, WsW_s, WhW_h are learned parameters.

Step 2: Convert scores to weights.

αt,i=exp(et,i)j=1Txexp(et,j)\alpha_{t,i} = \frac{\exp(e_{t,i})}{\sum_{j=1}^{T_x} \exp(e_{t,j})}

The softmax ensures the weights sum to 1 - they form a probability distribution over the input.

Step 3: Compute context vector.

ct=i=1Txαt,ihic_t = \sum_{i=1}^{T_x} \alpha_{t,i} \cdot h_i

Step 4: Update decoder state.

st=RNN(st1,[yt1;ct])s_t = \text{RNN}(s_{t-1}, [y_{t-1}; c_t])

Why "Additive"?

The score function uses addition inside the tanh: Wsst1+WhhiW_s s_{t-1} + W_h h_i. The decoder state and encoder state are projected to the same space and added. This is an additive alignment model.

Complexity: O(Txd)O(T_x \cdot d) per decoder step (where dd is the alignment hidden dimension). The vTtanh()v^T \tanh(\ldots) computation requires evaluating a small neural network for every encoder-decoder position pair.

Bahdanau Attention - Score, Softmax, Weighted Context at Each Decoder Step

Part 3 - Luong Attention (Multiplicative)

The Simplified Mechanism (2015)

Luong et al. simplified the alignment function to a dot product (or bilinear form):

Dot-product scoring:

et,i=stThie_{t,i} = s_t^T h_i

General (bilinear) scoring:

et,i=stTWahie_{t,i} = s_t^T W_a h_i

Concat scoring (equivalent to Bahdanau):

et,i=vTtanh(Wa[st;hi])e_{t,i} = v^T \tanh(W_a [s_t; h_i])

Bahdanau vs Luong Comparison

AspectBahdanau (Additive)Luong (Multiplicative)
Score functionvTtanh(Wss+Whh)v^T \tanh(W_s s + W_h h)sTWahs^T W_a h or sThs^T h
Decoder state usedst1s_{t-1} (previous)sts_t (current)
ParametersWsW_s, WhW_h, vvWaW_a (or none for dot)
Complexity per scoreO(d)O(d) with tanhO(d)O(d) or O(d2)O(d^2)
Training speedSlower (tanh is nonlinear)Faster (linear operations)
PerformanceSimilarSimilar
Historical importanceFirst attention paperSimplified attention, closer to modern form
Common Trap

Do NOT mix up which decoder state is used. Bahdanau uses st1s_{t-1} (the PREVIOUS decoder state) because the context is computed BEFORE the current decoder step. Luong uses sts_t (the CURRENT decoder state) because the context is computed AFTER. This is a detail interviewers love to test.

Part 4 - Self-Attention: From Encoder-Decoder to Within a Sequence

The Conceptual Leap

In encoder-decoder attention, the decoder attends to the encoder. The decoder state is the "query" and the encoder states are the "keys" and "values." But what if we want positions within the SAME sequence to attend to each other?

Self-attention: Every position in a sequence attends to every other position in the same sequence. Position ii computes a weighted combination of all positions, with weights determined by how "relevant" each position is to position ii.

This is the core idea behind Transformers. Instead of processing a sequence one step at a time (like an RNN), every position can directly interact with every other position in a single layer.

The Information Retrieval Analogy

Self-attention is best understood through the lens of information retrieval:

  • Query (Q): "What am I looking for?" - what information the current position needs
  • Key (K): "What do I contain?" - a summary of what each position offers
  • Value (V): "What do I actually provide?" - the information to be aggregated

The analogy to a database: you have a query, you match it against keys to find relevance scores, and you retrieve the corresponding values. The difference from a hard database lookup is that attention is a soft lookup - every key contributes, weighted by its relevance score.

Self-Attention: Query, Key, Value Projections - Attention(Q,K,V) = softmax(QK^T / √d_k) · V

Why Separate Q, K, V?

A common interview question: "Why not just use the same projection for all three?"

Using the same vector for queries and keys would force a position's "what I'm looking for" to be identical to "what I advertise." Consider the sentence "The cat chased the dog." The word "chased" wants to find its subject (query: "who is doing the chasing?") and also needs to advertise itself as a verb (key: "I am an action"). These are different roles - separating Q and K allows the model to learn different roles for the same position.

Similarly, separating K from V allows the model to match on one criterion but retrieve different information. A key might encode syntactic information (for matching), while the value encodes semantic information (for aggregation).

Interviewer's Perspective

When I ask "why separate Q, K, V?", the strongest candidates give the information retrieval analogy AND a concrete linguistic example. The weakest say "it is just how Transformers work" or "it gives more parameters." The medium-quality answer mentions the database analogy but cannot give a concrete example of why separation matters.

Part 5 - Scaled Dot-Product Attention: The Complete Derivation

The Formula

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Where:

  • QRn×dkQ \in \mathbb{R}^{n \times d_k} - queries (n positions, each of dimension dkd_k)
  • KRm×dkK \in \mathbb{R}^{m \times d_k} - keys (m positions, each of dimension dkd_k)
  • VRm×dvV \in \mathbb{R}^{m \times d_v} - values (m positions, each of dimension dvd_v)
  • Output Rn×dv\in \mathbb{R}^{n \times d_v}

Step-by-Step Computation

Step 1: Compute raw scores.

S=QKTRn×mS = QK^T \in \mathbb{R}^{n \times m}

SijS_{ij} is the dot product between query ii and key jj - measuring how much position ii should attend to position jj.

Step 2: Scale.

S=SdkS' = \frac{S}{\sqrt{d_k}}

Step 3: Apply softmax row-wise.

A=softmax(S)Rn×mA = \text{softmax}(S') \in \mathbb{R}^{n \times m}

Each row sums to 1. AijA_{ij} is the attention weight from position ii to position jj.

Step 4: Weighted combination.

Output=AVRn×dv\text{Output} = AV \in \mathbb{R}^{n \times d_v}

Each output position is a weighted combination of all value vectors.

Why Scale by dk\sqrt{d_k}? - The Variance Argument

This is one of the most commonly asked questions in AI interviews. Here is the full derivation.

Assume the elements of QQ and KK are independent random variables with mean 0 and variance 1 (which is approximately true after standard initialization).

The dot product between a query vector qq and a key vector kk, both of dimension dkd_k, is:

qk=i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i k_i

Each qikiq_i k_i is a product of two independent random variables with mean 0 and variance 1:

E[qiki]=E[qi]E[ki]=0E[q_i k_i] = E[q_i] \cdot E[k_i] = 0 Var(qiki)=E[qi2ki2](E[qiki])2=E[qi2]E[ki2]0=11=1\text{Var}(q_i k_i) = E[q_i^2 k_i^2] - (E[q_i k_i])^2 = E[q_i^2] E[k_i^2] - 0 = 1 \cdot 1 = 1

Since the sum of dkd_k independent terms with variance 1 has variance dkd_k:

Var(qk)=dk\text{Var}(q \cdot k) = d_k

So the dot product has standard deviation dk\sqrt{d_k}. For dk=64d_k = 64, the raw dot products have standard deviation 8. For dk=512d_k = 512, the standard deviation is ~22.6.

The problem with large variance: When the dot products are large in magnitude, the softmax function enters a region where its gradients are extremely small. Consider softmax with inputs [10,10][10, -10]: the output is approximately [1,0][1, 0], and the gradient is nearly zero for both entries. The attention becomes a nearly hard one-hot selection, and gradients vanish.

By dividing by dk\sqrt{d_k}, we normalize the dot products to have variance 1 (and standard deviation 1), keeping the softmax in a well-behaved regime where it produces soft distributions with meaningful gradients.

Var(qkdk)=Var(qk)dk=dkdk=1\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{\text{Var}(q \cdot k)}{d_k} = \frac{d_k}{d_k} = 1
Instant Rejection

If asked "why do we divide by sqrt(d_k)?" and you answer "to prevent overflow" or "to normalize the output," that is a weak answer that misses the point. The precise answer is: the dot product has variance dkd_k (grows with dimension), which pushes softmax into saturation where gradients vanish. Dividing by dk\sqrt{d_k} restores unit variance, keeping softmax in a well-conditioned regime. Interviewers at top labs expect the variance derivation.

What Happens Without Scaling?

For dk=512d_k = 512:

  • Raw dot products: mean 0, std ~22.6
  • With scaling: mean 0, std 1
  • Without scaling: softmax outputs are nearly one-hot (the maximum score dominates), gradients are near zero for non-maximum positions, and training becomes unstable

This is why additive attention (Bahdanau) does not need scaling - the tanh nonlinearity already bounds the scores to [1,1][-1, 1].

Part 6 - Multi-Head Attention

The Motivation

A single attention head computes one set of attention weights. But a sentence like "The animal didn't cross the street because it was too tired" contains multiple types of relationships:

  • Syntactic: "it" refers to "animal" (coreference)
  • Semantic: "tired" modifies "animal" (attribute)
  • Structural: "because" links "didn't cross" and "was too tired" (causal)

A single attention head cannot capture all these simultaneously. Multi-head attention runs multiple attention operations in parallel, each in a different subspace, allowing different heads to learn different types of relationships.

The Equations

Given input XRn×dmodelX \in \mathbb{R}^{n \times d_{\text{model}}}:

For each head ii (where i=1,,hi = 1, \ldots, h):

Qi=XWiQ,Ki=XWiK,Vi=XWiVQ_i = X W_i^Q, \quad K_i = X W_i^K, \quad V_i = X W_i^V headi=Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi\text{head}_i = \text{Attention}(Q_i, K_i, V_i) = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i

Where WiQ,WiKRdmodel×dkW_i^Q, W_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k} and WiVRdmodel×dvW_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}.

Concatenate and project:

MultiHead(X)=Concat(head1,,headh)WO\text{MultiHead}(X) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O

Where WORhdv×dmodelW^O \in \mathbb{R}^{h \cdot d_v \times d_{\text{model}}}.

Dimension Split

Typically, dk=dv=dmodel/hd_k = d_v = d_{\text{model}} / h. For a model with dmodel=512d_{\text{model}} = 512 and h=8h = 8 heads:

  • Each head works in dk=dv=64d_k = d_v = 64 dimensions
  • The total computation is the same as a single head with dk=512d_k = 512
ParameterSingle-Head (d=512)Multi-Head (8 heads, d=64 each)
WQW^Q512 x 512 = 262K8 x (512 x 64) = 262K
WKW^K512 x 512 = 262K8 x (512 x 64) = 262K
WVW^V512 x 512 = 262K8 x (512 x 64) = 262K
WOW^ON/A512 x 512 = 262K
Total786K1,049K

The multi-head version has ~33% more parameters due to WOW^O, but this is a small price for the expressiveness gain.

Multi-Head Attention - Parallel Subspace Projections Concatenated via W_O

Interviewer's Perspective

The key insight I look for: multi-head attention is not just "more parameters." It is about different subspaces. Each head projects Q, K, V into a different subspace, allowing it to attend to different patterns. This is analogous to having multiple feature detectors in a CNN. Candidates who understand the subspace argument get strong marks. Those who just say "more heads = more capacity" get moderate marks.

Part 7 - Cross-Attention vs Self-Attention

Definitions

Self-attention: Q, K, and V all come from the same sequence. Every position attends to every position in the same sequence.

Q=XWQ,K=XWK,V=XWVQ = X W^Q, \quad K = X W^K, \quad V = X W^V

Where XX is the same input for all three.

Cross-attention: Q comes from one sequence (typically the decoder), while K and V come from a different sequence (typically the encoder).

Q=XdecWQ,K=XencWK,V=XencWVQ = X_{\text{dec}} W^Q, \quad K = X_{\text{enc}} W^K, \quad V = X_{\text{enc}} W^V

Where Each Appears

ContextTypeQ SourceK,V SourceExample
Transformer encoderSelf-attentionEncoder inputEncoder inputWord-to-word relationships
Transformer decoder (masked)Causal self-attentionDecoder inputDecoder inputAutoregressive generation
Transformer decoder (cross)Cross-attentionDecoder statesEncoder statesAttending to source in translation
BERTSelf-attentionInput tokensInput tokensBidirectional understanding
GPTCausal self-attentionInput tokensInput tokensUnidirectional generation
Stable DiffusionCross-attentionImage featuresText embeddingsText-guided image generation
Vision Transformer (ViT)Self-attentionImage patchesImage patchesPatch-to-patch relationships

Causal Masking

In autoregressive models (GPT, decoder of Transformer), position tt must NOT attend to positions >t> t (it cannot see the future). This is enforced by a causal mask:

Mij={0if ijif i<jM_{ij} = \begin{cases} 0 & \text{if } i \geq j \\ -\infty & \text{if } i < j \end{cases} Attention(Q,K,V)=softmax(QKTdk+M)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V

The -\infty values become 0 after softmax, effectively preventing attention to future positions.

Common Trap

Do NOT say "causal attention uses a triangular matrix." Be precise: the MASK is upper-triangular (with -\infty above the diagonal), which produces a LOWER-triangular attention weight matrix (after softmax). The attention weights at position ii are nonzero only for positions i\leq i. Getting the direction wrong signals confusion about the mechanism.

Part 8 - Attention Patterns and Interpretability

Common Attention Patterns

Research on attention visualization has revealed several recurring patterns:

1. Diagonal pattern: Position ii attends most strongly to position ii. This is a "copy" or "identity" pattern - the output at each position is primarily its own value.

2. Offset diagonal: Position ii attends to position i1i-1 or i+1i+1. This captures local (neighboring token) relationships, similar to a 1D convolution.

3. Columnar pattern: All positions attend to a specific position (e.g., the [CLS] token in BERT, or the period at the end of a sentence). This creates a "global summary" position.

4. Block pattern: Tokens within the same clause or phrase attend to each other, forming attention blocks. This captures syntactic structure.

5. Sparse/specialized pattern: Only a few Q-K pairs have high attention, with most weights near zero. This often appears in later layers.

What Attention Is NOT

Instant Rejection

Attention weights are NOT the same as feature importance or explanation. Attention shows what the model looked at, not what caused the output. Jain and Wallace (2019) showed that alternative attention distributions can produce the same output, and Wiegreffe and Pinter (2019) showed the relationship is complex. If an interviewer asks "can we use attention weights to explain model predictions?", the answer is "attention weights provide suggestive evidence about what the model uses, but they are not reliable explanations because alternative weight configurations can produce the same output."

Part 9 - Complexity Analysis

Time and Space Complexity

OperationTimeSpace
QKTQK^TO(nmdk)O(n \cdot m \cdot d_k)O(nm)O(n \cdot m)
SoftmaxO(nm)O(n \cdot m)O(nm)O(n \cdot m)
AVAVO(nmdv)O(n \cdot m \cdot d_v)O(ndv)O(n \cdot d_v)
TotalO(nmd)O(n \cdot m \cdot d)O(nm)O(n \cdot m)

For self-attention where n=m=Tn = m = T (sequence length):

  • Time: O(T2d)O(T^2 \cdot d)
  • Space: O(T2)O(T^2) for the attention matrix

This quadratic scaling in sequence length is the fundamental limitation of standard attention. For T=8192T = 8192 with d=128d = 128:

  • Attention matrix: 81922=67M8192^2 = 67M entries per head
  • With 32 heads: ~2.1 billion entries
  • Memory: ~8.6 GB in float32

Comparison with Other Approaches

MechanismTime per LayerSequential OpsMax Path Length
Self-attentionO(T2d)O(T^2 \cdot d)O(1)O(1)O(1)O(1)
RNNO(Td2)O(T \cdot d^2)O(T)O(T)O(T)O(T)
CNN (kernel kk)O(Tkd2)O(T \cdot k \cdot d^2)O(1)O(1)O(logkT)O(\log_k T) (dilated)
Sparse attentionO(TTd)O(T \cdot \sqrt{T} \cdot d)O(1)O(1)O(T)O(\sqrt{T})
Linear attentionO(Td2)O(T \cdot d^2)O(1)O(1)O(1)O(1)

Self-attention wins on sequential operations and path length but loses on time/memory for very long sequences. This motivates efficient attention variants (Flash Attention, sparse attention, linear attention - covered in the Transformer Architecture page).

Company Variation

At Google and Anthropic, interviewers expect you to know the complexity analysis cold and be able to discuss efficient attention variants (Flash Attention, sliding window, sparse attention). At startups, the focus is more on "when does the quadratic cost become a problem in practice?" The answer: typically around T>8KT > 8K for standard GPU memory, though Flash Attention extends this significantly.

Part 10 - Historical Context and Evolution

The Attention Timeline

YearPaperContribution
2014Bahdanau et al.First attention for seq2seq (additive)
2015Luong et al.Simplified to multiplicative/dot-product
2015Xu et al.Visual attention for image captioning
2016Yang et al.Hierarchical attention for document classification
2017Vaswani et al.Self-attention + multi-head in Transformers
2018Devlin et al.Bidirectional self-attention in BERT
2019Child et al.Sparse Transformer (efficient attention)
2020Katharopoulos et al.Linear attention
2022Dao et al.Flash Attention (memory-efficient)
2023Dao et al.Flash Attention 2 (faster GPU utilization)
2024Dao et al.Flash Attention 3

The Pattern

Each advance addresses a specific limitation:

  • Bahdanau → Luong: simplification (remove tanh, faster)
  • RNN attention → self-attention: remove recurrence, enable parallelism
  • Single head → multi-head: capture diverse relationship types
  • Standard attention → Flash Attention: reduce memory from O(T2)O(T^2) to O(T)O(T) without approximation

Practice Problems

Problem 1: Scaling Factor Derivation

Prove that for random vectors q,kRdkq, k \in \mathbb{R}^{d_k} with i.i.d. components from N(0,1)\mathcal{N}(0, 1), the dot product qkq \cdot k has mean 0 and variance dkd_k. Then explain what happens to softmax when the variance is too large.

Hint 1 -- Direction

The dot product is a sum of dkd_k terms. Each term is a product of two independent standard normal random variables. Use the properties of normal random variables to compute the mean and variance of each term.

Hint 2 -- Insight

If qi,kiN(0,1)q_i, k_i \sim \mathcal{N}(0, 1) independently, then qikiq_i k_i has mean 0 (by independence) and variance 1 (since Var(qiki)=E[qi2]E[ki2]=1\text{Var}(q_i k_i) = E[q_i^2]E[k_i^2] = 1). The sum of dkd_k independent terms with variance 1 has variance dkd_k by additivity of variance for independent random variables. When dk=512d_k = 512, the standard deviation is ~22.6, meaning dot products range roughly from 45-45 to +45+45. Softmax with inputs of magnitude 45 produces outputs that are essentially one-hot.

Hint 3 -- Full Solution + Rubric

Mean:

E[qk]=E[i=1dkqiki]=i=1dkE[qiki]=i=1dkE[qi]E[ki]=0E[q \cdot k] = E\left[\sum_{i=1}^{d_k} q_i k_i\right] = \sum_{i=1}^{d_k} E[q_i k_i] = \sum_{i=1}^{d_k} E[q_i]E[k_i] = 0

Variance:

Var(qiki)=E[(qiki)2](E[qiki])2=E[qi2]E[ki2]0=11=1\text{Var}(q_i k_i) = E[(q_i k_i)^2] - (E[q_i k_i])^2 = E[q_i^2]E[k_i^2] - 0 = 1 \cdot 1 = 1 Var(qk)=Var(i=1dkqiki)=i=1dkVar(qiki)=dk\text{Var}(q \cdot k) = \text{Var}\left(\sum_{i=1}^{d_k} q_i k_i\right) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i) = d_k

(The last step uses independence - the covariance terms are zero.)

Effect on softmax:

For large dkd_k, dot products have large magnitude. Consider softmax([z,z])\text{softmax}([z, -z]) for large zz:

softmax([z,z])=[ezez+ez,ezez+ez][1,0] for large z\text{softmax}([z, -z]) = \left[\frac{e^z}{e^z + e^{-z}}, \frac{e^{-z}}{e^z + e^{-z}}\right] \approx [1, 0] \text{ for large } z

The Jacobian of softmax approaches zero as inputs become extreme (the output saturates). This means:

  1. Gradients through the softmax become near-zero
  2. Attention becomes a nearly hard argmax selection
  3. Training becomes unstable because the model cannot smoothly adjust attention weights

Dividing by dk\sqrt{d_k} restores unit variance: Var(qk/dk)=1\text{Var}(q \cdot k / \sqrt{d_k}) = 1.

Scoring Rubric:

  • Strong Hire: Derives mean and variance correctly with proper independence arguments, explains softmax saturation with gradient vanishing, connects to training instability.
  • Lean Hire: Gets the variance formula correct but cannot fully explain the softmax implications.
  • No Hire: Cannot derive the variance or gives a hand-wavy explanation without math.

Problem 2: Multi-Head Attention Design

You are designing a Transformer for a specific task. You have dmodel=768d_{\text{model}} = 768 and must choose the number of heads.

(a) What are the tradeoffs between 4, 8, 12, and 24 heads? (b) What is the minimum dkd_k you would consider? Why? (c) If you increase heads from 12 to 24, does the parameter count change?

Hint 1 -- Direction

Each head operates in dk=dmodel/hd_k = d_{\text{model}} / h dimensions. More heads means each head works in a lower-dimensional space. Consider what "low-dimensional attention" means for the types of patterns a head can learn.

Hint 2 -- Insight

With h=24h = 24 and dmodel=768d_{\text{model}} = 768, each head has dk=32d_k = 32. This means each head's attention is computed in a 32-dimensional space - enough for simple pattern matching but potentially too low for complex relational reasoning. The parameter count for Q, K, V projections remains the same (3×dmodel23 \times d_{\text{model}}^2) regardless of head count; only the structure changes. The output projection WOW^O also stays dmodel×dmodeld_{\text{model}} \times d_{\text{model}}.

Hint 3 -- Full Solution + Rubric

(a) Tradeoffs:

Headsdkd_kPer-Head CapacityDiversityNotes
4192High (rich subspace)Low (few heads)Each head is powerful but fewer specializations
896ModerateModerateGood balance, common in base models
1264ModerateHighDefault for BERT-base, GPT-2
2432LowVery highMay lose per-head expressiveness

More heads = more diverse attention patterns but less capacity per head. Empirically, 8-16 heads work well for most tasks. Very high head counts (>16) show diminishing returns unless the model dimension is also very large.

(b) Minimum dkd_k: Generally dk32d_k \geq 32 is recommended. Below 32, the key space is too small for the dot product to distinguish between many different patterns. With dk=16d_k = 16, a head can only represent 16\sim 16 independent "matching criteria," which is often insufficient. The original Transformer paper used dk=64d_k = 64.

(c) Parameter count with 12 vs 24 heads:

For Q projections: WQ=[W1Q;W2Q;;WhQ]W^Q = [W_1^Q; W_2^Q; \ldots; W_h^Q], which concatenated is dmodel×dmodeld_{\text{model}} \times d_{\text{model}} regardless of hh.

Total attention parameters: 4×dmodel24 \times d_{\text{model}}^2 (Q, K, V, O), which is 4×7682=2,359,2964 \times 768^2 = 2,359,296 parameters - the same for both 12 and 24 heads.

The parameter count does NOT change with the number of heads. What changes is the structure: the same total parameters are divided differently among more or fewer heads. This is a crucial insight - multi-head attention is a different decomposition of the same parameter budget, not more parameters.

Scoring Rubric:

  • Strong Hire: Correctly analyzes the capacity-diversity tradeoff, identifies that parameter count is invariant to head count, provides minimum dkd_k with reasoning, mentions empirical findings.
  • Lean Hire: Knows the dimension split formula but thinks more heads = more parameters.
  • No Hire: Cannot explain the relationship between head count and per-head dimension, or believes head count changes total parameter count.

Problem 3: Attention Complexity Challenge

You need to process documents of 32,000 tokens with a standard Transformer. Your GPU has 24GB of memory.

(a) Estimate the memory required for the attention matrix alone (32 heads, float16). (b) Can you fit this computation? If not, what are your options? (c) How does Flash Attention help, and what is its complexity?

Hint 1 -- Direction

The attention matrix for one head is T×TT \times T. With 32 heads and float16 (2 bytes per element), compute the total memory. Remember that you also need to store the input/output activations, not just the attention matrix.

Hint 2 -- Insight

320002×32×2=65.532000^2 \times 32 \times 2 = 65.5 GB for the attention matrices alone. This exceeds 24GB by nearly 3x. Options include: Flash Attention (reduces memory to O(T)O(T) by never materializing the full attention matrix), sliding window attention (local context), sparse attention, or chunked processing. Flash Attention is the standard solution - it achieves exact attention with O(T)O(T) memory by tiling the computation.

Hint 3 -- Full Solution + Rubric

(a) Memory estimation:

Attention matrix per head: 32000×32000=1.024×10932000 \times 32000 = 1.024 \times 10^9 elements

Total for 32 heads: 32×1.024×109=3.28×101032 \times 1.024 \times 10^9 = 3.28 \times 10^{10} elements

In float16 (2 bytes): 3.28×1010×2=6.55×10103.28 \times 10^{10} \times 2 = 6.55 \times 10^{10} bytes = ~65.5 GB

This is for the attention matrix ALONE, not counting Q, K, V matrices, model weights, or activations.

(b) Cannot fit. Options:

  1. Flash Attention - Never materializes the full T×TT \times T attention matrix. Computes attention in blocks (tiles), keeping only a running softmax accumulator. Memory: O(T)O(T) instead of O(T2)O(T^2). Exact computation (no approximation).

  2. Sliding window attention - Each token attends only to a local window of ww tokens. Memory: O(Tw)O(T \cdot w). Approximate but captures local context well. Used in Mistral, Longformer.

  3. Sparse attention - Combine local windows with global tokens (every kk-th position attends to all). Memory: O(TT)O(T \cdot \sqrt{T}) or similar. Used in BigBird, Sparse Transformer.

  4. Chunked processing - Split the document into overlapping chunks, process each independently, merge results. Memory: O(C2)O(C^2) where CC is chunk size. Loses cross-chunk dependencies.

  5. Gradient checkpointing - Recompute attention during backward pass instead of storing it. Reduces memory by ~50% at the cost of ~33% more compute.

(c) Flash Attention:

  • Key idea: Compute attention in tiles that fit in GPU SRAM (fast on-chip memory). Use the online softmax trick to compute exact softmax without materializing the full attention matrix.
  • Memory complexity: O(T)O(T) - stores only the output, not the intermediate attention matrix.
  • Time complexity: Same O(T2d)O(T^2 \cdot d) FLOPs, but significantly faster wall-clock time due to reduced memory I/O (the bottleneck is HBM reads/writes, not FLOPs).
  • Result: Exact attention, 2-4x faster than standard implementation, ~10-20x less memory.

With Flash Attention, the 32K-token attention computation uses ~2GB instead of 65.5GB.

Scoring Rubric:

  • Strong Hire: Computes memory correctly, identifies Flash Attention as the primary solution, explains its tiling mechanism and the distinction between memory and compute complexity, mentions practical alternatives like sliding window.
  • Lean Hire: Gets the memory estimate right and knows Flash Attention exists but cannot explain how it works.
  • No Hire: Cannot estimate memory or suggests reducing model size as the only solution.

Problem 4: Self-Attention vs Cross-Attention

Design the attention pattern for a text-to-image model. You have text tokens T=[t1,,tn]T = [t_1, \ldots, t_n] and image patches I=[p1,,pm]I = [p_1, \ldots, p_m].

(a) Which interactions need self-attention and which need cross-attention? (b) Write the Q, K, V source for each attention layer. (c) What happens if you accidentally swap Q and K in cross-attention?

Hint 1 -- Direction

Text tokens need to understand each other (self-attention), image patches need to understand each other (self-attention), and image patches need to be conditioned on the text (cross-attention). Think about which sequence provides the queries and which provides the keys/values.

Hint 2 -- Insight

In the standard text-to-image architecture (like Stable Diffusion's U-Net), the image model has self-attention layers (patches attend to patches) and cross-attention layers (patches query the text). In cross-attention, Q comes from image patches and K, V come from text. This means image patches decide what text to look at. If you swap Q and K, the text tokens would decide what image patches to attend to, which inverts the conditioning direction and produces poor results.

Hint 3 -- Full Solution + Rubric

(a) Attention interactions:

LayerTypePurpose
Text encoder self-attentionSelf-attention over textText tokens understand context
Image self-attentionSelf-attention over patchesPatches maintain spatial coherence
Image-text cross-attentionCross-attentionPatches query relevant text tokens

(b) Q, K, V sources:

LayerQ SourceK SourceV Source
Text self-attentionText tokensText tokensText tokens
Image self-attentionImage patchesImage patchesImage patches
Image-text cross-attentionImage patchesText tokensText tokens

(c) Swapping Q and K in cross-attention:

If Q = text, K = image, V = image: text tokens would query the image, and the output would be a text-length sequence of image-derived features. This is:

  1. Wrong dimensions: output is n×dn \times d (text length) instead of m×dm \times d (image length)
  2. Wrong conditioning: instead of "image attends to relevant text" (what we want), we get "text attends to relevant image patches" (useful for captioning, not generation)
  3. Specifically, this would produce a text representation enriched with image features - the pattern used in image captioning, not image generation

The asymmetry of Q vs K,V is fundamental: Q determines the output length and who is "asking," while K,V determine what is being "looked at."

Scoring Rubric:

  • Strong Hire: Correctly identifies all three attention types, specifies Q/K/V sources precisely, explains the consequence of swapping Q and K with a concrete understanding of the dimension and conditioning implications.
  • Lean Hire: Gets the Q/K/V sources right but cannot explain the swapping consequence clearly.
  • No Hire: Confuses self-attention and cross-attention, or cannot specify which sequence provides Q vs K.

Problem 5: Attention from Scratch

Implement scaled dot-product attention in NumPy. Handle an optional causal mask.

Hint 1 -- Direction

The implementation follows the formula directly: compute QKTQK^T, scale, optionally add mask, softmax, multiply by V. The causal mask should be an upper-triangular matrix of -\infty (or a very large negative number).

Hint 2 -- Insight

Be careful with numerical stability in softmax. Subtract the maximum value before exponentiating to prevent overflow. The causal mask sets future positions to -\infty BEFORE softmax, which makes them contribute zero attention weight.

Hint 3 -- Full Solution + Rubric
import numpy as np

def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: [batch, n, d_k] (queries)
K: [batch, m, d_k] (keys)
V: [batch, m, d_v] (values)
mask: optional [n, m] or [batch, n, m], -inf for masked positions

Returns: output [batch, n, d_v], attention_weights [batch, n, m]
"""
d_k = Q.shape[-1]

# Step 1: Compute raw scores
scores = np.matmul(Q, K.transpose(0, 2, 1)) # [batch, n, m]

# Step 2: Scale
scores = scores / np.sqrt(d_k)

# Step 3: Apply mask (if provided)
if mask is not None:
scores = scores + mask # -inf positions become -inf

# Step 4: Softmax (numerically stable)
scores_max = np.max(scores, axis=-1, keepdims=True)
exp_scores = np.exp(scores - scores_max)
attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)

# Step 5: Weighted combination
output = np.matmul(attention_weights, V) # [batch, n, d_v]

return output, attention_weights


def create_causal_mask(n):
"""Create upper-triangular mask with -inf above diagonal."""
mask = np.triu(np.full((n, n), -1e9), k=1)
return mask

Key implementation details:

  • Numerical stability: subtract max before exp (log-sum-exp trick)
  • Causal mask uses a large negative number (109-10^9) rather than actual -\infty to avoid NaN
  • The mask is ADDED to scores (additive masking), not multiplied
  • k=1 in np.triu means the diagonal is 0 (not masked) - position ii CAN attend to itself

Scoring Rubric:

  • Strong Hire: Correct implementation with numerical stability, proper mask handling, correct dimensions, clean code with comments.
  • Lean Hire: Correct basic implementation but missing numerical stability or wrong mask direction.
  • No Hire: Cannot implement softmax correctly, wrong matrix dimensions, or applies mask after softmax.

Interview Cheat Sheet

ConceptKey FormulaOne-LinerRed Flag
Scaled dot-productsoftmax(QKT/dk)V\text{softmax}(QK^T/\sqrt{d_k})VSoft database lookupCannot write this formula
Scaling factorDivide by dk\sqrt{d_k}Keeps softmax in good gradient regime"To prevent overflow"
Variance argumentVar(qk)=dk\text{Var}(q \cdot k) = d_kDot product variance grows with dimensionCannot derive this
Multi-headConcat(head_1,...,head_h)WOW^ODifferent heads, different patterns"More heads = more params"
Self-attentionQ, K, V from same sequenceEvery position attends to every positionConfusing with cross-attention
Cross-attentionQ from one seq, K/V from anotherDecoder queries encoderSwapping Q/K source
Causal maskUpper triangular -\inftyPrevent attending to future"Mask applied after softmax"
BahdanauvTtanh(Wss+Whh)v^T \tanh(W_s s + W_h h)First attention (additive)Not knowing historical context
LuongsTWahs^T W_a h or sThs^T hSimplified attention (multiplicative)Confusing with Bahdanau
ComplexityO(T2d)O(T^2 d) time, O(T2)O(T^2) spaceQuadratic in sequence length"Attention is O(T)O(T)"

Spaced Repetition Checkpoints

Day 0 -- Initial Learning

  • Read this entire page
  • Write the scaled dot-product attention formula from memory
  • Derive the variance of the dot product and explain the scaling factor
  • Complete the self-assessment

Day 3 -- First Recall

  • Without notes, write the multi-head attention equations (projections, concat, output)
  • Give the "60-Second Answer" out loud, timed
  • Draw the information flow diagram for self-attention on a 4-token sequence

Day 7 -- Connections

  • Explain the progression from Bahdanau to self-attention to multi-head (5-minute verbal explanation)
  • Do Practice Problem 1 (scaling derivation) on paper without hints
  • Compare self-attention vs cross-attention with 3 concrete examples from different domains

Day 14 -- Application

  • Do Practice Problem 3 (complexity challenge) under timed conditions (10 minutes)
  • Implement scaled dot-product attention from scratch in NumPy without reference
  • Explain to an imaginary interviewer how attention patterns differ across layers

Day 21 -- Mock Interview

  • Have someone ask: "Derive scaled dot-product attention from the seq2seq bottleneck problem, and extend to multi-head"
  • Time yourself: derivation in under 5 minutes, multi-head extension in under 3 minutes
  • Do all 5 practice problems in sequence under timed conditions (50 minutes total)

Key Takeaways

  1. Attention is a soft, differentiable lookup mechanism. Queries match against keys to determine relevance, and values are aggregated according to those relevance scores. This simple idea - originally designed to fix the seq2seq bottleneck - became the foundation of all modern language and vision models.

  2. The scaling factor is not a detail - it is essential. Without dk\sqrt{d_k} scaling, dot products grow with dimension, softmax saturates, and gradients vanish. Understanding the variance argument is table stakes for AI interviews.

  3. Multi-head attention is about subspaces, not parameter count. Different heads project into different subspaces, enabling the model to capture different types of relationships simultaneously. The total parameter count is invariant to head count.

  4. Cross-attention vs self-attention is an asymmetry worth understanding deeply. Who provides the queries determines the output shape and the direction of information flow. This distinction is critical for understanding Transformers, text-to-image models, and multimodal architectures.

© 2026 EngineersOfAI. All rights reserved.