Skip to main content

Weight Initialization

The Real Interview Moment

You are training a 20-layer MLP on a tabular dataset for anomaly detection. The 5-layer version trains fine - loss drops smoothly, converges in 30 epochs. You add more layers to improve representational capacity. With 20 layers, the loss does not decrease at all. Epoch 1, epoch 10, epoch 50 - flat. You have not changed the optimizer, learning rate, or data. The only change is network depth.

You add a gradient monitoring hook and print the gradient norm at each layer. Layer 20 (output): 0.14. Layer 15: 0.003. Layer 10: 1.2e-8. Layer 5: 3.1e-19. Layer 1: 1.7e-37.

Early layers receive gradient signals 37 orders of magnitude smaller than the output layer. They are not learning - they are completely invisible to backpropagation. The initialization created this disaster before training even began. The weights were too large, causing activations to blow up in the forward pass, sending the network into saturated territory where gradients are near zero. Fourteen additional layers multiplied that near-zero gradient until it vanished entirely.

This lesson explains the mathematical cause of that failure, derives the correct initialization for each activation function, and shows you how to measure initialization quality before training begins.

Why This Exists: The Variance Problem

Consider a single linear layer with no activation: z=Wx\mathbf{z} = \mathbf{W}\mathbf{x}, where WRnout×nin\mathbf{W} \in \mathbb{R}^{n_{\text{out}} \times n_{\text{in}}}.

If inputs xix_i are i.i.d. with mean 0 and variance σx2\sigma_x^2, and weights wijw_{ij} are i.i.d. with mean 0 and variance σw2\sigma_w^2, then the variance of each output:

Var(zk)=Var(iwkixi)=iVar(wki)Var(xi)=ninσw2σx2\text{Var}(z_k) = \text{Var}\left(\sum_i w_{ki} x_i\right) = \sum_i \text{Var}(w_{ki}) \text{Var}(x_i) = n_{\text{in}} \cdot \sigma_w^2 \cdot \sigma_x^2

(Using the fact that variance of a sum of independent zero-mean variables is the sum of variances, and the product rule for independent zero-mean variables.)

Through LL such layers:

Var(z(L))=(ninσw2)LVar(x)\text{Var}(\mathbf{z}^{(L)}) = (n_{\text{in}} \cdot \sigma_w^2)^L \cdot \text{Var}(\mathbf{x})

If ninσw2>1n_{\text{in}} \cdot \sigma_w^2 > 1: variance explodes exponentially - activations become NaN after a few layers.

If ninσw2<1n_{\text{in}} \cdot \sigma_w^2 < 1: variance collapses exponentially - activations become zero, gradients vanish.

The critical target: ninσw2=1n_{\text{in}} \cdot \sigma_w^2 = 1, which means σw2=1/nin\sigma_w^2 = 1/n_{\text{in}}.

The entire theory of weight initialization is built on making this product equal to 1 (or close to it) under the specific activation function being used.

Zero Initialization: Why It Completely Fails

The most natural initialization - W=0\mathbf{W} = \mathbf{0} - catastrophically breaks training through a problem called symmetry breaking failure.

Formal argument: if all weights in layer ll are identical (including zero), every neuron in that layer computes exactly the same pre-activation zj=wTx+b=bz_j = \mathbf{w}^T \mathbf{x} + b = b for the same input. Every neuron produces the same output. Every neuron receives the same gradient via backpropagation: the gradient of the loss with respect to neuron jj's weights depends only on the upstream gradient and the layer's input, not on any property that distinguishes neuron jj from neuron kk. Every weight receives the same update. After the update, all weights remain identical.

This continues indefinitely. The network is stuck in a symmetric state - it effectively has one neuron per layer regardless of the declared width. A zero-initialized 10-layer network with 512 neurons per layer is, functionally, a chain of 1-neuron layers.

import torch
import torch.nn as nn


def demonstrate_symmetry_failure():
"""Show that zero initialization prevents learning by maintaining neuron symmetry."""
torch.manual_seed(42)

model = nn.Sequential(
nn.Linear(4, 8),
nn.ReLU(),
nn.Linear(8, 1)
)

# Force zero initialization
with torch.no_grad():
for layer in model:
if isinstance(layer, nn.Linear):
layer.weight.zero_()
layer.bias.zero_()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

print("Initial weights (all zeros):")
print(model[0].weight[:2]) # first two neurons

