Skip to main content

Layer Normalization and Residual Connections

Reading time: ~30 min · Interview relevance: High · Target roles: ML Engineer, AI Engineer, Research Engineer


The Training Collapse Nobody Could Explain

The year was 2018. A team at a major research lab was training a 24-layer transformer for machine translation. The model was larger than anything they'd attempted before. The first few hundred steps looked perfect - loss dropping cleanly, attention patterns forming.

Then, around step 2,000, the loss spiked to infinity and the training crashed. They checked for bugs. Restarted with different seeds. Tried smaller learning rates. Every time, the model would progress for a while and then catastrophically diverge. The loss would suddenly jump from 3.2 to NaN and never recover.

The culprit, which took two weeks to diagnose, was activation drift. As the input signal flowed through 24 layers, the scale of activations accumulated. Layer 1 might output activations with standard deviation ~1. After 10 layers without normalization, the standard deviation might be 100. The softmax in attention was receiving inputs of magnitude 1,000 - producing near-one-hot distributions with near-zero gradients. Backpropagation was flowing through 24 layers of this - the gradient was vanishing before it reached layer 1.

The fix was two mechanisms working together: layer normalization (keep the scale of activations controlled) and residual connections (give gradients a direct highway from loss to any layer). These are not optional stability tricks - they are the reason transformers can be trained at all at depth beyond 6 layers.


The Deep Network Problem: Vanishing and Exploding Gradients

Training a neural network requires backpropagation: computing L/W\partial L / \partial W for every weight WW by applying the chain rule backward through the network.

For a network with LL layers, the gradient of the loss with respect to layer 1's weights involves multiplying LL Jacobian matrices:

LW1=LhLhLhL1h2h1h1W1\frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial h_L} \cdot \frac{\partial h_L}{\partial h_{L-1}} \cdot \ldots \cdot \frac{\partial h_2}{\partial h_1} \cdot \frac{\partial h_1}{\partial W_1}

If these Jacobians have singular values less than 1 (very common), the product shrinks exponentially - vanishing gradients. Early layers receive near-zero gradient and learn nothing. If singular values are greater than 1, gradients explode.

For a 96-layer network (GPT-3), this is 96 matrix multiplications. Without careful architectural choices, training is essentially impossible.

Two solutions:

  1. Residual connections: Add a direct gradient highway that bypasses each sublayer
  2. Layer normalization: Control activation scale at each layer

Residual Connections

He et al. (2015) introduced residual connections (skip connections) in ResNet for image classification. The core idea: instead of each layer learning a full transformation F(x)F(x), it learns a residual F(x)F(x) on top of the identity:

output=x+F(x)\text{output} = x + F(x)

In a transformer, this is applied to every sublayer:

x=x+MultiHeadAttention(x)x = x + \text{MultiHeadAttention}(x) x=x+FeedForward(x)x = x + \text{FeedForward}(x)

Why does this help gradients? The gradient of x+F(x)x + F(x) with respect to xx is:

(x+F(x))x=1+F(x)x\frac{\partial (x + F(x))}{\partial x} = 1 + \frac{\partial F(x)}{\partial x}

The "+1" means there is always a direct gradient path from the output of any block to the input - it doesn't have to flow through F(x)F(x). Even if F(x)/x\partial F(x)/\partial x is near zero (vanishing), the gradient still reaches the earlier layer through the identity path.

In practice, for a 96-layer transformer, there are 96 residual paths creating 96 direct gradient highways from the loss to each layer. The "gradient highway" interpretation is why deep networks became trainable.

Additionally, residual connections serve a functional role: the network doesn't need to re-learn everything from scratch at each layer. Layer ll adds a refinement F(x)F(x) to the representation from layer l1l-1. This makes each layer's task easier - "improve this representation" rather than "compute everything from scratch."


Layer Normalization

Batch Normalization (Ioffe & Szegedy, 2015) was the standard for image models: normalize activations across the batch dimension. For a feature jj at layer ll:

x^j=xjμBσB\hat{x}_j = \frac{x_j - \mu_B}{\sigma_B}

where μB\mu_B and σB\sigma_B are computed over the batch.

