Skip to main content

Self-Attention Mechanism

Reading time: ~40 min · Interview relevance: Essential · Target roles: ML Engineer, AI Engineer, Research Engineer


The Bug That Wasn't a Bug

Maria had been at the NLP team for three months when her manager handed her a debugging task. The production translation model was producing wrong pronoun resolutions in German. Sentences like "The trophy didn't fit in the suitcase because it was too big" were being translated with "it" sometimes referring to the trophy and sometimes the suitcase, seemingly at random.

She spent a week in the LSTM internals. She added logging, visualized hidden states, traced gradient flow. She could see the model struggling - the hidden state at the pronoun position had to carry information about both "trophy" and "suitcase" from earlier in the sentence, and by the time it reached the pronoun, those representations had been overwritten by everything that came between. The information was simply gone.

Her manager had an idea: use attention weights to diagnose the problem. They visualized which input tokens the decoder was "attending to" when it generated the pronoun translation. The attention was diffuse - spread across 15 tokens, with no clear preference for either noun. The model had no principled way to look back at the original nouns and evaluate their size relative to the container.

When the team migrated to a transformer the following year, Maria ran the same diagnostic. She visualized the attention weights at the layer handling pronoun coreference. One head had learned, with no explicit supervision, to attend sharply from pronouns to their antecedents. When processing "it", that head would attend strongly to "trophy" or "suitcase" depending on context. The model had learned the syntactic rule that the LSTM had failed to carry forward.

That's the promise of self-attention. It doesn't have to carry information forward through time - it reaches back and fetches it directly.


Why This Exists - The Long-Range Dependency Problem

Language has dependencies that span arbitrary distances:

  • "The keys to the cabinet are on the table" - subject-verb agreement across a prepositional phrase
  • "Alice told Bob that she believed he was wrong" - coreference spanning 6 words
  • "The company that the investors who lost money in the previous year funded was acquired" - nested clauses spanning the whole sentence

In RNNs, information about a token at position 1 must flow through all intermediate hidden states to influence a prediction at position 100. With each step, the signal degrades. Backpropagation compounds this - gradients must flow back through all those same steps, and they vanish.

Self-attention solves this by creating direct connections between all pairs of tokens. The number of operations separating two tokens in self-attention is always 1, regardless of their distance in the sequence. This is why transformers handle long-range dependencies so much better.


The Core Idea: Attention as Weighted Averaging

Before the formula, the intuition.

Suppose you're reading the sentence "The bank along the river had steep slopes." You want to understand what "bank" means. You look at every other word in the sentence and ask: "How relevant are you to understanding what 'bank' means in this context?"

  • "river": highly relevant - suggests riverbank, not financial institution
  • "steep": moderately relevant - consistent with riverbank
  • "along": relevant - spatial relationship
  • "The", "had": not relevant

You weight each word by its relevance, then combine their meaning. The result is an updated representation of "bank" that incorporates context from the relevant words, especially "river".

Self-attention is exactly this - formalized and differentiated so gradients can flow through it.


Queries, Keys, and Values

The mathematical machinery of attention is built on three concepts borrowed from information retrieval:

Query (Q): What are you looking for? A vector representing the "question" that the current token is asking of the rest of the sequence.

Key (K): What do you contain? A vector representing the "label" or "topic" of each token. When a query is matched against keys, it determines relevance.

Value (V): What information do you carry? A vector representing the actual content that should be returned when a token is selected.

Think of it like a library search: your search term is the Query, the book's catalog entry is the Key, and the book's contents are the Value. You search by Query against Keys to find relevance scores, then retrieve the Values of relevant books.

In a transformer, each token ii computes three vectors from its embedding xix_i:

Qi=xiWQ,Ki=xiWK,Vi=xiWVQ_i = x_i W^Q, \quad K_i = x_i W^K, \quad V_i = x_i W^V

where WQ,WKRdmodel×dkW^Q, W^K \in \mathbb{R}^{d_{model} \times d_k} and WVRdmodel×dvW^V \in \mathbb{R}^{d_{model} \times d_v} are learned weight matrices.