# Training loop
for step in range(500):
x = torch.randn(16, 4)
y = torch.randn(16, 1)
loss = criterion(model(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()

print("\nAfter 500 training steps:")
w = model[0].weight # (8, 4)
print(f"Max difference between neuron 0 and neuron 1 weights: "
f"{(w[0] - w[1]).abs().max().item():.10f}")
# Essentially 0 - all 8 neurons in the first layer are identical
# The 8-neuron layer is functionally 1 neuron

# Verify by checking effective rank
singular_values = torch.linalg.svdvals(w)
print(f"Singular values of weight matrix (non-zero = effective rank):")
print(f"{singular_values}")
# Only 1 non-trivial singular value - rank-1 matrix despite 8×4 shape


demonstrate_symmetry_failure()

:::danger Zero Initialization Rule Never initialize the weight matrices of hidden layers to zero or any constant. Biases can be initialized to zero - they do not share the symmetry problem because each bias affects only one neuron's output, not a shared computation. All neurons still compute differently because their weight-input products differ. :::

Constant initialization (all weights = some constant c0c \neq 0) has the same symmetry problem. The specific value does not matter - the symmetry does.

Naive Random Gaussian: Variance Matters Enormously

Consider initializing weights from a standard normal: wijN(0,1)w_{ij} \sim \mathcal{N}(0, 1), so σw2=1\sigma_w^2 = 1.

For a layer with nin=500n_{\text{in}} = 500 inputs:

Var(zk)=500×1.0×σx2=500σx2\text{Var}(z_k) = 500 \times 1.0 \times \sigma_x^2 = 500 \sigma_x^2

Each layer amplifies variance by 500. Through 20 layers: 500201054500^{20} \approx 10^{54} amplification. Immediate overflow.

Now consider wijN(0,106)w_{ij} \sim \mathcal{N}(0, 10^{-6}), so σw2=106\sigma_w^2 = 10^{-6}:

Var(zk)=500×106×σx2=0.0005σx2\text{Var}(z_k) = 500 \times 10^{-6} \times \sigma_x^2 = 0.0005 \sigma_x^2

Each layer shrinks variance by 2000x. Through 20 layers: 0.00052010690.0005^{20} \approx 10^{-69} amplification. Immediate underflow.

The correct variance is σw2=1/500=0.002\sigma_w^2 = 1/500 = 0.002, making ninσw2=1n_{\text{in}} \cdot \sigma_w^2 = 1. But this was derived assuming linear activation. The actual correct variance depends on the activation function.

Xavier/Glorot Initialization: Derived for Tanh/Sigmoid (Glorot and Bengio, 2010)

Glorot and Bengio (2010) derived the variance that preserves signal magnitude through a layer with approximately linear activation (tanh near zero behaves linearly).

Forward pass constraint: to preserve activation variance through a layer, we need:

Var(a(l))=Var(a(l1))ninσw2=1σw2=1nin\text{Var}(\mathbf{a}^{(l)}) = \text{Var}(\mathbf{a}^{(l-1)}) \Rightarrow n_{\text{in}} \cdot \sigma_w^2 = 1 \Rightarrow \sigma_w^2 = \frac{1}{n_{\text{in}}}

Backward pass constraint: for stable gradient flow backward through the same layer:

Var ⁣(La(l1))=Var ⁣(La(l))noutσw2=1σw2=1nout\text{Var}\!\left(\frac{\partial L}{\partial \mathbf{a}^{(l-1)}}\right) = \text{Var}\!\left(\frac{\partial L}{\partial \mathbf{a}^{(l)}}\right) \Rightarrow n_{\text{out}} \cdot \sigma_w^2 = 1 \Rightarrow \sigma_w^2 = \frac{1}{n_{\text{out}}}

The two constraints give different values. The Glorot compromise satisfies both approximately:

σw2=2nin+nout\sigma_w^2 = \frac{2}{n_{\text{in}} + n_{\text{out}}}

This is the harmonic mean of the forward and backward constraints.

Glorot Uniform (the original formulation):

wU[6nin+nout, 6nin+nout]w \sim \mathcal{U}\left[-\sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}},\ \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}\right]

The 6\sqrt{6} comes from the fact that the variance of a uniform distribution on [a,a][-a, a] is a2/3a^2/3, so to achieve variance 2/(nin+nout)2/(n_{\text{in}}+n_{\text{out}}): a2/3=2/(nin+nout)a=6/(nin+nout)a^2/3 = 2/(n_{\text{in}}+n_{\text{out}}) \Rightarrow a = \sqrt{6/(n_{\text{in}}+n_{\text{out}})}.

Glorot Normal:

wN ⁣(0, 2nin+nout)w \sim \mathcal{N}\!\left(0,\ \sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}}\right)

When to use Xavier/Glorot: tanh, sigmoid, and linear activations. It is theoretically invalid for ReLU because ReLU is not linear - it zeroes out half the activations, halving the effective variance. Using Glorot with ReLU underestimates the required variance and causes slow activation collapse in deep ReLU networks.

He/Kaiming Initialization: Derived for ReLU (He et al., 2015)

Kaiming He et al. (2015) extended Glorot's analysis to ReLU by accounting for the fact that ReLU zeroes out the negative half of its inputs.

For a ReLU layer, if pre-activations zkz_k are zero-mean with variance σz2\sigma_z^2, roughly half are positive (passed through ReLU) and half are zero (killed). The expected value of zk2z_k^2 after ReLU is approximately σz2/2\sigma_z^2 / 2 - the positive half contributes σz2/2\sigma_z^2 / 2 and the negative half contributes 0.

Formally, for a random variable ZN(0,σz2)Z \sim \mathcal{N}(0, \sigma_z^2):

E[ReLU(Z)2]=σz22\mathbb{E}[\text{ReLU}(Z)^2] = \frac{\sigma_z^2}{2}

(This follows from integrating z212πσez2/2σ2z^2 \cdot \frac{1}{\sqrt{2\pi}\sigma} e^{-z^2/2\sigma^2} over z>0z > 0.)

For the output variance of a ReLU layer to equal its input variance:

Var(ReLU(z(l)))=nin2σw2Var(a(l1))=Var(a(l1))\text{Var}(\text{ReLU}(\mathbf{z}^{(l)})) = \frac{n_{\text{in}}}{2} \cdot \sigma_w^2 \cdot \text{Var}(\mathbf{a}^{(l-1)}) = \text{Var}(\mathbf{a}^{(l-1)})

Solving: nin2σw2=1σw2=2nin\frac{n_{\text{in}}}{2} \cdot \sigma_w^2 = 1 \Rightarrow \sigma_w^2 = \frac{2}{n_{\text{in}}}

Kaiming Normal (most common):

wN ⁣(0, 2nin)w \sim \mathcal{N}\!\left(0,\ \sqrt{\frac{2}{n_{\text{in}}}}\right)

Kaiming Uniform:

wU ⁣[6nin, 6nin]w \sim \mathcal{U}\!\left[-\sqrt{\frac{6}{n_{\text{in}}}},\ \sqrt{\frac{6}{n_{\text{in}}}}\right]

PyTorch's default for nn.Linear is Kaiming Uniform with nonlinearity='leaky_relu', which gives σw=2/((1+α2)nin)\sigma_w = \sqrt{2 / ((1 + \alpha^2) n_{\text{in}})} for slope α\alpha. For ReLU (α=0\alpha = 0), this reduces to 2/nin\sqrt{2/n_{\text{in}}}.

Fan-in vs fan-out mode:

  • mode='fan_in': σw2=2/nin\sigma_w^2 = 2/n_{\text{in}} - ensures forward pass variance is stable
  • mode='fan_out': σw2=2/nout\sigma_w^2 = 2/n_{\text{out}} - ensures backward pass gradient variance is stable
  • Fan-in is the default and appropriate for most cases

LeCun Initialization: For SELU

LeCun initialization (LeCun et al., 1998) predates Glorot and He. It was designed for networks without activation functions (approximately linear regime):

σw2=1nin\sigma_w^2 = \frac{1}{n_{\text{in}}}

This is the forward-only constraint from Glorot's derivation - it keeps forward-pass variance stable but does not balance the backward pass. It is also the required initialization for SELU networks: the SELU self-normalizing property was mathematically derived assuming LeCun initialization.

import torch
import torch.nn as nn

# LeCun normal in PyTorch
weight = torch.empty(256, 512)
nn.init.normal_(weight, mean=0.0, std=(1.0 / 512) ** 0.5)
# std = 1/sqrt(fan_in) = 1/sqrt(512) ≈ 0.0442

Orthogonal Initialization: For RNNs

Orthogonal initialization draws weight matrices from the set of orthogonal matrices (matrices where WTW=I\mathbf{W}^T \mathbf{W} = \mathbf{I}). For a non-square matrix, the singular values are all 1.

Why this matters for RNNs: in a recurrent network, the hidden state is repeatedly multiplied by the recurrent weight matrix Wh\mathbf{W}_h. If Wh\mathbf{W}_h has singular values >1> 1, the hidden state explodes over many timesteps. If <1< 1, it vanishes. Orthogonal matrices have all singular values equal to 1 - the hidden state's magnitude is preserved exactly across arbitrary length sequences.

How to construct: sample a random Gaussian matrix A\mathbf{A}, compute QR decomposition A=QR\mathbf{A} = \mathbf{Q}\mathbf{R}, use Q\mathbf{Q} (orthogonal by definition).

import torch
import torch.nn as nn

weight = torch.empty(128, 128)
nn.init.orthogonal_(weight, gain=1.0)

# Verify orthogonality: W^T @ W should be identity
identity_check = (weight.T @ weight - torch.eye(128)).abs().max()
print(f"Max deviation from identity: {identity_check:.6f}") # < 1e-6

# Singular values should all be 1
singular_values = torch.linalg.svdvals(weight)
print(f"Singular value range: [{singular_values.min():.4f}, {singular_values.max():.4f}]")
# Should be [1.0000, 1.0000]

LSUV: Layer-Sequential Unit Variance (Mishkin and Matas, 2016)

LSUV is a data-driven initialization that does not require knowing the activation function's statistics ahead of time. Algorithm:

  1. Initialize all weights with orthogonal initialization
  2. For each layer ll (sequentially, from first to last): a. Run a forward pass through layers 1,,l1, \ldots, l on a batch of data b. Compute the variance of layer ll's output activations c. Scale the weights of layer ll so that the output variance equals 1 d. Optionally, subtract the mean to zero-center
import torch
import torch.nn as nn


def lsuv_init(model: nn.Module, data_batch: torch.Tensor,
target_var: float = 1.0, max_iter: int = 10, tol: float = 0.1) -> None:
"""
Layer-Sequential Unit Variance initialization.
Iteratively scales each layer's weights until output variance ≈ target_var.

This is data-driven: it accounts for the actual activation statistics
under the real data distribution, not theoretical assumptions.
"""
model.eval()
hooks = []
layer_outputs = {}

def get_hook(name):
def hook(module, input, output):
layer_outputs[name] = output.detach()
return hook

# Register hooks to capture activations
linear_layers = [(name, module) for name, module in model.named_modules()
if isinstance(module, nn.Linear)]

for name, module in linear_layers:
hooks.append(module.register_forward_hook(get_hook(name)))