Why batch norm fails for transformers:

  1. Variable sequence lengths: padding means batch statistics are distorted
  2. Small batch sizes: batch statistics are noisy with batch size 1-4 (common in LLM training)
  3. Autoregressive inference: at generation time, batch size is often 1, making batch stats meaningless

Ba et al. (2016) introduced Layer Normalization: normalize across the feature dimension (not the batch dimension):

LayerNorm(x)=γxμσ2+ϵ+β\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

where μ\mu and σ\sigma are computed across the dmodeld_{model} features of a single token (not across the batch), and γ,βRdmodel\gamma, \beta \in \mathbb{R}^{d_{model}} are learned scale and shift parameters.

This means each token's representation is normalized independently of other tokens and other batch items. Batch size doesn't matter. Variable lengths don't matter. Inference with a single token works identically.


Pre-Norm vs Post-Norm

The original transformer (Vaswani et al., 2017) uses Post-Norm: apply LayerNorm after the sublayer and after adding the residual:

x=LayerNorm(x+Sublayer(x))x = \text{LayerNorm}(x + \text{Sublayer}(x))

Modern LLMs use Pre-Norm: apply LayerNorm to the input before the sublayer, then add the residual:

x=x+Sublayer(LayerNorm(x))x = x + \text{Sublayer}(\text{LayerNorm}(x))

The difference is subtle but critical for training stability.

Why Pre-Norm is preferred for deep networks:

In Post-Norm, the output of each block goes through LayerNorm before the residual addition. This means the residual path itself passes through normalization, which changes the scale. For very deep networks (100+ layers), this can cause instability.

In Pre-Norm, the residual path (xx directly) is never normalized. The gradient flowing through the residual connection is preserved at its original scale throughout the entire depth. This is why Pre-Norm allows stable training of 100+ layer models without complex learning rate warmup schedules.

Empirically: models with Pre-Norm can be trained with higher learning rates, converge faster, and are more stable than Post-Norm at equal depth (Xiong et al., 2020, "On Layer Normalization in the Transformer Architecture").

import torch
import torch.nn as nn


class PostNormBlock(nn.Module):
"""Original transformer style: x = LayerNorm(x + Sublayer(x))"""

def __init__(self, d_model: int, sublayer: nn.Module):
super().__init__()
self.sublayer = sublayer
self.norm = nn.LayerNorm(d_model)

def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
sublayer_out = self.sublayer(x, **kwargs)
return self.norm(x + sublayer_out) # Norm AFTER residual addition


class PreNormBlock(nn.Module):
"""Modern LLM style: x = x + Sublayer(LayerNorm(x))"""

def __init__(self, d_model: int, sublayer: nn.Module):
super().__init__()
self.sublayer = sublayer
self.norm = nn.LayerNorm(d_model)

def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
normed = self.norm(x) # Norm BEFORE sublayer
sublayer_out = self.sublayer(normed, **kwargs)
return x + sublayer_out # Direct residual (no norm on the skip path)

RMSNorm: Simpler and Faster

Zhang & Sennrich (2019) proposed Root Mean Square Layer Normalization (RMSNorm):

RMSNorm(x)=xRMS(x)γ\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma

where RMS(x)=1di=1dxi2\text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2} and γRd\gamma \in \mathbb{R}^d is a learned scale.

Differences from LayerNorm:

  1. No centering: RMSNorm does not subtract the mean μ\mu - only scales by RMS
  2. No bias: RMSNorm has no β\beta parameter
  3. Faster: computing mean is O(d)O(d) extra operations - small but measurable at scale

Why no centering? The hypothesis is that the centering (subtracting mean) in LayerNorm is not the important part - the rescaling is. Empirically, RMSNorm achieves comparable quality to LayerNorm while being faster.

Used by: LLaMA, LLaMA-2, LLaMA-3, Mistral, Gemma, Falcon. Almost all modern open-source LLMs.

class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Used by LLaMA, Mistral, Gemma.

Reference: Zhang & Sennrich (2019)
"""

def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model)) # learned scale (gamma)
# No bias (beta) - key difference from LayerNorm

def _norm(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize by RMS."""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Cast to float32 for numerical stability, then cast back
output = self._norm(x.float()).type_as(x)
return output * self.weight


# Compare LayerNorm vs RMSNorm
import time

d_model = 4096
batch, seq = 32, 2048

x = torch.randn(batch, seq, d_model, device='cpu')

