Skip to main content

Batch Normalization - The Technique That Enabled Modern Deep Learning

Reading time: ~45 min | Interview relevance: High | Roles: MLE, AI Eng, Research Engineer, Computer Vision Engineer, LLM Engineer

The Real Interview Moment

You are in a Google Brain interview. The interviewer writes the batch normalization formula on the whiteboard and asks: "Walk me through every term. Then tell me - at inference time, where does the mean and variance come from? And finally: the original paper says BatchNorm works because it reduces 'internal covariate shift.' Do you buy that explanation?"

You explain the formula and the inference procedure, then she probes deeper: "Santurkar et al. showed that BatchNorm does not actually reduce internal covariate shift. So why does it work? If the original motivation is wrong, what is the real reason BatchNorm helps training?"

This is a deceptively deep question. The BatchNorm paper is one of the most cited in deep learning, but its original explanation has been largely debunked. Candidates who parrot "it reduces internal covariate shift" without knowing the current understanding reveal that they have not kept up with the field. Candidates who can explain the loss landscape smoothing hypothesis and discuss when to use BatchNorm vs LayerNorm vs GroupNorm vs RMSNorm get a "strong hire."

What You Will Master

  • Derive the complete BatchNorm formula: normalize, scale, shift
  • Explain training vs inference behavior and running statistics
  • Debunk the internal covariate shift explanation
  • Explain the real reason BatchNorm works (loss landscape smoothing)
  • Compare BatchNorm, LayerNorm, GroupNorm, InstanceNorm, and RMSNorm
  • Know when to use each normalization technique
  • Discuss BatchNorm's interaction with residual connections, dropout, and learning rate

Self-Assessment: Where Are You Now?

Skill1 - Cannot2 - Vaguely3 - Can Explain4 - Can Derive5 - Can TeachYour Score
Write the BatchNorm formula___
Explain learnable parameters γ\gamma and β\beta___
Explain training vs inference difference___
Define internal covariate shift___
Explain why ICS is not the real reason___
Describe loss landscape smoothing___
Compare BatchNorm vs LayerNorm___
Explain when to use GroupNorm___
Describe RMSNorm and why LLMs use it___
Discuss BatchNorm + residual connections___

Target: All 4s and 5s before your interview.

Part 1 - The Problem: Training Deep Networks Is Hard

Before Batch Normalization

Training deep networks in 2015 was painful:

  1. Careful initialization required. Bad initialization → dead neurons or exploding activations.
  2. Tiny learning rates. Large learning rates caused divergence.
  3. Saturating activations. Sigmoid and tanh saturated easily, killing gradients.
  4. Slow convergence. Training took days or weeks with careful hyperparameter tuning.

Batch normalization addressed all of these by normalizing the inputs to each layer, ensuring they stay in a well-behaved range throughout training.

The Original Motivation: Internal Covariate Shift

The BatchNorm paper (Ioffe & Szegedy, 2015) motivated the technique as follows:

Internal covariate shift (ICS): As the parameters of layer l1l-1 change during training, the distribution of inputs to layer ll also changes. Layer ll must continuously adapt to these shifting input distributions, which slows training.

The analogy: imagine trying to learn to hit a baseball, but the pitcher changes speed and angle every time you swing. You would learn much faster if the pitches were consistent.

Internal Covariate Shift: Without vs With BatchNorm

Common Trap

Many candidates state that BatchNorm works because it "reduces internal covariate shift." While this was the original motivation, it has been largely debunked. Santurkar et al. (2018) showed that BatchNorm does not significantly reduce ICS, and artificially injecting ICS after BatchNorm layers does not hurt performance. If you cite ICS as the reason BatchNorm works without mentioning the counterevidence, you reveal outdated knowledge. Always present the modern understanding alongside the original claim.

Part 2 - The BatchNorm Formula

Step-by-Step Derivation

Given a mini-batch of mm activations at a particular layer: {x1,x2,,xm}\{x_1, x_2, \ldots, x_m\}

Step 1: Compute batch statistics

μB=1mi=1mxi\mu_B = \frac{1}{m} \sum_{i=1}^{m} x_i

σB2=1mi=1m(xiμB)2\sigma_B^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_B)^2

Step 2: Normalize

x^i=xiμBσB2+ϵ\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}

The ϵ\epsilon (typically 10510^{-5}) prevents division by zero.

After this step, x^\hat{x} has mean 0 and variance 1 within the batch.

Step 3: Scale and shift (affine transformation)

yi=γx^i+βy_i = \gamma \hat{x}_i + \beta

Where γ\gamma and β\beta are learnable parameters (one per channel/feature).

Why the Learnable Parameters?

This is a critical interview question. If we only normalize, we constrain the network's representational power. The identity mapping cannot be represented if the normalization forces mean 0, variance 1.

The learnable γ\gamma and β\beta allow the network to undo the normalization if that is optimal. If the network learns γ=σB\gamma = \sigma_B and β=μB\beta = \mu_B, then:

yi=σBxiμBσB+μB=xiy_i = \sigma_B \cdot \frac{x_i - \mu_B}{\sigma_B} + \mu_B = x_i