with torch.no_grad():
# Process each layer sequentially
for layer_name, layer_module in linear_layers:
for iteration in range(max_iter):
# Forward pass to get current activations
layer_outputs.clear()
_ = model(data_batch)

if layer_name not in layer_outputs:
break

activations = layer_outputs[layer_name]
current_var = activations.var().item()

if abs(current_var - target_var) < tol:
break # variance is already close to target

# Scale weights so output variance = target_var
scale = (target_var / (current_var + 1e-8)) ** 0.5
layer_module.weight.data *= scale

print(f" {layer_name}: iter {iteration}, var={current_var:.3f} → scaling by {scale:.3f}")

for hook in hooks:
hook.remove()

model.train()


# Usage
model = nn.Sequential(
nn.Linear(100, 256), nn.ReLU(),
nn.Linear(256, 128), nn.ReLU(),
nn.Linear(128, 64), nn.ReLU(),
nn.Linear(64, 10),
)

data_batch = torch.randn(512, 100) # representative batch
lsuv_init(model, data_batch)
print("LSUV initialization complete.")

Bias Initialization Strategies

Biases do not suffer from the symmetry problem - each bias affects only one neuron's output, so all biases can be initialized identically. The question is what value to choose.

Standard: zeros. This is correct for hidden layers. At initialization, no activation has a directional bias - the network starts from a neutral state.

Output layer for class imbalance: if your dataset has 1% fraud (positive class) and 99% legitimate (negative class), initializing the output bias to 0 means the network starts by predicting 50% probability. It must learn the 99:1 ratio from data. This wastes early training steps. Instead, initialize the output bias to the log-odds of the base rate:

b=log(p1p)b = \log\left(\frac{p}{1-p}\right)

For p=0.01p = 0.01: b=log(0.01/0.99)4.6b = \log(0.01/0.99) \approx -4.6. This starts the network at the base rate, saving significant early training time.

import torch
import torch.nn as nn
import math


def initialize_output_bias_for_imbalance(layer: nn.Linear, positive_rate: float) -> None:
"""
Initialize output bias to log-odds of positive class rate.
Particularly important when positive_rate << 0.5.
"""
with torch.no_grad():
bias_value = math.log(positive_rate / (1 - positive_rate))
layer.bias.fill_(bias_value)
print(f"Output bias initialized to {bias_value:.4f} for positive_rate={positive_rate}")


# Example: fraud detection with 1% fraud rate
output_layer = nn.Linear(64, 1)
initialize_output_bias_for_imbalance(output_layer, positive_rate=0.01)
# Output bias: -4.5951 → sigmoid(-4.5951) = 0.01 → starts at base rate

LSTM forget gate: LSTM forget gates should be initialized to values that bias toward remembering (not forgetting). Common practice: initialize the forget gate bias to 1.0. This biases the initial forget gate activation toward sigmoid(1) ≈ 0.73, meaning the gate starts by passing most of the previous state through rather than forgetting it.

Full NumPy Experiment: Gradient Flow Under Different Initializations

import numpy as np


def build_deep_mlp_numpy(n_layers: int, width: int, init_scheme: str, seed: int = 42) -> list:
"""
Build a deep MLP with specified initialization.
Returns list of (W, b) tuples for each layer.
"""
rng = np.random.default_rng(seed)
layers = []

for i in range(n_layers):
in_dim = 100 if i == 0 else width
out_dim = width if i < n_layers - 1 else 10

if init_scheme == "zeros":
W = np.zeros((out_dim, in_dim))
elif init_scheme == "standard_normal":
W = rng.standard_normal((out_dim, in_dim)) # sigma=1, too large
elif init_scheme == "small_normal":
W = rng.standard_normal((out_dim, in_dim)) * 0.01 # sigma=0.01, too small
elif init_scheme == "xavier":
std = np.sqrt(2.0 / (in_dim + out_dim))
W = rng.standard_normal((out_dim, in_dim)) * std
elif init_scheme == "kaiming":
std = np.sqrt(2.0 / in_dim)
W = rng.standard_normal((out_dim, in_dim)) * std

b = np.zeros(out_dim)
layers.append((W, b))

return layers


def relu(z):
return np.maximum(0, z)

def relu_grad(z):
return (z > 0).astype(float)


def forward_pass(layers: list, x: np.ndarray) -> tuple:
"""Forward pass, returning pre-activations for gradient computation."""
a = x
cache = [x]
pre_activations = []

for i, (W, b) in enumerate(layers):
z = a @ W.T + b
pre_activations.append(z)
if i < len(layers) - 1:
a = relu(z)
else:
a = z # no activation on output layer
cache.append(a)

return a, cache, pre_activations


def backward_pass(layers: list, loss_grad: np.ndarray,
cache: list, pre_activations: list) -> list[float]:
"""Backward pass, returning gradient norms at each layer."""
delta = loss_grad
grad_norms = []

for i in reversed(range(len(layers))):
W, b = layers[i]
z = pre_activations[i]
a_prev = cache[i]

if i < len(layers) - 1:
delta = delta * relu_grad(z)

dW = delta.T @ a_prev / delta.shape[0]
grad_norms.insert(0, np.linalg.norm(dW))

if i > 0:
delta = delta @ W

