Chain Rule and Backpropagation - How Neural Networks Learn
Reading time: ~28 minutes | Level: Mathematical Foundations → ML Engineering
You type loss.backward(). One line. Half a second later, every one of the 7 billion parameters in your model has a gradient. How?
The answer is backpropagation. And backpropagation is exactly one mathematical rule applied cleverly: the chain rule of calculus. Nothing more. The "mystery" of how gradients flow backward through a deep network is simply the chain rule, computed in the right order, reusing intermediate results.
This lesson builds backpropagation from first principles. By the end, loss.backward() will not be magic - it will be obvious.
What You Will Learn
- The chain rule for single-variable and multi-variable functions
- What a computational graph is and how to read one
- Forward pass: computing outputs and storing intermediate values
- Backward pass: computing gradients via chain rule in reverse
- Manual backprop through a 2-layer fully connected network
- How PyTorch autograd builds and traverses the computational graph
- Common gradient flow issues: vanishing and exploding gradients
Prerequisites
- Lesson 01: Derivatives and Gradients (required)
- Python and NumPy
- Basic neural network concepts (layers, activation functions, loss)
Part 1 - The Chain Rule
Single-variable chain rule
If y = f(g(x)), then:
Read it as: "the rate at which y changes with respect to x equals the rate at which y changes with respect to g, times the rate at which g changes with respect to x."
Example: y = sin(x²)
Let g = x², y = sin(g). Then:
- dg/dx = 2x
- dy/dg = cos(g) = cos(x²)
- dy/dx = cos(x²) · 2x
import numpy as np
# Verify with numerical derivative
def y_fn(x):
return np.sin(x**2)
def dy_dx_analytical(x):
return np.cos(x**2) * 2 * x
def numerical_derivative(f, x, h=1e-6):
return (f(x + h) - f(x - h)) / (2 * h)
x = 1.5
print(f"Analytical: {dy_dx_analytical(x):.8f}")
print(f"Numerical: {numerical_derivative(y_fn, x):.8f}")
# Both approximately -2.79626... - confirms the chain rule
Multi-variable chain rule
When g has multiple inputs feeding into f, the chain rule sums contributions from all paths:
If z = f(x, y) where x = g(t) and y = h(t):
This is critical for neural networks where one variable feeds into multiple downstream computations.
Chain rule with multiple composed functions
For a deep chain y = f₄(f₃(f₂(f₁(x)))):
This is backpropagation for a 4-layer network with no branching. The gradient flows backward through each layer, multiplying local derivatives.
Forward: x → [f₁] → h₁ → [f₂] → h₂ → [f₃] → h₃ → [f₄] → y
Backward: δx ← [∂f₁/∂x]ᵀ ← δh₁ ← [∂f₂/∂h₁]ᵀ ← ... ← δy=1
Part 2 - Computational Graphs
A computational graph is a directed acyclic graph (DAG) where:
- Nodes are variables (input data, parameters, intermediate activations, output/loss)
- Edges are operations (addition, multiplication, activation functions)
Every computation in a neural network can be expressed as a computational graph.
Simple example: y = (x · w + b)²
x w
\ /
[×] ← multiplication: z₁ = x·w
|
z₁ b
\ /
[+] ← addition: z₂ = z₁ + b = xw + b
|
z₂
|
[²] ← square: y = z₂² = (xw + b)²
|
y
Forward pass (left to right):
- z₁ = x · w
- z₂ = z₁ + b
- y = z₂²
Backward pass (right to left, chain rule):
- dy/dz₂ = 2·z₂
- dz₂/dz₁ = 1, dz₂/db = 1
- dz₁/dx = w, dz₁/dw = x
Chain rule gives:
- dy/db = (dy/dz₂) · (dz₂/db) = 2z₂ · 1 = 2(xw+b)
- dy/dw = (dy/dz₂) · (dz₂/dz₁) · (dz₁/dw) = 2z₂ · 1 · x = 2(xw+b)·x
- dy/dx = (dy/dz₂) · (dz₂/dz₁) · (dz₁/dx) = 2z₂ · 1 · w = 2(xw+b)·w
import numpy as np
# Manual forward and backward for y = (x*w + b)^2
x = 2.0
w = 3.0
b = 1.0
# ── Forward pass ──────────────────────────────────────────────────────────
z1 = x * w # 6.0
z2 = z1 + b # 7.0
y = z2 ** 2 # 49.0
print(f"Forward pass: z1={z1}, z2={z2}, y={y}")
# ── Backward pass (chain rule, right to left) ───────────────────────────
# dy/dz2: derivative of z2^2 is 2*z2
dy_dz2 = 2 * z2 # 14.0
# dz2/dz1 = 1 (linear), dz2/db = 1 (linear)
dz2_dz1 = 1.0
dz2_db = 1.0
# dz1/dw = x, dz1/dx = w
dz1_dw = x # 2.0
dz1_dx = w # 3.0
# Chain rule: multiply going back through the graph
dy_db = dy_dz2 * dz2_db # 14.0 * 1 = 14.0
dy_dw = dy_dz2 * dz2_dz1 * dz1_dw # 14.0 * 1 * 2.0 = 28.0
dy_dx = dy_dz2 * dz2_dz1 * dz1_dx # 14.0 * 1 * 3.0 = 42.0
print(f"\nBackward pass:")
print(f"dy/db = {dy_db}") # 14.0
print(f"dy/dw = {dy_dw}") # 28.0
print(f"dy/dx = {dy_dx}") # 42.0
# ── Verify with numerical gradients ───────────────────────────────────────
h = 1e-6
def f_fn(x, w, b): return (x*w + b) ** 2
print(f"\nNumerical verification:")
print(f"dy/db ≈ {(f_fn(x, w, b+h) - f_fn(x, w, b-h))/(2*h):.4f}") # 14.0
print(f"dy/dw ≈ {(f_fn(x, w+h, b) - f_fn(x, w-h, b))/(2*h):.4f}") # 28.0
print(f"dy/dx ≈ {(f_fn(x+h, w, b) - f_fn(x-h, w, b))/(2*h):.4f}") # 42.0
Why computational graphs are powerful
The key insight: we can compute the gradient with respect to any input or parameter by simply tracing the graph backward from the output, multiplying local derivatives.
"Local derivative": At each node, we only need the derivative of that node's output with respect to its own inputs. We do not need to know what came before.
Reusing intermediate results: During the forward pass, we save intermediate activations (z₁, z₂, etc.). During the backward pass, we reuse them to compute local derivatives efficiently. This is the core efficiency of backpropagation.
Part 3 - Backpropagation Through a 2-Layer Network
Let us manually derive backpropagation for a 2-layer network. This is the exact computation PyTorch does automatically.
Network architecture
Input x (d,) → Linear₁ → ReLU → Linear₂ → Scalar output → MSE loss
Layer 1: h = ReLU(W₁x + b₁), where W₁ is (H, d), b₁ is (H,) Layer 2: ŷ = W₂h + b₂, where W₂ is (1, H), b₂ is scalar Loss: L = (ŷ - y)²
Forward pass
x → z₁ = W₁x + b₁ → h = ReLU(z₁) → ŷ = W₂h + b₂ → L = (ŷ - y)²
Variables to save for backward: x, z₁, h, ŷ (we need them to compute local derivatives)
Backward pass derivation
We want: ∂L/∂W₁, ∂L/∂b₁, ∂L/∂W₂, ∂L/∂b₂
Step 1: Gradient of loss with respect to ŷ:
Step 2: Gradient through Layer 2 (linear):
Step 3: Gradient through ReLU:
(Element-wise: pass gradient through where z₁ > 0, zero out where z₁ ≤ 0)
Step 4: Gradient through Layer 1 (linear):
Full implementation from scratch
import numpy as np
class TwoLayerNet:
"""2-layer neural network with manual forward and backward pass."""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int = 1):
# He initialization for ReLU
self.W1 = np.random.randn(hidden_dim, input_dim) * np.sqrt(2.0 / input_dim)
self.b1 = np.zeros(hidden_dim)
self.W2 = np.random.randn(output_dim, hidden_dim) * np.sqrt(2.0 / hidden_dim)
self.b2 = np.zeros(output_dim)
self.cache = {}
def forward(self, x: np.ndarray) -> np.ndarray:
"""Forward pass: compute output and save activations for backward."""
z1 = self.W1 @ x + self.b1 # (hidden_dim,)
h = np.maximum(0, z1) # ReLU
z2 = self.W2 @ h + self.b2 # (output_dim,)
# Save for backward
self.cache = {'x': x, 'z1': z1, 'h': h}
return z2
def backward(self, y_hat: np.ndarray, y: np.ndarray) -> dict:
"""Backward pass: compute gradients of MSE loss w.r.t. all parameters."""
x = self.cache['x']
z1 = self.cache['z1']
h = self.cache['h']
# Step 1: Gradient of MSE loss
# L = (y_hat - y)^2, so ∂L/∂y_hat = 2*(y_hat - y)
delta_yhat = 2 * (y_hat - y) # (output_dim,)
# Step 2: Backward through Layer 2
dW2 = np.outer(delta_yhat, h) # (output_dim, hidden_dim)
db2 = delta_yhat # (output_dim,)
delta_h = self.W2.T @ delta_yhat # (hidden_dim,)
# Step 3: Backward through ReLU
relu_mask = (z1 > 0).astype(float)
delta_z1 = delta_h * relu_mask # (hidden_dim,)
# Step 4: Backward through Layer 1
dW1 = np.outer(delta_z1, x) # (hidden_dim, input_dim)
db1 = delta_z1 # (hidden_dim,)
return {'dW1': dW1, 'db1': db1, 'dW2': dW2, 'db2': db2}
def mse_loss(self, y_hat: np.ndarray, y: np.ndarray) -> float:
return float(np.mean((y_hat - y) ** 2))
def update(self, grads: dict, lr: float = 0.01):
self.W1 -= lr * grads['dW1']
self.b1 -= lr * grads['db1']
self.W2 -= lr * grads['dW2']
self.b2 -= lr * grads['db2']
# Gradient check: verify manual backprop matches numerical gradients
np.random.seed(42)
net = TwoLayerNet(input_dim=4, hidden_dim=8)
x = np.random.randn(4)
y = np.array([1.0])
y_hat = net.forward(x)
loss = net.mse_loss(y_hat, y)
grads = net.backward(y_hat, y)
# Check W1[0,0]
h_eps = 1e-5
original = net.W1[0, 0]
net.W1[0, 0] = original + h_eps
loss_plus = net.mse_loss(net.forward(x), y)
net.W1[0, 0] = original - h_eps
loss_minus = net.mse_loss(net.forward(x), y)
net.W1[0, 0] = original
numerical = (loss_plus - loss_minus) / (2 * h_eps)
analytical = grads['dW1'][0, 0]
print(f"Gradient check for W1[0,0]:")
print(f" Analytical: {analytical:.8f}")
print(f" Numerical: {numerical:.8f}")
rel_err = abs(analytical - numerical) / (abs(analytical) + abs(numerical) + 1e-8)
print(f" Relative error: {rel_err:.2e}") # should be < 1e-4
Part 4 - Backpropagation in Matrix Form (Batched)
Real training uses batches. For a batch of size B, input X is (B, d), targets Y are (B, 1):
Batched forward pass
Z1 = X @ W1.T + b1 → (B, H)
H = ReLU(Z1) → (B, H)
Z2 = H @ W2.T + b2 → (B, 1)
L = mean((Z2 - Y)^2) → scalar
Batched backward pass
import numpy as np
def forward_batch(X, W1, b1, W2, b2):
"""Batched forward pass. X shape: (B, d)"""
Z1 = X @ W1.T + b1 # (B, H)
H = np.maximum(0, Z1) # (B, H)
Z2 = H @ W2.T + b2 # (B, 1)
return Z2, H, Z1
def backward_batch(X, Z1, H, Z2, Y, W2):
"""Batched backward pass."""
B = X.shape[0]
delta_Z2 = (2 / B) * (Z2 - Y) # (B, 1)
dW2 = delta_Z2.T @ H # (1, H)
db2 = delta_Z2.sum(axis=0) # (1,)
delta_H = delta_Z2 @ W2 # (B, H)
delta_Z1 = delta_H * (Z1 > 0) # (B, H)
dW1 = delta_Z1.T @ X # (H, d)
db1 = delta_Z1.sum(axis=0) # (H,)
return {'dW1': dW1, 'db1': db1, 'dW2': dW2, 'db2': db2}
# Test batched backprop
np.random.seed(42)
B, d, H_dim = 32, 4, 16
W1 = np.random.randn(H_dim, d) * 0.1
b1 = np.zeros(H_dim)
W2 = np.random.randn(1, H_dim) * 0.1
b2 = np.zeros(1)
X_batch = np.random.randn(B, d)
Y_batch = np.random.randn(B, 1)
Z2, H_out, Z1 = forward_batch(X_batch, W1, b1, W2, b2)
loss = np.mean((Z2 - Y_batch)**2)
grads = backward_batch(X_batch, Z1, H_out, Z2, Y_batch, W2)
print(f"Batch size: {B}, Loss: {loss:.4f}")
print(f"dW1 shape: {grads['dW1'].shape}") # (H, d) = (16, 4)
print(f"dW2 shape: {grads['dW2'].shape}") # (1, H) = (1, 16)
Part 5 - PyTorch Autograd: The Chain Rule Automated
PyTorch's autograd system builds the computational graph dynamically as you execute operations.
How autograd works
When you perform operations on tensors with requires_grad=True, PyTorch:
- Executes the operation (forward pass)
- Records the operation and its inputs into a node in the computation graph
- Each node stores its grad_fn - a function that implements the backward pass for that operation
When you call .backward(), PyTorch traverses the graph in reverse topological order, calling each grad_fn and accumulating gradients.
import torch
# Visualizing the computational graph
x = torch.tensor([2.0, 3.0], requires_grad=True)
W = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
b = torch.tensor([0.5, -0.5], requires_grad=True)
# Forward pass - PyTorch builds the graph
z = W @ x + b
h = torch.relu(z)
loss = h.sum()
print(f"x.grad_fn: {x.grad_fn}") # None (leaf tensor)
print(f"z.grad_fn: {z.grad_fn}") # AddmmBackward0
print(f"h.grad_fn: {h.grad_fn}") # ReluBackward0
print(f"loss.grad_fn: {loss.grad_fn}") # SumBackward0
# Backward pass: traverses graph in reverse, accumulates gradients
loss.backward()
print(f"\nAfter backward():")
print(f"x.grad: {x.grad}") # dL/dx - same shape as x
print(f"W.grad: {W.grad}") # dL/dW - same shape as W
print(f"b.grad: {b.grad}") # dL/db - same shape as b
Implementing a custom autograd function
For non-standard operations, implement both forward and backward manually:
import torch
class SigmoidCustom(torch.autograd.Function):
"""Custom sigmoid with manual forward and backward."""
@staticmethod
def forward(ctx, x):
"""
Compute output and save tensors needed for backward.
ctx (context object) stores intermediate values.
"""
output = 1.0 / (1.0 + torch.exp(-x))
ctx.save_for_backward(output) # save sigmoid output for backward
return output
@staticmethod
def backward(ctx, grad_output):
"""
Compute gradient with respect to input.
grad_output: dL/d(output) - upstream gradient from next layer
Returns: dL/dx = dL/d(output) * d(output)/dx
"""
output, = ctx.saved_tensors
# Sigmoid derivative: sigma * (1 - sigma)
sigmoid_grad = output * (1 - output)
# Chain rule: multiply upstream gradient by local derivative
return grad_output * sigmoid_grad
# Use as a regular function
sigmoid_custom = SigmoidCustom.apply
x = torch.randn(5, requires_grad=True)
y = sigmoid_custom(x)
y.sum().backward()
# Verify against PyTorch built-in
x2 = x.detach().clone().requires_grad_(True)
torch.sigmoid(x2).sum().backward()
print("Custom gradient:", x.grad.numpy().round(6))
print("PyTorch gradient:", x2.grad.numpy().round(6))
print("Match:", torch.allclose(x.grad, x2.grad))
Gradient checkpointing
Large models cannot store all intermediate activations - too much memory. Gradient checkpointing trades compute for memory:
import torch
import torch.utils.checkpoint as checkpoint
class LargeBlock(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(dim, dim * 4),
torch.nn.GELU(),
torch.nn.Linear(dim * 4, dim),
)
def forward(self, x):
return self.layers(x)
model = LargeBlock(512)
x = torch.randn(32, 512, requires_grad=True)
# With checkpointing: recomputes forward during backward
# Saves ~40-60% memory at cost of ~30% extra compute
y = checkpoint.checkpoint(model, x, use_reentrant=False)
Part 6 - Gradient Flow Issues
Vanishing gradients
In deep networks with sigmoid/tanh activations, gradients can shrink to near-zero:
Why: Sigmoid derivative σ(x)(1-σ(x)) ≤ 0.25. For a 10-layer network with sigmoid:
- Each layer multiplies gradient by its derivative (≤ 0.25)
- After 10 layers: 0.25^10 ≈ 9.5 × 10⁻⁷ - effectively zero
Early layers receive essentially no gradient signal.
import numpy as np
def sigmoid(x): return 1 / (1 + np.exp(-x))
def sigmoid_prime(x): s = sigmoid(x); return s * (1 - s)
def relu_prime(x): return float(x > 0)
n_layers = 15
print("Gradient after each layer:")
print(f"{'Layer':>6} | {'sigmoid':>12} | {'ReLU (p=0.5)':>14}")
print("-" * 40)
grad_sigmoid = 1.0
grad_relu = 1.0
for layer in range(n_layers):
x = np.random.randn()
grad_sigmoid *= sigmoid_prime(x)
# ReLU: derivative is 0 or 1, roughly 50% active
grad_relu *= relu_prime(x)
if layer < 5 or layer == n_layers - 1:
print(f"{layer+1:>6} | {grad_sigmoid:>12.2e} | {grad_relu:>14.2e}")
print()
print("After 15 sigmoid layers, gradient is nearly zero - vanishing gradient problem.")
Solutions:
- ReLU activations: derivative is 1 for positive inputs
- Residual connections (ResNet): direct gradient path through skip connections
- Batch normalization: keeps activations in a well-scaled range
Exploding gradients
# Exploding: large weight matrix Jacobians multiply gradients upward
n_layers = 10
gradient = 1.0
print("Gradient explosion with large weights:")
for layer in range(n_layers):
w_scale = abs(np.random.randn()) * 2.0 # average |weight| > 1
gradient *= w_scale
print(f" Layer {layer+1}: gradient = {gradient:.2e}")
Solution: Gradient clipping (covered in depth in Lesson 08)
import torch
def clip_gradients(model: torch.nn.Module, max_norm: float = 1.0):
"""Clip gradient norm to prevent exploding gradients."""
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
return total_norm.item()
# In training loop:
# optimizer.zero_grad()
# loss.backward()
# grad_norm = clip_gradients(model, max_norm=1.0)
# optimizer.step()
:::tip ReLU almost completely solves vanishing gradients ReLU has derivative exactly 1 for positive inputs and 0 for negative. For a well-initialized network, roughly half the neurons are active (positive pre-activation) per layer, so the gradient magnitude is preserved - it does not shrink by a factor of 0.25 at every layer like sigmoid. This is the primary reason ReLU replaced sigmoid as the default activation function in deep networks. :::
Part 7 - The Vector-Jacobian Product (VJP)
PyTorch's key efficiency insight: for a scalar loss L and network output y = f(x), we never compute the full Jacobian J = ∂y/∂x. We compute the VJP: vᵀJ where v = ∂L/∂y is the upstream gradient.
import torch
# VJP illustration: non-scalar output requires explicit upstream gradient
x = torch.randn(3, requires_grad=True)
y = x ** 2 # vector output, not scalar
# Must specify what we are differentiating (upstream gradient = VJP vector)
upstream_grad = torch.ones(3) # equivalent to .sum().backward()
y.backward(upstream_grad) # computes vᵀ * J where v = upstream_grad
print(f"x.grad (VJP): {x.grad}") # 2x * 1 = 2x
print(f"Expected: {2*x.detach()}") # matches
# Why not compute full Jacobian?
# For 1M parameters and 1 loss value:
# - Full Jacobian: 1M x 1 = 1M numbers (that's the gradient)
# - For 1M parameters and 1M intermediate values: 1M x 1M = 10^12 numbers!
# VJP avoids this by propagating the scalar loss gradient backwards
Part 8 - Common Backpropagation Bugs
:::danger In-place operations break autograd In-place operations overwrite stored tensor values that autograd needs for the backward pass.
# WRONG: in-place modification corrupts backward
x = torch.randn(3, requires_grad=True)
y = x * 2
x += 1 # in-place! autograd stored original x for backward, now it is gone
y.sum().backward() # may raise RuntimeError
# RIGHT: create a new tensor
x = torch.randn(3, requires_grad=True)
y = x * 2
x = x + 1 # not in-place; original x still accessible for backward
y.sum().backward() # works
:::
:::danger Calling backward() twice without retain_graph Autograd frees the computational graph after the first backward pass.
x = torch.randn(3, requires_grad=True)
y = x.sum()
y.backward() # works, frees graph
y.backward() # RuntimeError: graph freed
# Fix: retain the graph for the first backward call
y.backward(retain_graph=True) # graph kept
y.backward() # now works
:::
:::warning Gradient accumulation across batches
PyTorch accumulates gradients by default. Always call optimizer.zero_grad() before loss.backward(). If you forget, the gradient from the previous batch is added to the current one.
# WRONG: gradients accumulate
for batch in dataloader:
loss = compute_loss(batch)
loss.backward() # adds to previous .grad
optimizer.step()
# RIGHT: zero before each step
for batch in dataloader:
optimizer.zero_grad()
loss = compute_loss(batch)
loss.backward()
optimizer.step()
:::
Interview Questions
Q1: Explain backpropagation in one paragraph without using the word "backpropagation."
Training a neural network requires computing the gradient of the loss with respect to every parameter. We do this by representing the computation as a directed graph from input to loss, then applying the chain rule of calculus in reverse. Starting from the loss (gradient = 1), we work backward through each node, computing the gradient at each node by multiplying the upstream gradient by the local derivative of that node's operation. Because we reuse upstream gradients (computed once and passed back), this is efficient - O(n) operations where n is the number of operations in the forward pass. PyTorch builds this graph automatically as you execute operations, and calling .backward() traverses the graph in reverse topological order.
Q2: Why does backpropagation work efficiently for scalar losses?
The key insight is that ML always has a scalar loss. For a scalar loss L and a parameter tensor W, we need ∂L/∂W - which is the same shape as W, not a full Jacobian.
Without this: If the output were a vector of dimension m and the input were a vector of dimension n, we would need the full m×n Jacobian - O(mn) numbers. For a network with millions of parameters and millions of output dimensions, this is impossible.
With a scalar loss: We need only the vector-Jacobian product (VJP): (∂L/∂output)ᵀ · J, which is always the same shape as the input. This requires only O(n) numbers and O(n) compute per layer.
Reverse mode automatic differentiation is efficient precisely because it computes VJPs in a single backward sweep. Forward mode would be efficient for few inputs and many outputs - which is not the ML case.
Q3: What is the gradient of a ReLU activation, and why does it matter for deep networks?
ReLU(x) = max(0, x). Its derivative is:
- 1 when x > 0 (gradient passes through unchanged)
- 0 when x ≤ 0 (gradient is blocked - "dead neuron")
Why it matters:
-
Vanishing gradients: With sigmoid, the local derivative is σ(1-σ) ≤ 0.25. After 10 layers, the gradient is at most 0.25^10 ≈ 10⁻⁷. Early layers get no gradient signal and do not learn. ReLU has derivative 1 for positive inputs, so gradients do not shrink.
-
Dead neurons: If a neuron's pre-activation is always negative, its gradient is always 0, and the neuron never receives a learning signal. This is why initialization matters - neurons need to start in the active regime.
-
Why deep learning became practical: Before ReLU, networks with many layers were nearly untrainable due to vanishing gradients. ReLU (mainstream from ~2012) was a key enabler of the deep learning revolution.
Variants addressing dead neurons: Leaky ReLU (gradient = 0.01 for x<0), ELU, GELU (smooth approximation used in transformers).
Q4: Walk me through the backward pass for a single linear layer y = Wx + b.
Given upstream gradient δ = ∂L/∂y (shape: m for m-dimensional output):
Gradient with respect to W (parameter update):
Each entry: (∂L/∂W)[i,j] = δᵢ · xⱼ
Gradient with respect to b:
Gradient with respect to x (for upstream backward):
Intuition: W maps x (n-dim) to y (m-dim). When propagating gradients backward, we transpose this mapping - Wᵀ maps output gradient (m-dim) back to input gradient (n-dim). This is the mathematical reason why backprop involves transposed weight matrices.
Q5: What is the difference between forward mode and reverse mode automatic differentiation?
Both compute exact derivatives via the chain rule. They differ in traversal direction and efficiency characteristics.
Forward mode AD:
- Propagates derivatives forward alongside the primal computation
- Each forward sweep computes the derivative with respect to ONE input
- Cost: O(n_inputs) sweeps to get all partial derivatives
- Efficient when: few inputs, many outputs
Reverse mode AD (used by PyTorch):
- Requires one complete forward pass (storing all intermediates)
- One backward pass computes derivatives with respect to ALL inputs
- Cost: O(1) backward passes regardless of n_inputs
- Efficient when: few outputs (scalar loss), many inputs (millions of parameters)
In ML: Millions of parameters (inputs to loss), exactly 1 loss (output). Reverse mode is optimal - one backward pass computes all gradients simultaneously.
Memory trade-off: Reverse mode stores all intermediate activations from the forward pass. For very deep networks, this can be memory-intensive. Gradient checkpointing (recomputing activations during backward) trades extra compute for less memory - often used in training large transformers.
Quick Reference
| Concept | Formula | Code |
|---|---|---|
| Chain rule (1 var) | dy/dx = (dy/dg)(dg/dx) | Multiply local derivatives |
| Chain rule (multiple paths) | ∂z/∂t = Σᵢ (∂z/∂xᵢ)(∂xᵢ/∂t) | Sum contributions |
| Linear layer backward (W) | ∂L/∂W = δxᵀ | np.outer(delta, x) |
| Linear layer backward (x) | ∂L/∂x = Wᵀδ | W.T @ delta |
| ReLU backward | δ · 1[z>0] | delta * (z > 0) |
| Sigmoid backward | δ · σ(1-σ) | delta * s * (1 - s) |
| PyTorch autograd | graph built dynamically | loss.backward() |
| Custom backward | VJP = vᵀ·J | torch.autograd.Function |
| Gradient checkpointing | recompute activations | checkpoint.checkpoint(fn, x) |
Key Takeaways
- Backpropagation is the chain rule of calculus applied to a computational graph in reverse order
- The forward pass computes the output and saves intermediate activations needed for backward
- The backward pass computes gradients by propagating the upstream gradient through each operation's local derivative
- For linear layers: ∂L/∂W = δxᵀ and ∂L/∂x = Wᵀδ - transposed weight matrices appear naturally from the chain rule
- ReLU solves vanishing gradients by having derivative 1 for positive inputs (vs sigmoid's max 0.25)
- PyTorch autograd builds the computational graph dynamically and computes vector-Jacobian products efficiently
- Always zero gradients before each backward pass; avoid in-place operations on tensors that require gradients
Next: Gradient Descent Mechanics →
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Backpropagation Step-by-Step demo on the EngineersOfAI Playground - no code required.
:::