The network can recover the identity transformation, meaning BatchNorm cannot reduce representational power.

import numpy as np

def batch_norm_forward(x, gamma, beta, eps=1e-5):
"""
Batch normalization forward pass.

x: input activations, shape (batch_size, features)
gamma: scale parameter, shape (features,)
beta: shift parameter, shape (features,)

Returns: normalized output, cache for backward pass
"""
batch_size = x.shape[0]

# Step 1: Batch mean and variance
mu = x.mean(axis=0) # (features,)
var = x.var(axis=0) # (features,)

# Step 2: Normalize
x_hat = (x - mu) / np.sqrt(var + eps) # (batch_size, features)

# Step 3: Scale and shift
y = gamma * x_hat + beta # (batch_size, features)

cache = (x, x_hat, mu, var, gamma, eps)
return y, cache


# Example: normalize a batch of 32 samples with 64 features
batch_size, features = 32, 64
x = np.random.randn(batch_size, features) * 5 + 3 # Mean ≈ 3, Std ≈ 5

gamma = np.ones(features) # Initialize to 1 (identity scale)
beta = np.zeros(features) # Initialize to 0 (identity shift)

y, cache = batch_norm_forward(x, gamma, beta)

print(f"Input - Mean: {x.mean():.2f}, Std: {x.std():.2f}")
print(f"Output - Mean: {y.mean():.4f}, Std: {y.std():.4f}")
# Output - Mean: ≈0.0000, Std: ≈1.0000

# Now show that gamma and beta can undo normalization
gamma_undo = np.full(features, x.std(axis=0).mean())
beta_undo = np.full(features, x.mean(axis=0).mean())
y_undo, _ = batch_norm_forward(x, gamma_undo, beta_undo)
print(f"Undone - Mean: {y_undo.mean():.2f}, Std: {y_undo.std():.2f}")
# Approximately recovers original distribution

Backward Pass

The gradient computation through BatchNorm is more complex than most layers because the normalization creates dependencies between samples in the batch:

Lγ=i=1mLyix^i\frac{\partial \mathcal{L}}{\partial \gamma} = \sum_{i=1}^{m} \frac{\partial \mathcal{L}}{\partial y_i} \cdot \hat{x}_i

Lβ=i=1mLyi\frac{\partial \mathcal{L}}{\partial \beta} = \sum_{i=1}^{m} \frac{\partial \mathcal{L}}{\partial y_i}

Lxi=γσB2+ϵ(Lyi1mjLyjx^imjLyjx^j)\frac{\partial \mathcal{L}}{\partial x_i} = \frac{\gamma}{\sqrt{\sigma_B^2 + \epsilon}} \left( \frac{\partial \mathcal{L}}{\partial y_i} - \frac{1}{m}\sum_j \frac{\partial \mathcal{L}}{\partial y_j} - \frac{\hat{x}_i}{m} \sum_j \frac{\partial \mathcal{L}}{\partial y_j} \hat{x}_j \right)

The gradient with respect to xix_i depends on the entire batch (through μB\mu_B and σB2\sigma_B^2), which makes BatchNorm behave differently from most layers during backpropagation.

Part 3 - Training vs Inference: The Critical Difference

The Problem

During training, BatchNorm uses the current mini-batch's mean and variance. But during inference, we typically process a single example at a time - there is no batch to compute statistics from.

The Solution: Running (Exponential Moving Average) Statistics

During training, BatchNorm maintains running estimates of the population mean and variance:

μrunning(1α)μrunning+αμB\mu_{\text{running}} \leftarrow (1 - \alpha) \cdot \mu_{\text{running}} + \alpha \cdot \mu_B

σrunning2(1α)σrunning2+ασB2\sigma^2_{\text{running}} \leftarrow (1 - \alpha) \cdot \sigma^2_{\text{running}} + \alpha \cdot \sigma^2_B

Where α\alpha is the momentum (typically 0.1 in PyTorch, 0.01 in TensorFlow).

During inference, these running statistics are used instead of batch statistics:

x^=xμrunningσrunning2+ϵ\hat{x} = \frac{x - \mu_{\text{running}}}{\sqrt{\sigma^2_{\text{running}} + \epsilon}}

y=γx^+βy = \gamma \hat{x} + \beta

This can be fused into a single affine transformation for efficiency:

y=γσrunning2+ϵx+(βγμrunningσrunning2+ϵ)y = \frac{\gamma}{\sqrt{\sigma^2_{\text{running}} + \epsilon}} \cdot x + \left(\beta - \frac{\gamma \cdot \mu_{\text{running}}}{\sqrt{\sigma^2_{\text{running}} + \epsilon}}\right)

class BatchNorm1D:
"""Complete BatchNorm with training/inference modes."""

def __init__(self, num_features, momentum=0.1, eps=1e-5):
self.gamma = np.ones(num_features)
self.beta = np.zeros(num_features)
self.running_mean = np.zeros(num_features)
self.running_var = np.ones(num_features)
self.momentum = momentum
self.eps = eps
self.training = True

def forward(self, x):
if self.training:
# Use batch statistics
mu = x.mean(axis=0)
var = x.var(axis=0)