return grad_norms


def compare_initializations():
"""
Run a single forward+backward pass for each initialization scheme
and compare gradient norms across layers.
"""
n_layers = 10
width = 128
schemes = ["standard_normal", "small_normal", "xavier", "kaiming"]

# Random input batch (B=32, features=100)
rng = np.random.default_rng(0)
X = rng.standard_normal((32, 100))
y_true = rng.integers(0, 10, 32)

print("Gradient norms across layers (Layer 1 = earliest, Layer 10 = output-side):")
print(f"{'Scheme':<18} | {'L1':>9} | {'L5':>9} | {'L10':>9} | Status")
print("-" * 65)

for scheme in schemes:
layers = build_deep_mlp_numpy(n_layers, width, scheme)

# Forward pass
output, cache, pre_acts = forward_pass(layers, X)

# Simple MSE-style gradient on output
loss_grad = np.zeros_like(output)
for i, label in enumerate(y_true):
loss_grad[i, label] = output[i, label] - 1.0

# Backward pass
grad_norms = backward_pass(layers, loss_grad, cache, pre_acts)

g1 = grad_norms[0]
g5 = grad_norms[4]
g10 = grad_norms[-1]

if g1 < 1e-8:
status = "VANISHING"
elif g1 > 100:
status = "EXPLODING"
else:
status = "STABLE "

print(f"{scheme:<18} | {g1:>9.2e} | {g5:>9.2e} | {g10:>9.2e} | {status}")


compare_initializations()
# Expected output (representative):
# standard_normal | 1.23e+08 | 4.52e+02 | 1.34e-01 | EXPLODING
# small_normal | 3.21e-18 | 8.91e-10 | 9.21e-02 | VANISHING
# xavier | 8.34e-02 | 9.12e-02 | 1.03e-01 | STABLE
# kaiming | 1.11e-01 | 9.87e-02 | 9.54e-02 | STABLE

Measuring Activation Statistics: Pre-Training Diagnostic

A fast way to verify initialization before training begins: run a single forward pass and measure activation statistics at each layer. Healthy activation statistics (mean near 0, variance near 1) indicate good initialization.

import torch
import torch.nn as nn


def measure_activation_statistics(model: nn.Module, n_features: int,
batch_size: int = 512) -> None:
"""
Run a forward pass with random input and measure activation statistics
at each layer. Use this BEFORE training to verify initialization quality.
"""
model.eval()
x = torch.randn(batch_size, n_features)

hooks = []
stats = {}

def make_hook(name):
def hook(module, input, output):
a = output.detach()
stats[name] = {
'mean': a.mean().item(),
'std': a.std().item(),
'frac_zero': (a == 0).float().mean().item(),
}
return hook

for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.ReLU, nn.GELU)):
hooks.append(module.register_forward_hook(make_hook(name)))

with torch.no_grad():
model(x)

for hook in hooks:
hook.remove()

print(f"{'Layer':<30} | {'Mean':>8} | {'Std':>8} | {'Frac Zero':>10}")
print("-" * 65)
for name, s in stats.items():
flag = ""
if abs(s['mean']) > 2:
flag = " ← MEAN DRIFT"
if s['std'] < 0.01 or s['std'] > 10:
flag = " ← VARIANCE ISSUE"
if s['frac_zero'] > 0.9:
flag = " ← TOO SPARSE"
print(f"{name[:28]:<30} | {s['mean']:>8.4f} | {s['std']:>8.4f} | {s['frac_zero']:>10.4f}{flag}")

model.train()


# Compare two initializations on the same model architecture
def compare_pytorch_inits():
in_dim, width, n_layers = 256, 256, 10

# Build model with Kaiming init (correct for ReLU)
model_kaiming = nn.Sequential(*[
layer for i in range(n_layers)
for layer in [nn.Linear(in_dim if i == 0 else width, width), nn.ReLU()]
])
# PyTorch default is Kaiming uniform - already correct

# Build model with Xavier init (incorrect for ReLU)
model_xavier = nn.Sequential(*[
layer for i in range(n_layers)
for layer in [nn.Linear(in_dim if i == 0 else width, width), nn.ReLU()]
])
for module in model_xavier.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight) # wrong for ReLU

print("=== Kaiming Initialization ===")
measure_activation_statistics(model_kaiming, in_dim)
print("\n=== Xavier Initialization (wrong for ReLU) ===")
measure_activation_statistics(model_xavier, in_dim)


compare_pytorch_inits()

PyTorch: The nn.init Module

import torch
import torch.nn as nn


def custom_weight_init(model: nn.Module, activation: str = "relu") -> None:
"""
Apply appropriate initialization for each layer type based on activation.

Key principle: match the initialization to the activation's statistical
properties - specifically how much variance it adds or removes.
"""
for module in model.modules():
if isinstance(module, nn.Linear):
if activation in ("relu", "gelu", "silu", "leaky_relu"):
# Kaiming: accounts for ReLU's variance halving
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
elif activation in ("tanh", "sigmoid"):
# Xavier: assumes approximately linear activation
nn.init.xavier_normal_(module.weight)
elif activation == "selu":
# LeCun: required for SELU's self-normalizing property
nn.init.normal_(module.weight, mean=0.0,
std=(1.0 / module.in_features) ** 0.5)
# Biases: always zeros for hidden layers
if module.bias is not None:
nn.init.zeros_(module.bias)

