Skip to main content

Attention Is All You Need

Reading time: ~35 min · Interview relevance: Essential · Target roles: ML Engineer, AI Engineer, Research Engineer


The Production Incident That Changed NLP

It was early 2017, and the Google Brain team was frustrated. Not with a failed experiment - with success. Their best translation model, a stacked LSTM encoder-decoder with attention, was translating sentences remarkably well. The problem was speed.

Every sentence had to be processed sequentially, left to right, one word at a time. To translate "The European Economic Community was created in 1957," the model had to read "The", update its hidden state, read "European", update again, and so on through the entire sequence before it could start generating the translation. This meant no GPU parallelization during training - the bottleneck was not memory or compute capacity, it was sequential dependency.

The team had tried increasing the LSTM depth - stacking 8 layers, then 16. Each additional layer made the models better but training slower and harder to parallelize. They had tried better initialization, better gradient clipping, learning rate schedules. The models worked. They just couldn't scale.

Ashish Vaswani, the lead author, had been looking at attention mechanisms - the small addition that let encoder-decoders "look back" at source tokens while generating. He had a thought: what if attention wasn't just a supplement to the RNN? What if it was the whole computation? What if you threw away the recurrence entirely and used only attention?

His colleagues thought it was risky. Recurrence was the standard. LSTMs had just spent a decade proving themselves. The idea of removing them entirely was not obviously better - it was a research bet. Noam Shazeer pushed back on the parallelization argument and was convinced. Jakob Uszkoreit had been independently thinking about attention-only models. The team ran experiments. The results were not marginal improvements - they were decisive. On English-to-German translation, the transformer scored 28.4 BLEU, a full 2 points above the previous state of the art. On English-to-French, 41.0 BLEU - beating all prior results while training 8 times faster.

They titled the paper "Attention Is All You Need." The provocative name was the point. They had replaced an entire computational paradigm with a single mechanism. That was June 2017. Every LLM in production today is a direct descendant of what they published.


Why This Exists - The RNN Problem

To understand why the transformer matters, you need to understand what it replaced and why that failed.

The Seq2Seq Framework

By 2015, the dominant approach for sequence-to-sequence tasks (translation, summarization, question answering) was the encoder-decoder RNN. The setup:

  1. An encoder RNN reads the input sequence one token at a time, updating a hidden state hth_t at each step
  2. After reading the entire input, it produces a "context vector" - the final hidden state hTh_T
  3. A decoder RNN takes that context vector and generates the output sequence, one token at a time

This worked for short sequences. For long ones, it catastrophically failed. The entire meaning of a 50-word sentence had to be compressed into a single fixed-size vector. Try compressing a paragraph into one sentence and you'll feel the problem intuitively.

The Attention Addition (2015)

Bahdanau et al. (2015) added attention to the decoder: instead of using only the final context vector, the decoder at each output step could "look at" all encoder hidden states and compute a weighted sum. This was a huge improvement.

But a fundamental problem remained: sequential computation.

The Sequential Processing Bottleneck

With RNNs - whether plain RNNs, LSTMs, or GRUs - computing hth_t requires ht1h_{t-1}. There is no way around this. Step tt cannot start until step t1t-1 finishes. This means:

  • A sequence of length nn requires nn sequential steps
  • You cannot parallelize the forward pass across a sequence
  • GPUs, which are built for massive parallelism, sit largely idle during training
  • Training on long sequences is slow, regardless of hardware investment

For a sequence of length 512, your GPU utilization drops dramatically. For length 1024, it's even worse. This is not an implementation problem - it is a fundamental property of recurrence.

Vanishing Gradients

The second failure mode: gradients vanishing over long sequences. When you backpropagate through 100 steps of an LSTM, the gradients must flow backward through each gate at each time step. Despite the clever gating mechanisms in LSTMs, this signal degrades. In practice, LSTMs struggle to capture dependencies between tokens that are more than ~100 positions apart.

This is why models would lose the subject of a sentence by the time they reached the verb 20 words later.


The Transformer Proposal: Replace Recurrence With Attention

The key insight of Vaswani et al. (2017) was deceptively simple:

For translation, what you actually need is to compute relationships between all pairs of tokens in the input. You do not need to process them sequentially to do this.

If you have a sequence of nn tokens, you can compute a (n×n)(n \times n) attention matrix in a single parallel operation. Every token can "look at" every other token simultaneously. No sequential dependency. Full GPU utilization.

The transformer does this with three innovations:

  1. Self-attention: Every token in the input attends to every other token simultaneously
  2. Multi-head attention: Run several attention operations in parallel, each learning different relationships
  3. Positional encoding: Since there's no sequential processing, inject position information explicitly