# Update running statistics
self.running_mean = ((1 - self.momentum) * self.running_mean
+ self.momentum * mu)
self.running_var = ((1 - self.momentum) * self.running_var
+ self.momentum * var)
else:
# Use running statistics
mu = self.running_mean
var = self.running_var

x_hat = (x - mu) / np.sqrt(var + self.eps)
return self.gamma * x_hat + self.beta

def eval(self):
"""Switch to inference mode."""
self.training = False

def train(self):
"""Switch to training mode."""
self.training = True


# Demonstrate training vs inference
bn = BatchNorm1D(num_features=64)

# Training: process many batches, building up running stats
for _ in range(100):
x_batch = np.random.randn(32, 64) * 3 + 2 # Mean ≈ 2, Std ≈ 3
_ = bn.forward(x_batch)

print(f"Running mean ≈ {bn.running_mean.mean():.2f} (true: 2.00)")
print(f"Running var ≈ {bn.running_var.mean():.2f} (true: 9.00)")

# Inference: single sample using running stats
bn.eval()
x_single = np.random.randn(1, 64) * 3 + 2
y_single = bn.forward(x_single)
print(f"Inference output mean: {y_single.mean():.4f}")
Instant Rejection

If asked "What happens at inference time?" and you say "Same as training - normalize using the batch," you fail. At inference time: (1) there may be no batch (single example), (2) even with a batch, results should not depend on what other examples happen to be in the batch, (3) running statistics accumulated during training are used instead. Forgetting to call model.eval() in PyTorch is one of the most common deployment bugs - it causes the model to use (incorrect) single-sample statistics instead of the learned running statistics.

Part 4 - Why BatchNorm Really Works

The Original Explanation: Internal Covariate Shift (Debunked)

Ioffe & Szegedy (2015) argued that BatchNorm works by reducing internal covariate shift (ICS) - the change in the distribution of layer inputs caused by parameter updates in preceding layers.

The Debunking: Santurkar et al. (2018)

The paper "How Does Batch Normalization Help Optimization?" (Santurkar et al., NeurIPS 2018) showed:

  1. BatchNorm does not significantly reduce ICS. They measured the change in layer input distributions with and without BatchNorm and found no meaningful difference.

  2. Artificially injecting ICS does not hurt. They added random noise to layer inputs after BatchNorm (increasing ICS) and performance did not degrade.

  3. The real effect is on the loss landscape.

The Real Reason: Loss Landscape Smoothing

Santurkar et al. showed that BatchNorm makes the loss landscape significantly smoother:

BatchNorm reduces: L(x1)L(x2) for nearby x1,x2\text{BatchNorm reduces: } \|\nabla \mathcal{L}(x_1) - \nabla \mathcal{L}(x_2)\| \text{ for nearby } x_1, x_2

This is quantified by the β\beta-smoothness of the loss:

L(x1)L(x2)βx1x2\|\nabla \mathcal{L}(x_1) - \nabla \mathcal{L}(x_2)\| \leq \beta \|x_1 - x_2\|

BatchNorm reduces β\beta, meaning:

  • Gradients are more predictive - they point in useful directions for longer
  • The loss function has fewer sharp minima and cliffs
  • Larger learning rates can be used safely
  • Optimization converges faster

BatchNorm Loss Landscape Smoothing Effect

Other Contributing Factors

Beyond loss landscape smoothing, BatchNorm helps through several additional mechanisms:

  1. Implicit regularization. Because each sample's normalization depends on the other samples in the batch, BatchNorm adds noise to the training process (similar to dropout). This provides regularization and can sometimes replace dropout entirely.

  2. Gradient magnitude stabilization. By normalizing activations, BatchNorm prevents gradients from becoming extremely large or small, allowing training with higher learning rates.

  3. Reduced sensitivity to initialization. Because BatchNorm normalizes each layer's inputs, the network is less sensitive to the initial weight values.

  4. Decoupling layer interactions. Normalization reduces the extent to which one layer's updates affect other layers, making optimization more modular.

60-Second Answer

"The original BatchNorm paper claimed it works by reducing internal covariate shift - the change in layer input distributions during training. Santurkar et al. (2018) debunked this. The real reason BatchNorm works is that it smooths the loss landscape: gradients become more predictive, the loss surface has fewer sharp cliffs, and this allows larger learning rates and faster convergence. Additionally, BatchNorm provides implicit regularization through batch-dependent noise, stabilizes gradient magnitudes, and reduces sensitivity to initialization."

Part 5 - Where BatchNorm Fails

Small Batch Sizes

BatchNorm estimates population statistics from the mini-batch. With small batches (m<16m < 16), these estimates are noisy, degrading performance. In some domains, memory constraints force very small batch sizes:

  • Object detection: Large images → small batch sizes (2-4)
  • Video processing: Temporal data consumes memory
  • Medical imaging: 3D volumes are enormous
  • Distributed training with gradient accumulation: Effective batch size per GPU may be small

Sequence Models (RNNs, Transformers)