The key word is learned. The model learns what to query for, what keys to advertise, and what values to carry. Different layers learn different types of relationships.


The Scaled Dot-Product Attention Formula

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

Let's unpack each piece:

QKTQK^T: The dot product of queries with keys. For a sequence of nn tokens:

  • QRn×dkQ \in \mathbb{R}^{n \times d_k}
  • KTRdk×nK^T \in \mathbb{R}^{d_k \times n}
  • QKTRn×nQK^T \in \mathbb{R}^{n \times n}

Entry (i,j)(i,j) in this matrix is the dot product of token ii's query with token jj's key - a scalar measuring how relevant token jj is to token ii.

/dk/ \sqrt{d_k}: Divides by the square root of the key dimension. We'll explain why in a moment.

softmax()\text{softmax}(\cdot): Applied row-wise. Converts the raw scores for each query into a probability distribution over all keys. Now each row sums to 1 - these are the attention weights.

V\cdot V: Multiplies the attention weights by the values matrix VRn×dvV \in \mathbb{R}^{n \times d_v}. This computes a weighted sum of value vectors, where the weights come from the attention distribution.

The output has shape Rn×dv\mathbb{R}^{n \times d_v} - one output vector per token, where each output is a context-aware mixture of all value vectors.


Why Scale by dk\sqrt{d_k}?

This is a subtle but important design choice, and a common interview question.

For two random vectors qq and kk of dimension dkd_k, each component independently drawn from a distribution with mean 0 and variance 1, their dot product qkq \cdot k has:

  • Mean: 0
  • Variance: dkd_k
  • Standard deviation: dk\sqrt{d_k}

So as dkd_k grows, the dot products grow in magnitude proportionally to dk\sqrt{d_k}.

Why does this matter? Consider what softmax does with large inputs:

softmax([10,0,0,0])=[0.9999,0.0000,0.0000,0.0000]\text{softmax}([10, 0, 0, 0]) = [0.9999, 0.0000, 0.0000, 0.0000]

For dk=64d_k = 64, unscaled dot products can easily reach magnitudes of 8 (64\sqrt{64}). With 64 tokens, one dominant token pushes softmax to near 1, all others to near 0. This creates a near-one-hot attention distribution - the model learns to always attend to exactly one token, which dramatically limits its expressive power.

More critically: the gradient of softmax is approximately 0 when its inputs have very large magnitude. Peaky softmax outputs mean vanishing gradients through the attention operation, making learning effectively stop.

Dividing by dk\sqrt{d_k} keeps the inputs to softmax in a reasonable range (unit variance), allowing gradients to flow and the model to learn distributed attention patterns.

:::tip Why not divide by d_k? Dividing by dkd_k instead of dk\sqrt{d_k} would over-correct - inputs would be too small and softmax would be nearly uniform, losing the ability to focus attention. The square root is the right scaling to maintain unit variance. :::


Attention Weights: What They Mean

The attention weight αij\alpha_{ij} (the (i,j)(i,j) element of the softmax output) tells you: when computing the output for token ii, what fraction of information comes from token jj?

If α12=0.8\alpha_{12} = 0.8, it means that 80% of the information in token 1's output representation comes from token 2. If α77=0.4\alpha_{77} = 0.4, token 7 is attending to itself (40%) as the strongest signal.

In a trained model, these weights capture semantically meaningful patterns:

  • A verb token attends strongly to its subject
  • A pronoun attends to its antecedent
  • An adjective attends to the noun it modifies
  • An end-of-sentence token attends broadly to the full sentence

These patterns emerge from training - they are not programmed in.


Implementation from Scratch in NumPy

import numpy as np