elif isinstance(module, nn.Conv2d):
# fan_out mode for conv: stability of backward pass
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)

elif isinstance(module, nn.Embedding):
# Small normal for embeddings - GPT-2 convention
nn.init.normal_(module.weight, mean=0.0, std=0.02)

elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight) # gamma starts at 1
nn.init.zeros_(module.bias) # beta starts at 0

elif isinstance(module, nn.BatchNorm1d):
nn.init.ones_(module.weight) # gamma = 1 (no initial scaling)
nn.init.zeros_(module.bias) # beta = 0 (no initial shift)


# Residual network initialization: zero the last layer of each residual branch
def residual_branch_init(model: nn.Module) -> None:
"""
For residual connections: y = x + F(x), initialize F's last layer to zero.
This means F(x) ≈ 0 at initialization - the block computes identity.
Training begins from a stable starting point.
Used in T-Fixup, ReZero, and various Transformer initialization schemes.
"""
for name, module in model.named_modules():
# Heuristic: zero the weight of linear layers named 'out_proj' or 'fc2'
# (common names for the last layer in residual blocks)
if isinstance(module, nn.Linear) and any(
s in name for s in ['out_proj', 'fc2', 'c_proj']
):
nn.init.zeros_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
print(f"Zeroed residual branch: {name}")


# GPT-2 style initialization
def gpt2_init(model: nn.Module, n_layers: int) -> None:
"""
GPT-2 initialization (Radford et al., 2019):
- Normal(0, 0.02) for all weights
- Residual projections scaled by 1/sqrt(2*n_layers)
to prevent activation variance growth through residual connections
"""
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)

# Scale down residual projections
# If each residual block adds O(1) variance, and there are n_layers blocks,
# total variance grows as O(n_layers). Scaling by 1/sqrt(2*n_layers) normalizes this.
residual_scale = (2 * n_layers) ** -0.5
for name, param in model.named_parameters():
if "c_proj.weight" in name or "out_proj.weight" in name:
param.data *= residual_scale
print(f"Scaled {name} by {residual_scale:.4f}")

Initialization Decision Guide

Layer Normalization as an Alternative

Layer normalization decouples training stability from initialization quality. By normalizing activations during the forward pass, LayerNorm prevents extreme activation statistics regardless of initialization.

This is why transformer architectures using Pre-LN (LayerNorm before each sub-layer) are relatively robust to initialization - LayerNorm catches and corrects activation drift at each layer. However, good initialization still matters: it reduces the early training instability that occurs before LayerNorm statistics adapt to the current model state.

The combination of Kaiming initialization + LayerNorm is particularly robust. Kaiming keeps the initial signal scale reasonable; LayerNorm prevents drift as training proceeds.

:::warning The fan_in vs fan_out Mode in Kaiming nn.init.kaiming_normal_ has a mode argument: 'fan_in' (default) and 'fan_out'. Fan-in uses σ=2/nin\sigma = \sqrt{2/n_{\text{in}}}, preserving forward-pass variance. Fan-out uses σ=2/nout\sigma = \sqrt{2/n_{\text{out}}}, preserving backward-pass gradient variance. For most MLPs, fan-in is appropriate. For very deep networks where backward-pass stability is the primary concern, fan-out can help. PyTorch documentation recommends fan-out for Conv2d layers; for Linear layers, fan-in is standard. :::

YouTube Resources

VideoChannelWhy Watch It
Weight Initialization ExplainedAndrej KarpathyPractical derivation with code from scratch
CS231n - Training Neural Networks IStanford CS231nXavier and Kaiming derivations with visual intuition
He et al. 2015 Paper ExplainedYannic KilcherDeep dive into the Kaiming He derivation and ImageNet results
Deep Learning - Weight InitMIT OpenCourseWareTheoretical treatment of signal propagation and initialization
Neural Networks: Zero to Hero - InitializationAndrej KarpathyHands-on debugging of initialization issues in micrograd

Interview Q&A

Q1: Why does weight initialization matter? Can the optimizer fix a bad initialization?

Bad initialization creates structural problems that optimizers cannot fix. Zero initialization causes permanent symmetry - all neurons in a layer compute identical functions and receive identical gradients, effectively making width irrelevant. Standard-normal initialization (variance = 1) causes exponential variance explosion: for 500-wide layers, each layer amplifies variance by 500x, giving 500201054500^{20} \approx 10^{54} amplification across 20 layers. Small-variance initialization causes exponential collapse: (0.001)201060(0.001)^{20} \approx 10^{-60} shrinkage. Adam is more robust than SGD to initialization scale because its adaptive learning rate partially compensates for parameter scale - but it cannot fix symmetry (all neurons receive identical gradients regardless of optimizer) or extreme explosion/collapse that leads to NaN. Proper initialization ensures the network is in a learnable state before training begins.

Q2: Derive the Kaiming He initialization for ReLU. Why is Xavier insufficient?