BatchNorm normalizes across the batch dimension, which is problematic for sequence models:

  • Different sequences in a batch have different lengths
  • Statistics should not be shared across time steps with different semantics
  • The position-dependent statistics make it difficult to apply to variable-length inputs

This is why Transformers use LayerNorm, not BatchNorm.

Online/Streaming Inference

BatchNorm requires maintaining running statistics, which assumes stationary data distributions. In streaming or online settings where the data distribution shifts over time, the running statistics become stale.

Part 6 - The Normalization Zoo: Modern Alternatives

The Normalization Landscape

The key difference between normalization techniques is which dimensions they normalize over:

Normalization Zoo: BatchNorm, LayerNorm, InstanceNorm, GroupNorm

For a feature tensor with shape (N,C,H,W)(N, C, H, W):

TechniqueNormalize OverStatistics PerLearnable Params
BatchNormN,H,WN, H, WChannel2C2C (γ\gamma, β\beta per channel)
LayerNormC,H,WC, H, WSample2C2C (γ\gamma, β\beta per feature)
InstanceNormH,WH, WSample, Channel2C2C
GroupNormC/G,H,WC/G, H, WSample, Group2C2C
RMSNormCC (no mean subtraction)SampleCC (γ\gamma only)

LayerNorm: The Transformer Standard

LayerNorm normalizes across features (not across the batch), making it independent of batch size:

x^i=xiμxσx2+ϵ,yi=γix^i+βi\hat{x}_i = \frac{x_i - \mu_x}{\sqrt{\sigma_x^2 + \epsilon}}, \quad y_i = \gamma_i \hat{x}_i + \beta_i

Where μx\mu_x and σx2\sigma_x^2 are computed across the feature dimension for each sample independently.

def layer_norm(x, gamma, beta, eps=1e-5):
"""
Layer normalization.
x: (batch_size, features)
Normalizes across features (axis=1), independently per sample.
"""
mu = x.mean(axis=-1, keepdims=True) # Per-sample mean
var = x.var(axis=-1, keepdims=True) # Per-sample variance
x_hat = (x - mu) / np.sqrt(var + eps)
return gamma * x_hat + beta

def batch_norm(x, gamma, beta, eps=1e-5):
"""
Batch normalization.
x: (batch_size, features)
Normalizes across batch (axis=0), independently per feature.
"""
mu = x.mean(axis=0, keepdims=True) # Per-feature mean
var = x.var(axis=0, keepdims=True) # Per-feature variance
x_hat = (x - mu) / np.sqrt(var + eps)
return gamma * x_hat + beta


# Key difference in normalization axes:
x = np.random.randn(4, 8) # 4 samples, 8 features

# BatchNorm: each feature has mean 0, var 1 across the batch
y_bn = batch_norm(x, np.ones(8), np.zeros(8))
print(f"BatchNorm - feature 0 mean across batch: {y_bn[:, 0].mean():.4f}")
print(f"BatchNorm - feature 0 var across batch: {y_bn[:, 0].var():.4f}")

# LayerNorm: each sample has mean 0, var 1 across features
y_ln = layer_norm(x, np.ones(8), np.zeros(8))
print(f"LayerNorm - sample 0 mean across features: {y_ln[0, :].mean():.4f}")
print(f"LayerNorm - sample 0 var across features: {y_ln[0, :].var():.4f}")

Why Transformers use LayerNorm:

  1. No batch dependence - works with batch size 1 (autoregressive generation)
  2. No running statistics needed - same computation at train and inference time
  3. Handles variable sequence lengths naturally
  4. Each token is normalized independently

GroupNorm: The Small-Batch Solution

GroupNorm (Wu & He, 2018) divides channels into groups and normalizes within each group:

GroupNorm(x)=xμgσg2+ϵ,where g=group containing channel c\text{GroupNorm}(x) = \frac{x - \mu_g}{\sqrt{\sigma_g^2 + \epsilon}}, \quad \text{where } g = \text{group containing channel } c

With GG groups of C/GC/G channels each.

  • G=CG = C → InstanceNorm (one channel per group)
  • G=1G = 1 → LayerNorm (all channels in one group)

GroupNorm is batch-size independent and works well with small batches, making it ideal for object detection and segmentation.

RMSNorm: The LLM Favorite

RMSNorm (Zhang & Sennrich, 2019) simplifies LayerNorm by removing the mean centering:

RMSNorm(x)=xRMS(x)γ,RMS(x)=1ni=1nxi2\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma, \quad \text{RMS}(x) = \sqrt{\frac{1}{n}\sum_{i=1}^{n} x_i^2}

No β\beta parameter and no mean subtraction.

Why RMSNorm is increasingly preferred in LLMs (LLaMA, Gemma, etc.):

  1. Faster. Removing mean computation and the β\beta parameter saves ~10-15% of normalization compute
  2. Equally effective. Research shows the re-centering (mean subtraction) in LayerNorm provides minimal benefit
  3. Simpler implementation. One learnable parameter per feature instead of two
def rms_norm(x, gamma, eps=1e-8):
"""
RMSNorm - used in LLaMA, Gemma, and modern LLMs.
Simpler and faster than LayerNorm.
"""
rms = np.sqrt(np.mean(x ** 2, axis=-1, keepdims=True) + eps)
return (x / rms) * gamma