layer_norm = nn.LayerNorm(d_model)
rms_norm = RMSNorm(d_model)

# Correctness: both should normalize scale
with torch.no_grad():
ln_out = layer_norm(x)
rms_out = rms_norm(x)
print(f"LayerNorm output std: {ln_out.std():.4f}") # ~1.0
print(f"RMSNorm output std: {rms_out.std():.4f}") # ~0.6-1.0

# Parameter count comparison
ln_params = sum(p.numel() for p in layer_norm.parameters())
rms_params = sum(p.numel() for p in rms_norm.parameters())
print(f"\nLayerNorm parameters: {ln_params}") # 2 * d_model (weight + bias)
print(f"RMSNorm parameters: {rms_params}") # d_model (weight only)

# Verify RMSNorm is compatible with LLaMA checkpoints
# LLaMA uses: model.layers.N.input_layernorm.weight and
# model.layers.N.post_attention_layernorm.weight
# Both are RMSNorm weight tensors of shape (d_model,)

The Complete Transformer Block

Putting it together: a modern Pre-Norm transformer block with RMSNorm (as used in LLaMA):

class LLaMATransformerBlock(nn.Module):
"""
A single LLaMA-style transformer block.
Pre-norm, RMSNorm, SwiGLU FFN, RoPE attention.
"""

def __init__(self, d_model: int, num_heads: int, d_ff: int):
super().__init__()

# Attention with pre-norm
self.input_layernorm = RMSNorm(d_model)
self.attention = MultiHeadAttentionWithRoPE(d_model, num_heads) # From Lesson 04

# FFN with pre-norm
self.post_attention_layernorm = RMSNorm(d_model)
self.mlp = FFNWithSwiGLU(d_model, d_ff) # From Lesson 05

def forward(
self,
x: torch.Tensor,
position_ids: torch.Tensor = None,
mask: torch.Tensor = None,
) -> torch.Tensor:
# Attention sublayer (Pre-Norm)
residual = x
x = self.input_layernorm(x) # Norm first
x = self.attention(x, x, x, mask=mask, position_ids=position_ids)
x = residual + x # Residual: direct path preserved

# FFN sublayer (Pre-Norm)
residual = x
x = self.post_attention_layernorm(x) # Norm first
x = self.mlp(x)
x = residual + x # Residual: direct path preserved

return x


# Contrast with original Post-Norm Transformer block:
class OriginalTransformerBlock(nn.Module):
"""
Original Vaswani et al. (2017) transformer block.
Post-norm (LayerNorm after sublayer + residual).
"""

def __init__(self, d_model: int, num_heads: int, d_ff: int):
super().__init__()
self.attention = MultiHeadAttentionBasic(d_model, num_heads)
self.norm1 = nn.LayerNorm(d_model)
self.ff = FFNWithReLU(d_model, d_ff)
self.norm2 = nn.LayerNorm(d_model)

def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
# Attention + post-norm
attn_out = self.attention(x, x, x, mask)
x = self.norm1(x + attn_out) # Norm AFTER addition

# FFN + post-norm
ff_out = self.ff(x)
x = self.norm2(x + ff_out) # Norm AFTER addition

return x

Mathematical Analysis of Gradient Flow

With Pre-Norm residual connections, the gradient of the loss LL with respect to the input xx at layer ll satisfies:

Lxl=LxLk=lL1xk+1xk\frac{\partial L}{\partial x_l} = \frac{\partial L}{\partial x_L} \cdot \prod_{k=l}^{L-1} \frac{\partial x_{k+1}}{\partial x_k}

For Pre-Norm blocks: xk+1=xk+fk(LN(xk))x_{k+1} = x_k + f_k(\text{LN}(x_k))

xk+1xk=I+fk(LN(xk))xk\frac{\partial x_{k+1}}{\partial x_k} = I + \frac{\partial f_k(\text{LN}(x_k))}{\partial x_k}

The identity matrix II in this sum is the residual gradient path. Even if fkxk0\frac{\partial f_k}{\partial x_k} \approx 0 (vanishing gradient through the sublayer), the total derivative is still approximately II - the gradient flows unimpeded through the residual path.