def scaled_dot_product_attention(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
mask: np.ndarray = None,
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute scaled dot-product attention.

Args:
Q: Queries, shape (batch, seq_q, d_k)
K: Keys, shape (batch, seq_k, d_k)
V: Values, shape (batch, seq_k, d_v)
mask: Optional boolean mask, shape broadcastable to (batch, seq_q, seq_k)
True where attention should be ALLOWED (0 where masked out)

Returns:
output: shape (batch, seq_q, d_v)
attn_weights: shape (batch, seq_q, seq_k)
"""
d_k = Q.shape[-1]

# Step 1: Compute raw attention scores - (batch, seq_q, seq_k)
# Q @ K^T: for each query, score against all keys
scores = np.matmul(Q, K.transpose(0, 2, 1)) # (batch, seq_q, seq_k)

# Step 2: Scale by 1/sqrt(d_k) to keep variance stable
scores = scores / np.sqrt(d_k)

# Step 3: Apply mask (set masked positions to -inf before softmax)
if mask is not None:
scores = np.where(mask, scores, -1e9)

# Step 4: Softmax row-wise - convert scores to probabilities
# Numerically stable: subtract max before exp
scores_shifted = scores - scores.max(axis=-1, keepdims=True)
exp_scores = np.exp(scores_shifted)
attn_weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True)

# Step 5: Weighted sum of values
output = np.matmul(attn_weights, V) # (batch, seq_q, d_v)

return output, attn_weights


def self_attention_layer(
X: np.ndarray,
W_q: np.ndarray,
W_k: np.ndarray,
W_v: np.ndarray,
W_o: np.ndarray,
mask: np.ndarray = None,
) -> np.ndarray:
"""
Single self-attention layer.

Args:
X: Input, shape (batch, seq, d_model)
W_q, W_k: Weight matrices, shape (d_model, d_k)
W_v: Weight matrix, shape (d_model, d_v)
W_o: Output projection, shape (d_v, d_model)

Returns:
Output, shape (batch, seq, d_model)
"""
# Project inputs to Q, K, V
Q = np.matmul(X, W_q) # (batch, seq, d_k)
K = np.matmul(X, W_k) # (batch, seq, d_k)
V = np.matmul(X, W_v) # (batch, seq, d_v)

# Compute attention
context, weights = scaled_dot_product_attention(Q, K, V, mask)

# Project back to d_model
output = np.matmul(context, W_o) # (batch, seq, d_model)
return output, weights


# Demonstration
np.random.seed(42)

batch, seq, d_model, d_k, d_v = 2, 5, 8, 4, 4

# Random weights (in real training, these are learned)
W_q = np.random.randn(d_model, d_k) * 0.1
W_k = np.random.randn(d_model, d_k) * 0.1
W_v = np.random.randn(d_model, d_v) * 0.1
W_o = np.random.randn(d_v, d_model) * 0.1

# Random input sequence
X = np.random.randn(batch, seq, d_model)

output, attn_weights = self_attention_layer(X, W_q, W_k, W_v, W_o)

print(f"Input shape: {X.shape}") # (2, 5, 8)
print(f"Output shape: {output.shape}") # (2, 5, 8)
print(f"Attention weights: {attn_weights.shape}") # (2, 5, 5)
print(f"\nAttention weights for first item, first token:")
print(np.round(attn_weights[0, 0, :], 3)) # Should sum to 1.0
print(f"Sum: {attn_weights[0, 0, :].sum():.4f}") # 1.0000

PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class SelfAttention(nn.Module):
"""
Single-head self-attention layer.

For production use, prefer MultiHeadAttention (Lesson 03).
This is the building block - understand this first.
"""

def __init__(self, d_model: int, d_k: int = None):
super().__init__()
self.d_k = d_k or d_model

self.W_q = nn.Linear(d_model, self.d_k, bias=False)
self.W_k = nn.Linear(d_model, self.d_k, bias=False)
self.W_v = nn.Linear(d_model, self.d_k, bias=False)
self.W_o = nn.Linear(self.d_k, d_model, bias=False)

def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: (batch, seq, d_model)
mask: (batch, seq, seq) - True where attention is ALLOWED

Returns:
output: (batch, seq, d_model)
attn_weights: (batch, seq, seq)
"""
Q = self.W_q(x) # (batch, seq, d_k)
K = self.W_k(x) # (batch, seq, d_k)
V = self.W_v(x) # (batch, seq, d_k)

# Scaled dot-product
scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(self.d_k)

if mask is not None:
scores = scores.masked_fill(~mask, float('-inf'))

attn_weights = F.softmax(scores, dim=-1)

# Handle case where all positions are masked (softmax of all -inf = nan)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)

context = torch.bmm(attn_weights, V)
output = self.W_o(context)

return output, attn_weights


def create_causal_mask(seq_len: int) -> torch.Tensor:
"""
Create a causal (autoregressive) mask.
Token i can attend to tokens 0..i (not i+1..n-1).
Returns a boolean tensor where True = allowed.
"""
# Lower triangular matrix: True where j <= i
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))