# Compare LayerNorm vs RMSNorm
x = np.random.randn(4, 768) # 4 tokens, hidden size 768
gamma = np.ones(768)
beta = np.zeros(768)

y_ln = layer_norm(x, gamma, beta)
y_rms = rms_norm(x, gamma)

print(f"LayerNorm output stats - mean: {y_ln.mean():.4f}, std: {y_ln.std():.4f}")
print(f"RMSNorm output stats - mean: {y_rms.mean():.4f}, std: {y_rms.std():.4f}")
# RMSNorm does not center to zero mean, but this rarely matters in practice

Comparison Table

TechniqueBatch-Size Dependent?Running Stats?Best ForUsed In
BatchNormYesYesConvNets, large batchesResNet, EfficientNet
LayerNormNoNoTransformers, NLPBERT, GPT-2
GroupNormNoNoConvNets, small batchesDetection (DETR), segmentation
InstanceNormNoNoStyle transferStyleGAN, AdaIN
RMSNormNoNoLLMs (fastest normalization)LLaMA, Gemma, Mistral
Company Variation

At vision companies (Tesla, Apple Vision), expect BatchNorm and GroupNorm questions. At LLM companies (OpenAI, Anthropic, Meta AI), expect LayerNorm and RMSNorm questions. At Google, expect knowledge of all variants - they invented most of them.

Part 7 - BatchNorm's Effects on Training

Higher Learning Rates

BatchNorm's most practical benefit is allowing much higher learning rates. Without BatchNorm, learning rates above ~0.01 often cause divergence. With BatchNorm, learning rates of 0.1 or higher work reliably.

ConfigurationMax Stable Learning RateConvergence Speed
No normalization~0.01Slow
BatchNorm~0.15-10x faster
BatchNorm + warmup~0.310-14x faster

Regularization Effect

BatchNorm adds noise because each sample's normalization depends on the other samples in the mini-batch. Different batches produce different normalizations for the same input. This stochasticity acts as a regularizer, similar to dropout.

Consequence: Models with BatchNorm often need less (or no) dropout.

Interaction with Weight Decay

An interesting property: BatchNorm makes the network invariant to weight scaling. If you multiply all weights in a layer by a constant cc:

BN(cWx)=γcWxμcWxσcWx+β=γWxμWxσWx+β=BN(Wx)\text{BN}(cWx) = \gamma \frac{cWx - \mu_{cWx}}{\sigma_{cWx}} + \beta = \gamma \frac{Wx - \mu_{Wx}}{\sigma_{Wx}} + \beta = \text{BN}(Wx)

The scaling cc cancels out in the normalization. This means weight decay on BatchNorm layers does not regularize in the traditional sense - it only affects the effective learning rate (by reducing the weight norm, it increases the effective learning rate of the normalized gradient).

Part 8 - Pre-Norm vs Post-Norm in Transformers

This is a direct extension of the BatchNorm discussion to the Transformer architecture, and is frequently asked in LLM interviews.

Post-Norm (Original Transformer)

xl+1=LayerNorm(xl+Sublayer(xl))x_{l+1} = \text{LayerNorm}(x_l + \text{Sublayer}(x_l))

The LayerNorm is applied after the residual addition.

Pre-Norm (GPT-2, LLaMA, modern LLMs)

xl+1=xl+Sublayer(LayerNorm(xl))x_{l+1} = x_l + \text{Sublayer}(\text{LayerNorm}(x_l))

The LayerNorm is applied before the sublayer, and the skip connection is completely clean.

PropertyPost-NormPre-Norm
Training stabilityRequires careful warmupMore stable, easier to train
Gradient flowLN on skip path (less clean)Skip path is pure identity
Final performanceSlightly better (when tuned)Slightly worse (but easier)
Used inOriginal Transformer, BERTGPT-2, LLaMA, most modern LLMs
Learning rate sensitivityHigh (needs warmup)Low (more forgiving)

The connection to ResNet v1 vs v2 is direct: post-norm is like ResNet v1 (operations on the skip path), and pre-norm is like ResNet v2 (clean skip path). The same principle - keeping the identity shortcut unobstructed - applies.

Part 9 - Placement of BatchNorm

Where Exactly Does BatchNorm Go?

In ConvNets, BatchNorm is typically placed between the convolution and the activation function:

ConvBatchNormReLU\text{Conv} \rightarrow \text{BatchNorm} \rightarrow \text{ReLU}

Why before ReLU? If placed after ReLU, BatchNorm would see only non-negative values (ReLU clips negatives), making the mean always positive. Normalizing before ReLU ensures the inputs to ReLU are centered around zero, allowing half the neurons to be active.

Should Bias Be Used with BatchNorm?

No. The convolutional layer's bias bb is absorbed by BatchNorm's β\beta parameter:

BN(Wx+b)=γ(Wx+b)μσ+β=γWx(μb)σ+β\text{BN}(Wx + b) = \gamma \frac{(Wx + b) - \mu}{\sigma} + \beta = \gamma \frac{Wx - (\mu - b)}{\sigma} + \beta