Xavier/Glorot assumes approximately linear activation and derives σw2=2/(nin+nout)\sigma_w^2 = 2/(n_{\text{in}} + n_{\text{out}}) to balance forward and backward pass variance. ReLU is not linear - it zeroes out the negative half of its inputs. For a pre-activation ZN(0,σz2)Z \sim \mathcal{N}(0, \sigma_z^2), the expected squared output of ReLU is E[ReLU(Z)2]=σz2/2\mathbb{E}[\text{ReLU}(Z)^2] = \sigma_z^2/2 (only the positive half contributes). For output variance to equal input variance: ninσw2σz2/2=σz2n_{\text{in}} \cdot \sigma_w^2 \cdot \sigma_z^2 / 2 = \sigma_z^2, giving σw2=2/nin\sigma_w^2 = 2/n_{\text{in}}. Xavier underestimates this by the factor of 2 - it gives σw2=1/nin\sigma_w^2 = 1/n_{\text{in}}. In a deep ReLU network with Xavier initialization, each layer reduces variance by half (due to the uncorrected ReLU factor), causing slow but systematic activation collapse.

Q3: What is the symmetry breaking problem and why does it prevent learning?

Symmetry breaking refers to the requirement that neurons in the same layer start with different parameters, so that gradient descent can differentiate them. With identical weights (e.g., all zeros, or all 0.01), every neuron in a layer computes the same weighted sum of inputs: zj=wTx+bz_j = \mathbf{w}^T\mathbf{x} + b for all jj. Backpropagation computes the gradient of the loss with respect to each neuron's weights. Since all neurons have the same weights and the same activations, the upstream gradient distributes identically to all neurons. All weight gradients are equal. All weight updates are equal. After the update, all weights remain identical. The network is permanently stuck in a state where each layer has effective rank 1 - only one feature is represented regardless of declared width. Breaking symmetry requires random initialization where no two neurons start identically.

Q4: How would you initialize a transformer model from scratch?

For a transformer, the common practice follows the GPT-2 convention: (1) Linear layers: N(0,0.02)\mathcal{N}(0, 0.02) - small enough to keep initial activations near zero-mean, large enough to break symmetry. (2) Embedding layers: N(0,0.02)\mathcal{N}(0, 0.02) - same convention for consistency. (3) LayerNorm: γ=1,β=0\gamma = 1, \beta = 0 - identity transform at initialization. (4) Residual projections (the output projection of each attention block and FFN): scaled down by 1/2L1/\sqrt{2L} where LL is the number of layers - to prevent the residual sum from growing as O(L)O(\sqrt{L}) with depth. (5) Bias terms: zeros throughout. The 0.02 standard deviation is not derived theoretically - it is an empirical convention from the original GPT-2 paper that has become standard for language models trained with the Adam optimizer.

Q5: A training run diverges immediately after the first batch. How would you diagnose initialization issues?

First, print the mean and standard deviation of activations at every layer before any gradient update: model.eval(), then measure_activation_statistics(model, data_batch). Healthy activation statistics are approximately (μ,σ)(0,1)(\mu, \sigma) \approx (0, 1) throughout. If activations are very large (σ1\sigma \gg 1) in early layers, the initialization variance is too high. If activations are very small (σ0.01\sigma \ll 0.01), the variance is too low. If activations are all identical across a batch, check for zero or constant initialization. Second, use torch.autograd.set_detect_anomaly(True) to find the first operation that produces NaN. Third, check custom layers that may not inherit PyTorch's initialization - any nn.Module subclass with a custom __init__ that calls nn.Linear(...) but then does something unusual to the weights. Fourth, check BatchNorm and LayerNorm gamma initialization - if gamma is initialized to a large positive value, it rescales the normalized activations to extreme magnitudes. The most common cause of immediate divergence in practice is too-large weight variance combined with too-large learning rate - reducing the learning rate by 10x usually stabilizes this enough to see where the true initialization error is.

Signal Propagation in Practice: A Worked Numerical Example

Consider a 5-layer MLP with 256-wide hidden layers using ReLU activation. We verify that Kaiming initialization preserves activation variance across all layers:

import torch
import torch.nn as nn

torch.manual_seed(42)

# Build a 5-layer ReLU MLP
dims = [64, 256, 256, 256, 256, 10]

# With Kaiming initialization (correct for ReLU)
def build_with_kaiming(dims):
layers = []
for i in range(len(dims) - 1):
lin = nn.Linear(dims[i], dims[i+1])
nn.init.kaiming_normal_(lin.weight, nonlinearity='relu')
nn.init.zeros_(lin.bias)
layers.append(lin)
if i < len(dims) - 2:
layers.append(nn.ReLU())
return nn.Sequential(*layers)

# With naive initialization (std=1, too large)
def build_with_naive(dims):
layers = []
for i in range(len(dims) - 1):
lin = nn.Linear(dims[i], dims[i+1])
nn.init.normal_(lin.weight, mean=0, std=1.0) # wrong
nn.init.zeros_(lin.bias)
layers.append(lin)
if i < len(dims) - 2:
layers.append(nn.ReLU())
return nn.Sequential(*layers)

x = torch.randn(128, 64)

for label, model in [("Kaiming", build_with_kaiming(dims)),
("Naive std=1", build_with_naive(dims))]:
print(f"\n=== {label} ===")
a = x
for i, layer in enumerate(model):
a = layer(a)
if isinstance(layer, nn.Linear):
print(f"After Linear {i//2 + 1}: mean={a.mean():.4f}, std={a.std():.4f}")