def create_padding_mask(lengths: torch.Tensor, max_len: int) -> torch.Tensor:
"""
Create a padding mask from sequence lengths.
Token i can attend to token j only if j < length[i].
Returns shape (batch, 1, max_len) - broadcast over query positions.
"""
batch = lengths.shape[0]
positions = torch.arange(max_len).unsqueeze(0) # (1, max_len)
mask = positions < lengths.unsqueeze(1) # (batch, max_len)
return mask.unsqueeze(1) # (batch, 1, max_len)


# Test with a real sentence pair
if __name__ == "__main__":
torch.manual_seed(42)

d_model = 64
attn = SelfAttention(d_model, d_k=32)

# Simulate a batch of 2 sequences, length 6
batch, seq = 2, 6
x = torch.randn(batch, seq, d_model)

# Test 1: no mask (encoder self-attention)
output, weights = attn(x)
print("=== Encoder Self-Attention (no mask) ===")
print(f"Output shape: {output.shape}") # (2, 6, 64)
print(f"Weights shape: {weights.shape}") # (2, 6, 6)
print(f"Weights sum per query: {weights[0].sum(dim=-1)}") # All should be ~1.0

# Test 2: causal mask (decoder self-attention)
causal = create_causal_mask(seq).unsqueeze(0) # (1, 6, 6)
output_causal, weights_causal = attn(x, mask=causal)
print("\n=== Decoder Self-Attention (causal mask) ===")
# The upper triangle of weights should be ~0 (future positions masked)
print("Weights for first item (upper triangle should be ~0):")
print(torch.round(weights_causal[0], decimals=3))

# Test 3: check parameter count
total = sum(p.numel() for p in attn.parameters())
print(f"\nParameter count for d_model=64, d_k=32: {total}")
# W_q: 64*32=2048, W_k: 64*32=2048, W_v: 64*32=2048, W_o: 32*64=2048
# Total: 8192

Visualizing Attention Weights

Understanding what attention weights look like is important for both debugging and interpretation.

import torch
import matplotlib.pyplot as plt
import seaborn as sns


def visualize_attention(
tokens: list[str],
attn_weights: torch.Tensor,
head_idx: int = 0,
title: str = "Attention Weights",
):
"""
Visualize attention weights as a heatmap.
attn_weights: (batch, heads, seq, seq) or (batch, seq, seq)
"""
if attn_weights.dim() == 4:
weights = attn_weights[0, head_idx].detach().cpu().numpy()
else:
weights = attn_weights[0].detach().cpu().numpy()

plt.figure(figsize=(8, 6))
sns.heatmap(
weights,
xticklabels=tokens,
yticklabels=tokens,
cmap="Blues",
vmin=0, vmax=1,
annot=True, fmt=".2f",
)
plt.title(title)
plt.xlabel("Key (attending from)")
plt.ylabel("Query (attending to)")
plt.tight_layout()
plt.show()


# Example: manually crafted attention for illustration
tokens = ["The", "bank", "by", "the", "river"]
# Attention weights showing "bank" (position 1) attending to "river" (position 4)
example_weights = torch.tensor([[
[0.6, 0.1, 0.1, 0.1, 0.1], # "The" attends mostly to itself
[0.1, 0.2, 0.1, 0.1, 0.5], # "bank" attends strongly to "river"
[0.2, 0.1, 0.4, 0.1, 0.2], # "by" fairly distributed
[0.5, 0.1, 0.1, 0.2, 0.1], # "the" attends to "The"
[0.1, 0.1, 0.1, 0.1, 0.6], # "river" attends to itself
]])

print("Simulated attention: 'bank' attends to 'river' (0.5 weight)")
print("This is how context disambiguation works in practice")

Computational Complexity