The bias bb only shifts the mean μ\mu, which is subtracted out by BatchNorm. The β\beta parameter replaces the role of bb. Using both is redundant.

# In PyTorch, convolutions before BatchNorm should have bias=False
# nn.Conv2d(64, 128, 3, padding=1, bias=False) # No bias!
# nn.BatchNorm2d(128)
# nn.ReLU()

# This saves parameters and computation
total_params_with_bias = 64 * 128 * 3 * 3 + 128 # weights + bias
total_params_without = 64 * 128 * 3 * 3 # weights only
print(f"With bias: {total_params_with_bias:,}")
print(f"Without bias: {total_params_without:,}")
print(f"Saved: {total_params_with_bias - total_params_without} params per layer")

Part 10 - The Mathematical Connection: Why Normalization Helps Optimization

Lipschitz Smoothness

A function ff is LL-Lipschitz smooth if:

f(x)f(y)Lxyx,y\|\nabla f(x) - \nabla f(y)\| \leq L \|x - y\| \quad \forall x, y

A smaller LL means the gradient changes more slowly - the loss landscape is smoother.

Santurkar et al. proved that BatchNorm reduces the Lipschitz constant of the loss function. Specifically, for a network with BatchNorm, the Hessian of the loss has a smaller spectral norm:

λmax(2LBN)λmax(2Lno-BN)\lambda_{\max}\left(\nabla^2 \mathcal{L}_{\text{BN}}\right) \leq \lambda_{\max}\left(\nabla^2 \mathcal{L}_{\text{no-BN}}\right)

This means:

  • The maximum curvature of the loss is smaller
  • Gradient descent steps are more reliable
  • Larger step sizes (learning rates) are safe

Gradient Predictiveness

Beyond smoothness, BatchNorm makes gradients more predictive - the gradient at the current point more accurately predicts the gradient at nearby points. This means gradient descent moves in productive directions for longer before needing to recompute.

Formally, with BatchNorm, the loss satisfies:

L(x+ηd)L(x)ηL(x)TdLη22d2\left|\mathcal{L}(x + \eta d) - \mathcal{L}(x) - \eta \nabla \mathcal{L}(x)^T d\right| \leq \frac{L \eta^2}{2} \|d\|^2

For a smaller LL, the linear approximation of the loss (which gradient descent relies on) is accurate over a larger neighborhood.

Part 11 - BatchNorm in Convolutional Networks vs Fully Connected Networks

Spatial BatchNorm for ConvNets

For fully connected layers, BatchNorm normalizes each feature independently across the batch. For convolutional layers, the situation is different: each filter produces a 2D feature map, and we want the same filter to behave consistently regardless of spatial position.

Convolutional BatchNorm normalizes across the batch AND spatial dimensions, but separately for each channel:

For a feature tensor with shape (N,C,H,W)(N, C, H, W):

μc=1NHWn=1Nh=1Hw=1Wxn,c,h,w\mu_c = \frac{1}{N \cdot H \cdot W} \sum_{n=1}^{N} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{n,c,h,w}

σc2=1NHWn=1Nh=1Hw=1W(xn,c,h,wμc)2\sigma_c^2 = \frac{1}{N \cdot H \cdot W} \sum_{n=1}^{N} \sum_{h=1}^{H} \sum_{w=1}^{W} (x_{n,c,h,w} - \mu_c)^2

This means:

  • Each channel gets ONE mean and ONE variance (not one per spatial location)
  • Statistics are computed over N×H×WN \times H \times W values (not just NN)
  • Learnable parameters: 2C2C total (γc\gamma_c and βc\beta_c per channel)
import numpy as np

def spatial_batch_norm(x, gamma, beta, eps=1e-5):
"""
BatchNorm for convolutional layers.
x: (batch, channels, height, width)
gamma, beta: (channels,)
"""
N, C, H, W = x.shape

# Compute mean and var per channel, across batch AND spatial dims
mu = x.mean(axis=(0, 2, 3), keepdims=True) # (1, C, 1, 1)
var = x.var(axis=(0, 2, 3), keepdims=True) # (1, C, 1, 1)

# Normalize
x_hat = (x - mu) / np.sqrt(var + eps)

# Scale and shift (broadcast gamma and beta)
gamma = gamma.reshape(1, C, 1, 1)
beta = beta.reshape(1, C, 1, 1)
return gamma * x_hat + beta


# Example: 8 images, 64 channels, 32x32 spatial
x = np.random.randn(8, 64, 32, 32) * 5 + 3
gamma = np.ones(64)
beta = np.zeros(64)

y = spatial_batch_norm(x, gamma, beta)

# Each channel should have mean ≈ 0, var ≈ 1 across batch and spatial dims
print(f"Channel 0 - mean: {y[:, 0, :, :].mean():.6f}, "
f"var: {y[:, 0, :, :].var():.6f}")
print(f"Channel 31 - mean: {y[:, 31, :, :].mean():.6f}, "
f"var: {y[:, 31, :, :].var():.6f}")