Expected output for Kaiming: standard deviation stays approximately constant (~1.0) across all 5 layers. For naive std=1: standard deviation explodes by 16x per layer, reaching ~16510616^5 \approx 10^6 at layer 5 - immediate NaN after any loss computation.

Common Initialization Anti-Patterns

Three initialization mistakes that appear regularly in production codebases:

1. Copying weights from wrong initialization in fine-tuning:

# WRONG: loads pretrained weights then reinitializes
model = MyModel()
state_dict = torch.load("pretrained.pt")
model.load_state_dict(state_dict) # loads pretrained weights

# Then accidentally calls custom_weight_init which overwrites them!
custom_weight_init(model, activation="relu") # destroys pretrained weights

# CORRECT: apply custom init ONLY to newly added layers
def init_new_head_only(model, head_module_name="classifier"):
for name, module in model.named_modules():
if head_module_name in name and isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight)
nn.init.zeros_(module.bias)

2. Using default PyTorch init for embedding layers in LLMs:

# PyTorch default: Embedding uses N(0, 1) - too large for language models
embedding = nn.Embedding(vocab_size, d_model)
# embedding.weight initialized to N(0, 1) by default

# CORRECT: GPT-2 convention - N(0, 0.02) for all embedding/linear weights
nn.init.normal_(embedding.weight, mean=0.0, std=0.02)
# This keeps initial token embeddings close to zero, preventing early attention scores
# from becoming too large and causing softmax saturation

3. Forgetting to re-initialize when adding layers after training:

When you add a new intermediate layer to a trained model (e.g., adding a bottleneck), the new layer must be initialized so that it approximates the identity function initially - otherwise you destroy the learned representations in the surrounding layers:

def add_identity_linear(in_features: int, out_features: int) -> nn.Linear:
"""
Initialize a new linear layer to approximately compute identity.
Add this between two existing trained layers without disrupting representations.
Assumes in_features == out_features; otherwise use zero init + residual.
"""
assert in_features == out_features, "Use residual for different dims"
lin = nn.Linear(in_features, out_features, bias=True)
# Eye init: start as exact identity transform
nn.init.eye_(lin.weight)
nn.init.zeros_(lin.bias)
return lin

:::tip Check Initialization Before Every Training Run The cost is negligible: one forward pass on a random batch, print activation statistics at every layer. This 5-second check catches bugs that would otherwise surface as confusing training curves hours later. Build measure_activation_statistics into your training script as a pre-training sanity check - not as something you run only when things go wrong. :::

Summary: Initialization Quick Reference

A consolidated reference to answer any initialization question without recomputing the derivations:

ScenarioInitializationFormulaPyTorch
Linear/MLP + ReLUKaiming Normalσ=2/nin\sigma = \sqrt{2/n_{\text{in}}}nn.init.kaiming_normal_(w, nonlinearity='relu')
Linear/MLP + GELU/SiLUKaiming NormalSame as ReLU (approximate)nn.init.kaiming_normal_(w, nonlinearity='relu')
Linear/MLP + TanhXavier Normalσ=2/(nin+nout)\sigma = \sqrt{2/(n_{\text{in}}+n_{\text{out}})}nn.init.xavier_normal_(w)
Linear/MLP + SigmoidXavier NormalSame as Tanhnn.init.xavier_normal_(w)
Linear/MLP + SELULeCun Normalσ=1/nin\sigma = \sqrt{1/n_{\text{in}}}nn.init.normal_(w, std=(1/fan_in)**0.5)
Conv2d + ReLUKaiming Normal, fan_outσ=2/noutk2\sigma = \sqrt{2/n_{\text{out}} \cdot k^2}nn.init.kaiming_normal_(w, mode='fan_out')
RNN recurrent weightsOrthogonalSingular values = 1nn.init.orthogonal_(w)
Transformer weightsNormal(0, 0.02)GPT-2 conventionnn.init.normal_(w, std=0.02)
EmbeddingNormal(0, 0.02)Same as GPT-2nn.init.normal_(w, std=0.02)
LayerNorm gammaOnesγ=1\gamma = 1nn.init.ones_(w)
LayerNorm betaZerosβ=0\beta = 0nn.init.zeros_(b)
BatchNorm gammaOnesγ=1\gamma = 1nn.init.ones_(w)
BatchNorm betaZerosβ=0\beta = 0nn.init.zeros_(b)
Output bias (imbalanced)Log-oddsb=log(p/(1p))b = \log(p/(1-p))layer.bias.fill_(math.log(p/(1-p)))
LSTM forget gate bias1.0Biases toward memoryforget_gate_bias.fill_(1.0)
Residual projectionZero (last layer)W=0W = 0nn.init.zeros_(w)

The most common initialization mistakes in production:

  1. Using nn.init.xavier_normal_ with ReLU (underfits variance by 2×)
  2. Using nn.init.kaiming_normal_ with tanh (overfits variance by ~2×)
  3. Forgetting to zero-initialize the final projection of residual branches (causes activation variance growth proportional to depth)
  4. Using standard Gaussian (σ=1\sigma = 1) for transformer embeddings (10–50× too large, causes attention score saturation at initialization)

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Weight Initialization Strategies demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.