OperationTimeSpace
QKTQK^T (score matrix)O(n2dk)O(n^2 d_k)O(n2)O(n^2)
SoftmaxO(n2)O(n^2)O(n2)O(n^2)
weightsV\text{weights} \cdot VO(n2dv)O(n^2 d_v)O(ndv)O(nd_v)
TotalO(n2d)O(n^2 d)O(n2)O(n^2)
RNN comparisonO(nd2)O(n d^2)O(nd)O(nd)

The quadratic O(n2)O(n^2) memory requirement is the key constraint. For n=4096n = 4096 and float16 values, a single attention matrix is 40962×2 bytes=334096^2 \times 2 \text{ bytes} = 33 MB - per layer, per batch element, per head.

This is why the community has developed:

  • Flash Attention: computes attention in tiles to avoid materializing the full matrix - same result, O(n)O(n) memory
  • Sparse attention: only compute attention for a subset of pairs (Longformer, BigBird)
  • Linear attention: approximate attention with O(n)O(n) complexity using kernel tricks

Production Engineering Notes

Numerical Stability in Float16

When running inference in float16 (BF16 or FP16), attention scores can overflow. The softmax of very large inputs is numerically problematic. In practice:

  1. Always compute the softmax in float32, then cast back: attn_weights = F.softmax(scores.float(), dim=-1).to(scores.dtype)
  2. Ensure the scaling factor 1/dk1/\sqrt{d_k} is applied - without it, scores are ~8× too large in float16 range

Attention Dropout

The original paper applies dropout to attention weights (after softmax, before multiplying with V). This randomly zeros out some attention connections during training, acting as regularization. Common values: 0.1–0.3. In inference, dropout is disabled.

Flash Attention in PyTorch

PyTorch 2.0+ includes Flash Attention as F.scaled_dot_product_attention():

import torch.nn.functional as F

# PyTorch's optimized implementation - use this in production
output = F.scaled_dot_product_attention(
query, key, value,
attn_mask=mask,
dropout_p=0.1 if training else 0.0,
is_causal=False, # Set True for decoder causal attention
)

This uses Flash Attention under the hood when possible, reducing memory by ~10× for long sequences.


Common Mistakes

:::danger Computing attention in the wrong dimension A very common implementation bug: transposing KK along the wrong dimension. The attention score matrix requires QKTQ \cdot K^T where KTK^T is transposed in the last two dimensions (sequence and key-dim dimensions), NOT the batch dimension. Wrong: K.transpose(0, 1). Correct: K.transpose(-2, -1) or K.transpose(1, 2) for 3D tensors. :::

:::danger Forgetting to mask before softmax, not after Masking attention weights AFTER softmax is wrong - you'd need to renormalize. The mask must be applied as large negative values BEFORE softmax so those positions compute to near-zero after exponentiation. If you zero out after softmax, the attention distribution no longer sums to 1. :::

:::warning Attention weights summing to 0 after masking If all positions in a query's attention row are masked (e.g., a padding token attending to a fully-padded sequence), softmax of all -inf produces nan. Always add nan handling: torch.nan_to_num(attn_weights, nan=0.0). :::

:::tip Attention heads vs attention layers Self-attention heads (in multi-head attention) are parallel attention operations within a single layer. Attention layers are separate transformer blocks stacked on top of each other. GPT-3 has 96 attention layers, each with 96 heads - a total of 9,216 individual attention heads. :::


Interview Q&A

Q1: Walk me through the scaled dot-product attention formula from first principles.

Answer: Start with the problem: we want token ii's output representation to incorporate information from all other tokens, weighted by relevance.