# Number of values used to estimate statistics per channel:
print(f"Statistics computed from {8 * 32 * 32} = {8 * 32 * 32} values per channel")
# Much more stable than the N=8 values used in FC BatchNorm!

Why Spatial Statistics Matter

This spatial pooling is important for two reasons:

  1. Better statistics. For a 32×3232 \times 32 feature map with batch size 8, each channel's statistics are estimated from 8×32×32=81928 \times 32 \times 32 = 8192 values instead of just 8. This makes convolutional BatchNorm much more stable than FC BatchNorm at the same batch size.

  2. Translation equivariance. By sharing statistics across spatial positions, BatchNorm preserves the convolutional network's translation equivariance: the same filter should behave the same way regardless of where in the image it operates.

Part 12 - Synchronized BatchNorm for Distributed Training

The Problem

In distributed training with data parallelism, each GPU processes a subset of the batch. If each GPU computes its own BatchNorm statistics, the effective batch size per GPU may be small (e.g., 4 images per GPU even with a total batch of 128).

Synchronized BatchNorm

Synchronized BatchNorm (SyncBN) aggregates statistics across all GPUs before normalizing:

μglobal=1Kk=1Kμk,σglobal2=1Kk=1K(σk2+(μkμglobal)2)\mu_{\text{global}} = \frac{1}{K} \sum_{k=1}^{K} \mu_k, \quad \sigma^2_{\text{global}} = \frac{1}{K} \sum_{k=1}^{K} \left(\sigma_k^2 + (\mu_k - \mu_{\text{global}})^2\right)

Where KK is the number of GPUs and μk\mu_k, σk2\sigma_k^2 are per-GPU statistics.

This requires an all-reduce communication step at every BatchNorm layer, which adds latency. SyncBN is essential for tasks like object detection and segmentation where per-GPU batch sizes are very small.

ConfigurationPer-GPU BatchEffective BN BatchPerformance
Standard BN, 8 GPUs22 (each GPU independent)Poor
SyncBN, 8 GPUs216 (synchronized)Good
GroupNorm, 8 GPUs2N/A (batch-independent)Good
60-Second Answer

"In distributed training, standard BatchNorm computes statistics independently per GPU, which can give noisy estimates when per-GPU batch sizes are small. Synchronized BatchNorm aggregates statistics across all GPUs via all-reduce, giving the same result as single-GPU training with the full batch. The tradeoff is additional communication overhead. For tasks like detection where batch sizes per GPU are typically 1-4, SyncBN or GroupNorm is essential."

Part 13 - Common BatchNorm Interview Gotchas

Gotcha 1: BatchNorm with Dropout

BatchNorm and dropout interact poorly. Dropout randomly zeros out activations during training, which changes the mean and variance of the remaining activations. When dropout is removed at inference time, the scale of activations changes, creating a mismatch with the running statistics BatchNorm computed during training.