For a 96-layer network, the gradient from layer 96 to layer 1 flows through 96 residual connections. The residual path contributes a gradient of magnitude approximately Lx96\left\|\frac{\partial L}{\partial x_{96}}\right\| regardless of what happens inside the sublayers.


Production Engineering Notes

Gradient Norm Monitoring

A critical production practice: monitor the gradient norm during training. A healthy training run shows:

  • Gradient norm gradually decreasing or stabilizing
  • No sudden spikes (would indicate numerical instability)
  • No trend toward zero (would indicate vanishing gradients)
# Log gradient norm every N steps
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Returns the total norm BEFORE clipping - log this
wandb.log({"grad_norm": grad_norm.item()})

Gradient clipping (max_norm=1.0) is standard for transformers. Without it, occasional large gradient steps (gradient explosion) can destabilize training even with residuals and LayerNorm.

Initialization Matters

For deep Pre-Norm transformers, residual connections must be initialized carefully. If sublayer outputs are too large at initialization, the residuals compound and early training diverges.

Wang et al. (2022) proposed "DeepNet" initialization: scale the residual branch by α\alpha and the weight matrices by β\beta, where α\alpha and β\beta depend on depth. Microsoft's MAGNETO paper validated this for training 1000-layer transformers.

Common practice: initialize output projection weights of attention and FFN with scale σ/2L\sigma / \sqrt{2L} where LL is the number of layers. This ensures residual contributions don't compound exponentially.


Common Mistakes

:::danger Applying LayerNorm on wrong dimension nn.LayerNorm(d_model) normalizes over the last dimension - correct for transformers where the last dimension is features. nn.LayerNorm([seq, d_model]) would normalize over both sequence and feature dimensions - wrong. Always confirm you're normalizing over features only. :::

:::danger Forgetting the residual in your custom block If you implement a transformer block and forget to add the residual (x = sublayer_output instead of x = x + sublayer_output), the model will still train - but poorly. Loss will converge more slowly, deep models will exhibit gradient issues, and the model won't match the expected parameter behavior. Unit test: the output of any fresh (random weights) transformer block should be close to its input when sublayer outputs are small. :::

:::warning RMSNorm uses float32 internally LLaMA's RMSNorm implementation casts to float32 internally for the normalization computation, then casts back. If you implement RMSNorm in float16 throughout, you'll get numerical errors because xi2\sum x_i^2 can overflow or underflow in float16 for large dmodeld_{model}. Always cast to float32 for the normalization step. :::

:::tip Post-norm models need warm learning rates; pre-norm can start higher Post-norm transformers are sensitive to learning rate - they need careful warmup (4000+ steps). Pre-norm transformers are more stable and can tolerate higher peak learning rates and shorter warmup. If you're seeing training instability and your model uses post-norm, switching to pre-norm is often the fix. :::


Interview Q&A

Q1: What problem do residual connections solve, and how do they solve it?

Answer: Residual connections solve the vanishing gradient problem in deep networks.

The problem: in backpropagation, gradients must flow backward through every layer. Each layer's Jacobian (the matrix of partial derivatives) is typically smaller than 1 in magnitude. Multiplying 96 such matrices together - the gradient product for a 96-layer network - produces a value exponentially close to zero. Early layers receive near-zero gradient and effectively don't train.

The solution: instead of learning a direct transformation H(x)H(x), learn a residual F(x)F(x) where the full transformation is H(x)=x+F(x)H(x) = x + F(x). The gradient of the loss with respect to the input of a residual block is:

Lx=L(x+F(x))(1+F(x)x)\frac{\partial L}{\partial x} = \frac{\partial L}{\partial (x + F(x))} \cdot \left(1 + \frac{\partial F(x)}{\partial x}\right)

The "+1" ensures there is always a direct gradient path - even if F/x0\partial F/\partial x \approx 0. Practically, with 96 residual blocks, there are 96 direct gradient paths from the loss to each layer. The training signal reaches every layer.

Secondary benefit: residual blocks let each layer learn incremental improvements. Layer ll doesn't need to recompute the full representation - it adds a correction to the previous layer's output.

Q2: What is the difference between layer normalization and batch normalization, and why does layer norm work better for transformers?

Answer: The difference is which dimension is normalized:

  • Batch norm: normalizes across the batch dimension - for each feature, computes mean and variance across all samples in the batch
  • Layer norm: normalizes across the feature dimension - for each sample (token), computes mean and variance across all features