The Architecture: A Walking Tour

The original transformer is an encoder-decoder architecture with 6 stacked layers in each half. Let's walk through it top-down.

The Encoder Stack

Each of the 6 encoder layers contains:

  1. Multi-head self-attention: Each token attends to all other tokens in the input sequence. This is how "bank" learns to attend to "river" vs "money" depending on context.

  2. Add & Norm (residual + layer norm): The output of attention is added to the input (residual connection), then normalized. This is what makes 6-layer-deep networks trainable without gradient explosion.

  3. Feed-forward network (FFN): A two-layer MLP applied independently to each position. Dimensions: dmodel4×dmodeldmodeld_{model} \to 4 \times d_{model} \to d_{model}. For dmodel=512d_{model} = 512, this is 5122048512512 \to 2048 \to 512.

The encoder produces a sequence of contextualized representations - one per input token.

The Decoder Stack

Each decoder layer has three sublayers:

  1. Masked multi-head self-attention: The decoder attends to its own previously generated tokens. The "masked" part means it cannot look at future tokens - this enforces autoregressive generation.

  2. Multi-head cross-attention: The decoder queries the encoder output. Queries come from the decoder, keys and values come from the encoder. This is how translation works - "where in the source sentence should I look while generating this output word?"

  3. Feed-forward network: Same structure as the encoder FFN.

The Three Types of Attention

The paper introduces three attention variants:

TypeQuery sourceKey/Value sourceUse case
Self-attention (encoder)Encoder tokensEncoder tokensSource context
Masked self-attention (decoder)Decoder tokensDecoder tokens (past only)Target history
Cross-attentionDecoder tokensEncoder outputSource-target alignment

Model Dimensions

The original transformer (base model) uses:

  • dmodel=512d_{model} = 512: embedding dimension
  • dff=2048d_{ff} = 2048: feed-forward inner dimension (4×)
  • h=8h = 8: number of attention heads
  • dk=dv=64d_k = d_v = 64: dimension per head (512/8512 / 8)
  • N=6N = 6: number of encoder/decoder layers
  • Vocabulary: 37,000 BPE tokens (shared encoder/decoder)
  • Total parameters: ~65M

The "large" variant uses dmodel=1024d_{model} = 1024, h=16h = 16, giving ~213M parameters.


Why This Was Controversial

The ML community had good reasons to be skeptical in 2017:

"RNNs process sequences - attention is just a helper." The dominant view was that sequential structure was fundamentally necessary for language. Attention was seen as an enhancement, not a replacement.

"Attention is O(n²) - it won't scale." Computing all-pairs attention on a sequence of length nn requires O(n2)O(n^2) operations. For sequences of length 512, this is 262,144 pairs. Critics predicted this would be a bottleneck.

"You need memory to carry information forward." RNNs had hidden states that could theoretically carry information across the whole sequence. Pure attention computes fresh at every layer - where does the "memory" live?

In practice, all three concerns turned out to be wrong, or at least secondary:

  • Language is about relationships between tokens, not sequential state - attention captures this more directly
  • For typical sequence lengths (under 2048 tokens), O(n2)O(n^2) attention is faster than sequential RNN computation in wall-clock time due to GPU parallelism
  • The feed-forward layers act as a form of memory, and residual connections preserve information across layers

Impact on the Field

The transformer's dominance happened quickly:

  • 2018: BERT (Devlin et al.) - encoder-only transformer, pretrained on masked language modeling. Achieved state of the art on 11 NLP benchmarks simultaneously.
  • 2018: GPT (Radford et al.) - decoder-only transformer, language modeling pretraining. Showed the power of generative pretraining.
  • 2019: GPT-2 - showed that scaling transformers produced dramatic improvements.
  • 2019: T5 (Raffel et al.) - encoder-decoder transformer, framed everything as text-to-text.
  • 2020: GPT-3 - 175B parameters, few-shot learning, the moment that convinced the industry.
  • 2022: ChatGPT - GPT-3.5 + RLHF, the public moment.
  • 2023+: GPT-4, Claude, Gemini, LLaMA - all transformers, all descended from the 2017 paper.

A Minimal Transformer in PyTorch

Here is a complete, minimal transformer encoder implementation. This is not toy code - it captures the essential structure:

import torch
import torch.nn as nn
import math


class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads

# Linear projections for Q, K, V, and output
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""Reshape (batch, seq, d_model) -> (batch, heads, seq, d_k)"""
batch, seq, _ = x.shape
x = x.view(batch, seq, self.num_heads, self.d_k)
return x.transpose(1, 2)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
) -> torch.Tensor:
batch = query.size(0)

# Project and split into heads
Q = self.split_heads(self.W_q(query)) # (batch, heads, seq_q, d_k)
K = self.split_heads(self.W_k(key)) # (batch, heads, seq_k, d_k)
V = self.split_heads(self.W_v(value)) # (batch, heads, seq_k, d_v)

# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V) # (batch, heads, seq_q, d_v)

# Concatenate heads and project
context = context.transpose(1, 2).contiguous()
context = context.view(batch, -1, self.d_model)
return self.W_o(context)


class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)


class EncoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.ff = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Self-attention with residual connection and layer norm
attn_out = self.attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_out)) # post-norm style

# Feed-forward with residual
ff_out = self.ff(x)
x = self.norm2(x + self.dropout(ff_out))
return x


class TransformerEncoder(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
num_layers: int = 6,
d_ff: int = 2048,
max_seq_len: int = 512,
dropout: float = 0.1,
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = self._create_positional_encoding(max_seq_len, d_model)
self.layers = nn.ModuleList(
[EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
)
self.norm = nn.LayerNorm(d_model)
self.d_model = d_model

def _create_positional_encoding(self, max_len: int, d_model: int) -> torch.Tensor:
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe.unsqueeze(0) # (1, max_len, d_model)

def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
seq_len = x.size(1)
# Scale embeddings, add positional encoding
x = self.embedding(x) * math.sqrt(self.d_model)
x = x + self.pos_encoding[:, :seq_len, :].to(x.device)

for layer in self.layers:
x = layer(x, mask)

return self.norm(x)


# Quick test
if __name__ == "__main__":
model = TransformerEncoder(vocab_size=10000, d_model=512, num_heads=8, num_layers=6)
tokens = torch.randint(0, 10000, (2, 32)) # batch=2, seq_len=32
output = model(tokens)
print(f"Input shape: {tokens.shape}") # (2, 32)
print(f"Output shape: {output.shape}") # (2, 32, 512)

total_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {total_params:,}") # ~25M for this config

Production Engineering Notes

Training Stability

The original transformer uses a warmup learning rate schedule - start low, ramp up linearly for 4000 steps, then decay as lr=dmodel0.5min(step0.5,stepwarmup1.5)\text{lr} = d_{model}^{-0.5} \cdot \min(\text{step}^{-0.5}, \text{step} \cdot \text{warmup}^{-1.5}). This is critical. Without warmup, transformer training is unstable in the first few thousand steps.

Dropout Placement

The original paper applies dropout to:

  1. Each sublayer output before residual addition
  2. The sum of embeddings and positional encodings

This is important for regularization and is often omitted in naive implementations.

Precision and Numerical Stability

In production, attention scores can overflow float16. The softmax of very large negative values (-inf from masking) is fine, but the scaling by 1/dk1/\sqrt{d_k} is essential. Without it, for dk=64d_k = 64, a random dot product can have magnitude ~8, leading to a near-one-hot softmax distribution and vanishing gradients.

Flash Attention

Standard attention requires materializing the full (n×n)(n \times n) attention matrix in memory. For sequences of length 4096, this is 40962×44096^2 \times 4 bytes = 64 MB per layer per batch element. Flash Attention (Dao et al., 2022) recomputes attention in tiles, reducing memory from O(n2)O(n^2) to O(n)O(n) while maintaining exact results. It is now the standard in production systems.


Common Mistakes

:::danger Removing the scaling factor If you implement attention without the 1/dk1/\sqrt{d_k} scaling, your model will appear to train (loss decreases) but the attention weights will be extremely peaky - the model effectively learns to use only one or two source tokens regardless of context. This is subtle and hard to diagnose. Always scale. :::

:::danger Wrong mask orientation Padding masks and causal masks have different shapes and semantics. A padding mask zeroes out specific positions in the key dimension. A causal mask zeroes out future positions in both dimensions. Mixing them up causes the model to either ignore padding tokens correctly but attend to future tokens, or vice versa. Test both on a toy example before trusting your implementation. :::

:::warning Post-norm vs pre-norm The original paper uses post-norm (normalize after adding the residual). Modern LLMs use pre-norm (normalize the input before the sublayer). Pre-norm is more stable for very deep networks. If you are implementing a transformer deeper than 12 layers, use pre-norm or you will likely encounter training instability. :::

:::tip The paper is 11 pages and worth reading Despite its impact, "Attention Is All You Need" is remarkably readable. The architecture sections are clear, the ablation studies in Section 6 are genuinely informative, and the results speak for themselves. Read the actual paper - not a blog post summary. :::


Interview Q&A

Q1: Why did the transformer replace RNNs? What were the specific limitations?

Answer: Three core problems with RNNs:

  1. Sequential dependency: Computing hidden state hth_t requires ht1h_{t-1}. This means you cannot parallelize across sequence positions - GPUs sit idle. Training a 512-token sequence requires 512 sequential operations.

  2. Vanishing/exploding gradients: Even with LSTM gating, gradients degrade over long sequences. Dependencies between tokens more than ~100 positions apart are difficult to learn.

  3. Information bottleneck: In encoder-decoder architectures, the entire input had to be compressed into a single context vector. This fails for long sequences.

The transformer addresses all three: attention is computed in parallel across all positions, it has direct paths between any two tokens (no gradient path longer than one layer), and every token has direct access to every other token.

Q2: Explain the three types of attention in the original transformer.

Answer:

  • Encoder self-attention: Q, K, V all come from the encoder input. Each source token can attend to all other source tokens. Builds contextual representations.
  • Decoder masked self-attention: Q, K, V come from the decoder. Masked to prevent attending to future positions - this enforces autoregressive generation (can only condition on what's been generated so far).
  • Decoder cross-attention: Q comes from the decoder, K and V come from the encoder output. This is how the decoder "looks at" the source sequence while generating.

Q3: What is the computational complexity of self-attention and why does it matter?

Answer: Self-attention is O(n2d)O(n^2 d) in time and O(n2)O(n^2) in space, where nn is sequence length and dd is the dimension.

This quadratic scaling is a serious production constraint:

  • For n=512n = 512: 262K attention pairs per layer - fine
  • For n=4096n = 4096: 16.7M pairs per layer - memory-intensive
  • For n=100000n = 100000: 10B pairs - impractical with standard attention

This is why long-context LLMs (128K, 1M tokens) require special techniques: Flash Attention (IO-aware tiling), sparse attention patterns (Longformer), or linear attention approximations.

Q4: The paper uses residual connections and layer normalization. Are both necessary? Can you use one without the other?

Answer: They serve different purposes and work synergistically:

Residual connections (skip connections) solve gradient flow - the gradient has a direct path from output to input, bypassing the sublayer. Without them, deep transformers (6+ layers) would face severe vanishing gradient problems.

Layer normalization controls activation scale - it normalizes activations to have mean 0 and variance 1 across the feature dimension. Without it, activations drift in scale across layers, causing training instability.

You technically could use either alone, but in practice:

  • Residuals without LayerNorm: unstable for deep networks, activations diverge
  • LayerNorm without residuals: gradient flow issues in very deep networks

Modern LLMs use pre-norm (LN before the sublayer) rather than post-norm (LN after) - pre-norm is more stable for 100+ layer networks.

Q5: The original transformer was designed for translation. What design choices reflect that, and how were they changed for LLMs?

Answer:

Translation-specific choices:

  • Encoder-decoder architecture: naturally maps to source-to-target transformation
  • Bidirectional encoder: source language can be read in full before encoding
  • Shared vocabulary between encoder and decoder

Changes for LLMs:

  • Architecture: Moved to decoder-only (GPT) or encoder-only (BERT). Encoder-decoder survives in T5 and sequence-to-sequence models.
  • Scale: Original 65M parameters. GPT-3 is 175B, Claude 3 is estimated 2T+.
  • Context length: Original 512 tokens. Modern LLMs support 128K-1M tokens.
  • Positional encoding: Original sinusoidal encoding. Modern models use RoPE (Rotary Position Embedding) which handles longer contexts better.
  • Activation: Original ReLU. Modern: GeLU, SwiGLU.
  • Normalization: Original post-norm. Modern: pre-norm (Pre-LN), RMSNorm.
  • Tokenization: Original BPE with fixed vocabulary. Modern: byte-level BPE (no unknown tokens), larger vocabularies (50K-128K).

Q6: How would you explain the transformer to a senior engineer who knows deep learning but not NLP?

Answer: "Think of it as a differentiable database query engine.

You have a sequence of tokens. For each token, you compute three vectors: a Query (what am I looking for?), a Key (what do I contain?), and a Value (what information do I carry?).

To generate the output representation for any given token, you compute the dot product between that token's Query and every other token's Key. This gives you a score for 'how relevant is this other token to me?'. Normalize with softmax to get weights that sum to 1, then take a weighted sum of all Values.

Do this for all tokens simultaneously - it's just a matrix multiplication - and every token gets an output that is a weighted mix of all other tokens' values, with mixing weights determined by relevance.

Stack 96 of these layers (as in GPT-3) with feed-forward networks and residual connections between each, train on 300B tokens of text, and you get a model that can write code, translate languages, and pass the bar exam."

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Transformer Self-Attention demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.