Solutions:

  • Use BatchNorm without dropout (BatchNorm's regularization effect often replaces the need for dropout)
  • Place dropout after BatchNorm, not before
  • Use consistent dropout placement and be aware of the interaction

Gotcha 2: Frozen BatchNorm

In transfer learning, when you freeze early layers of a pre-trained model, you should also freeze the BatchNorm layers (set them to eval mode). If BatchNorm remains in training mode while the rest of the layer is frozen, it will update its running statistics based on the new data distribution, but the frozen weights expect the original statistics.

Gotcha 3: BatchNorm in GANs

Training GANs with BatchNorm can cause issues because the generator and discriminator see different distributions. The discriminator's BatchNorm statistics from real images differ from fake images, potentially leaking information about whether an image is real or fake.

Solutions: Use Spectral Normalization or Instance Normalization instead of BatchNorm in the discriminator.

Practice Problems

Problem 1: Training vs Inference Bug

A model works well during training but performs terribly at inference time. The model uses BatchNorm. What is the most likely bug, and how do you fix it?

Hint

The most likely bug is failing to switch the model to evaluation mode (model.eval() in PyTorch). During training, BatchNorm uses batch statistics (mean and variance of the current mini-batch). During inference, it should use the running statistics accumulated during training. If the model remains in training mode during inference, it uses the (incorrect) statistics of whatever single sample or small batch is being processed, leading to poor normalization.

Problem 2: Small Batch Problem

You are training an object detection model with batch size 2 due to memory constraints. BatchNorm performs poorly. What normalization technique would you use instead, and why?

Hint

Use GroupNorm (Wu & He, 2018). With batch size 2, BatchNorm's estimate of the batch mean and variance is extremely noisy (only 2 samples). GroupNorm normalizes within groups of channels for each sample independently, requiring no batch statistics. With 32 groups on a 256-channel feature map, each group has 8 channels - enough for stable statistics. GroupNorm was specifically designed for and evaluated on object detection tasks with small batch sizes.

Problem 3: Normalization Axes

For a tensor of shape (Batch=32, Channels=128, Height=14, Width=14), write the shape of the mean and variance computed by (a) BatchNorm, (b) LayerNorm, (c) InstanceNorm, (d) GroupNorm with 32 groups.

Hint

(a) BatchNorm: Normalize over (N, H, W). Mean/var shape: (128,) - one per channel, shared across batch and spatial dims. (b) LayerNorm: Normalize over (C, H, W). Mean/var shape: (32,) - one per sample, shared across channels and spatial dims. (c) InstanceNorm: Normalize over (H, W). Mean/var shape: (32, 128) - one per sample per channel. (d) GroupNorm with 32 groups: Normalize over (C/G=4, H, W). Mean/var shape: (32, 32) - one per sample per group.

Problem 4: Normalization Choice for LLMs

Explain why modern LLMs (LLaMA, Mistral) use RMSNorm instead of LayerNorm. Derive the computational savings.

Hint

RMSNorm eliminates two operations from LayerNorm: (1) mean computation and subtraction (re-centering), and (2) the β\beta (shift) parameter. The mean computation requires a reduction over all features followed by a subtraction - both are memory-bandwidth-bound operations on GPUs. For a hidden dimension d=4096d = 4096, LayerNorm requires: compute mean (d additions, 1 division), compute variance (d subtractions, d multiplications, d additions, 1 division), normalize (d subtractions, d divisions), scale and shift (d multiplications, d additions). RMSNorm requires: compute RMS (d multiplications, d additions, 1 division, 1 sqrt), normalize (d divisions), scale (d multiplications). Savings: ~25-30% fewer operations, and one fewer learnable parameter per feature.

Problem 5: The Identity Argument

Prove that BatchNorm with learnable γ\gamma and β\beta can represent the identity function. Then explain why this is important for the representational power argument.

Hint

If γ=σB\gamma = \sigma_B and β=μB\beta = \mu_B, then y=γxμBσB+β=σB(xμB)σB+μB=xμB+μB=xy = \gamma \cdot \frac{x - \mu_B}{\sigma_B} + \beta = \frac{\sigma_B(x - \mu_B)}{\sigma_B} + \mu_B = x - \mu_B + \mu_B = x. This proves BatchNorm can learn the identity function. This is important because it means BatchNorm never reduces the representational power of the network - it can always "turn itself off" if normalization is unhelpful for a particular layer. The network can choose the optimal amount of normalization per layer through the learned γ\gamma and β\beta.

Interview Cheat Sheet

QuestionKey Points
"Write the BatchNorm formula"x^=(xμB)/σB2+ϵ\hat{x} = (x - \mu_B) / \sqrt{\sigma_B^2 + \epsilon}, then y=γx^+βy = \gamma\hat{x} + \beta. γ\gamma, β\beta are learnable.
"Why learnable gamma and beta?"Allows network to undo normalization if needed. Can represent identity: γ=σ\gamma=\sigma, β=μ\beta=\mu.
"Training vs inference?"Training: batch statistics. Inference: running (EMA) statistics. Forgetting model.eval() is a common bug.
"Why does BatchNorm work?"NOT internal covariate shift (debunked). Real reason: smooths the loss landscape, enabling larger learning rates.
"What is internal covariate shift?"Layer input distributions changing during training. Original motivation, but Santurkar et al. showed it is not the mechanism.
"BatchNorm vs LayerNorm?"BN: across batch, per feature. LN: across features, per sample. LN works with batch size 1, used in Transformers.
"When does BatchNorm fail?"Small batch sizes, sequence models, variable-length inputs, streaming inference.
"What is RMSNorm?"LayerNorm without mean centering or β\beta parameter. Faster, used in LLaMA/Mistral.
"Why no bias before BatchNorm?"BN's β\beta absorbs the bias. Bias is subtracted out by mean removal. Redundant parameters.
"Pre-norm vs post-norm?"Pre-norm: LN before sublayer, clean skip connection. More stable. Post-norm: LN after addition. Slightly better final performance.

Spaced Repetition Checkpoints

Day 0 (Today)

  • Write the BatchNorm formula from memory (normalize, scale, shift)
  • Explain training vs inference behavior
  • Know that ICS is debunked - loss landscape smoothing is the real reason

Day 3

  • Derive the backward pass for BatchNorm
  • Compare BatchNorm vs LayerNorm normalization axes
  • Explain why bias is redundant before BatchNorm

Day 7

  • Explain GroupNorm, InstanceNorm, and RMSNorm
  • Know when to use each normalization technique
  • Discuss pre-norm vs post-norm in Transformers

Day 14

  • Mock interview: answer all 10 cheat sheet questions
  • Explain the loss landscape smoothing argument with mathematical detail
  • Discuss BatchNorm's interaction with weight decay and dropout

Day 21

  • Full 20-minute paper discussion on BatchNorm
  • Handle follow-ups on normalization in modern LLMs (RMSNorm, QK-norm)
  • Debug a "training works but inference fails" scenario live

Next Steps

You now understand normalization - from the original BatchNorm to modern RMSNorm in LLMs. Combined with residual connections from the previous chapter, you know the two foundational techniques that made deep learning practical. Next, explore Chapter 8: Adam Optimizer - the third pillar of modern training, and why adaptive learning rates revolutionized optimization.

© 2026 EngineersOfAI. All rights reserved.