Why batch norm fails for transformers:

  1. Variable lengths: sequences are padded to equal length; padding tokens distort batch statistics
  2. Small batches: LLM training often uses micro-batch size 1-4 per GPU; batch statistics are extremely noisy
  3. Autoregressive inference: generating one token at a time means batch size 1 - batch statistics are a single sample, making batch norm undefined

Layer norm avoids all these issues. Each token is normalized independently of other tokens and batch items. It works identically during training with batch size 128 and during inference with batch size 1.

Q3: What is Pre-Norm vs Post-Norm? Which should you use and why?

Answer:

  • Post-Norm (original 2017 transformer): x=LayerNorm(x+Sublayer(x))x = \text{LayerNorm}(x + \text{Sublayer}(x))
  • Pre-Norm (modern LLMs): x=x+Sublayer(LayerNorm(x))x = x + \text{Sublayer}(\text{LayerNorm}(x))

The difference is where LayerNorm is applied relative to the residual addition.

Pre-Norm advantages:

  1. The residual path (xx directly) is never passed through normalization - gradient flows through residuals at full scale
  2. More stable at depth - can train 100+ layer models without the learning rate sensitivity of Post-Norm
  3. Higher peak learning rates possible (faster training)
  4. Xiong et al. (2020) proved that Pre-LN can be initialized to have better-conditioned gradients at initialization

Post-Norm characteristics:

  • The final representation is always normalized - the output distribution is more controlled
  • Sometimes achieves slightly better final quality on short training runs
  • Less stable - requires careful learning rate warmup

Recommendation: Use Pre-Norm for any model deeper than 12 layers. Use it by default for LLMs. Post-Norm is mainly of historical interest.

Q4: What is RMSNorm? Why are modern LLMs switching to it?

Answer: RMSNorm (Root Mean Square Layer Normalization) normalizes by RMS instead of by standard deviation: RMSNorm(x)=x/RMS(x)γ\text{RMSNorm}(x) = x / \text{RMS}(x) \cdot \gamma, where RMS(x)=(1/d)xi2\text{RMS}(x) = \sqrt{(1/d)\sum x_i^2}.

Compared to LayerNorm, RMSNorm:

  • Removes the centering step: doesn't subtract the mean μ\mu
  • Removes the bias parameter β\beta
  • Is faster: computing mean is an extra O(d)O(d) reduction operation. For a 4096-dim model with 96 layers, this is a measurable overhead at scale.

Why switch? Empirically, the centering in LayerNorm (subtracting mean) turns out to be largely unnecessary. The important part is the scale normalization. RMSNorm achieves equivalent or better performance (Zhang & Sennrich, 2019), with fewer parameters and faster computation.

Adopted by LLaMA (Meta), Mistral, Gemma (Google), Falcon, and essentially all major open-source LLMs since 2023.

Q5: If you're debugging a 48-layer transformer that is producing NaN loss after 1000 steps of training, where would you look first?

Answer: Systematic diagnosis:

  1. Check Pre-Norm vs Post-Norm: Post-Norm is more susceptible to explosion in deep networks. If Post-Norm, try Pre-Norm.

  2. Gradient norm monitoring: If you were logging gradient norm (which you should be), look at the trend just before the NaN. Gradients spiking to large values (10+) before the NaN indicates gradient explosion - add or reduce max_norm in gradient clipping.

  3. Learning rate warmup: Deep transformers need warmup. If training crashed at step 1000 without a warmup schedule (first 4000 steps), add warmup.

  4. Attention score overflow: In float16, attention scores can overflow without the 1/dk1/\sqrt{d_k} scaling. Check the scale factor is correctly applied.

  5. LayerNorm numerical stability: RMSNorm in float16 can produce NaN if the RMS computation overflows. Ensure it internally casts to float32.

  6. Weight initialization: Check that the initialization scale is appropriate. Too-large initial weights → immediate explosion. Standard transformer init uses std = 0.02, with the residual output projections initialized to std = 0.02 / sqrt(2 * num_layers) to prevent compounding.

  7. Data: NaN in loss can also be caused by NaN or inf in the training data - corrupted embeddings from OOV tokens, or corrupted data samples. Always sanitize training data.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Layer Norm & Residual Connections demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.