Step 1: Project the input token vectors into three spaces using learned matrices: Q=XWQQ = XW^Q, K=XWKK = XW^K, V=XWVV = XW^V. Each token now has a query (what it's looking for), a key (what it offers), and a value (what information it carries).

Step 2: Compute relevance scores. Token ii's query qiq_i is dot-producted with every token jj's key kjk_j. The dot product qikjq_i \cdot k_j measures alignment - if ii is "looking for" what jj "offers", the score is high. This gives us an (n×n)(n \times n) score matrix.

Step 3: Scale by 1/dk1/\sqrt{d_k}. Without this, for large dkd_k, dot products become large in magnitude, softmax becomes near-one-hot, and gradients vanish.

Step 4: Softmax row-wise. Each row of scores becomes a probability distribution over all tokens. These are the attention weights - how much weight token ii gives to each other token.

Step 5: Weighted sum of values. Multiply attention weights by VV. Token ii's output is a mixture of all value vectors, mixed by relevance.

Q2: What does each attention weight in the softmax output represent? What are the practical implications?

Answer: αij\alpha_{ij} (attention weight from query ii to key jj) represents the fraction of information in token ii's output that comes from token jj's value vector. Since weights are a probability distribution (sum to 1), they partition the "attention budget" across all tokens.

Practical implications:

  • If αij=0.9\alpha_{ij} = 0.9 for one jj, that head is effectively doing hard attention - looking at almost one token. This can be good (coreference resolution) or indicate an attention head has become degenerate.
  • Uniform distribution (αij=1/n\alpha_{ij} = 1/n) means the token is averaging over everything - no focused attention. This can indicate the head hasn't learned to differentiate.
  • For debugging: unexpected attention patterns (e.g., all attention to [CLS], or to padding tokens) are a red flag that something is wrong with masking or training.

Q3: Why is self-attention better than RNNs for long-range dependencies?

Answer: Three reasons:

  1. Path length: In an RNN, information from token at position 1 must flow through n1n-1 sequential hidden state updates to reach position nn. Each step risks information loss. In self-attention, there is a direct path of length 1 from any token to any other token - just one attention operation.

  2. Gradient flow: Backpropagation follows the same path as information flow. In RNNs, gradients must flow through all nn time steps, vanishing exponentially. In self-attention, the gradient path from output token nn to input token 1 is through one attention operation - gradient flows directly.

  3. Parallelism: RNNs must process tokens sequentially (each step depends on the previous). Self-attention computes all relationships in a single matrix multiplication, enabling full GPU parallelism.

The tradeoff: self-attention is O(n2)O(n^2) in computation and memory, while RNNs are O(n)O(n). For very long sequences (tens of thousands of tokens), this becomes a bottleneck that requires approximations like Flash Attention or sparse attention.

Q4: What is the difference between self-attention and cross-attention?

Answer:

  • Self-attention: Q, K, and V all come from the same sequence. "How should each token in sequence A attend to other tokens in sequence A?" Used in both encoder and decoder layers.
  • Cross-attention: Q comes from one sequence, K and V come from another sequence. "How should each token in sequence B attend to tokens in sequence A?" Used in encoder-decoder models where the decoder queries the encoder output.

The mathematical operation is identical - just the source of Q, K, V differs. In cross-attention for translation, the decoder's current token (Q) learns to attend to the most relevant source tokens (K, V from encoder), which is how the model learns alignment between languages.

Q5: How would you debug unexpected model behavior using attention weights?

Answer: Attention visualization is a powerful diagnostic tool:

  1. Visualize weight heatmaps: Plot the n×nn \times n attention matrix for each head. Look for patterns: diagonal (each token attends to itself), off-diagonal clusters (attending to specific positions), or uniform (diffuse attention).

  2. Check for attention sinks: In some models, [CLS], [SEP], or certain delimiters accumulate nearly all attention weight - a known phenomenon called "attention sinks" (Xiao et al., 2023). If you see this, it may indicate a head has become non-functional.

  3. Layer-by-layer analysis: Early layers often show local/syntactic attention patterns. Later layers show more semantic, long-range patterns. If the pattern is reversed or uniform throughout, that's a signal.

  4. Head specialization: Different heads often learn different functions (syntax, coreference, local context). Heads that have collapsed to near-uniform attention or near-degenerate one-hot attention may be pruned with little accuracy loss (Michel et al., 2019 found that 70-90% of heads can be removed from BERT with minimal accuracy loss).

  5. Compare with expected behavior: For a classification task, the [CLS] token should accumulate information from the full sequence by the last layer. If it only attends to the first few tokens, the model has a problem.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Transformer Self-Attention demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.