Automatic Differentiation - How PyTorch Really Computes Gradients
Reading time: ~24 minutes | Level: Mathematical Foundations → ML Engineering
PyTorch does not symbolically differentiate your code (like Mathematica). It does not numerically approximate derivatives with finite differences. When you call loss.backward(), it computes exact derivatives using the chain rule applied to every elementary operation in your forward pass.
This is automatic differentiation (AD). It is not approximation. It is not magic. It is the chain rule, implemented as a data structure traversal.
Understanding how AD works makes you a better ML engineer: you can implement custom differentiable operations, debug gradient flow issues, optimize memory usage with gradient checkpointing, and write efficient torch.autograd.Function subclasses.
What You Will Learn
- What automatic differentiation is (and is not)
- Dual numbers: the mathematical foundation of forward mode AD
- Forward mode AD: propagating derivatives forward
- Reverse mode AD: propagating gradients backward (what PyTorch uses)
- How PyTorch builds and traverses the computational graph
requires_grad,.backward(),.grad- the mechanics- Custom gradient functions:
torch.autograd.Function - Gradient checkpointing for memory efficiency
torch.no_grad()and when to use it
Prerequisites
- Lesson 02: Chain Rule and Backpropagation (required)
- Python and PyTorch basics
- Comfort with Python classes and decorators
Part 1 - Three Ways to Compute Derivatives
Before diving into AD, let us contrast the three approaches:
Manual symbolic differentiation
Compute derivatives by hand and hard-code them:
# Manual: derivative of f(x) = x^3 * sin(x)
# f'(x) = 3x^2 * sin(x) + x^3 * cos(x) (product rule)
def f(x): return x**3 * np.sin(x)
def df(x): return 3*x**2 * np.sin(x) + x**3 * np.cos(x)
Problems: Error-prone for complex functions. Does not scale to millions of operations. Requires re-derivation if the model changes.
Numerical differentiation (finite differences)
Approximate derivatives numerically:
def numerical_gradient(f, x, h=1e-5):
return (f(x + h) - f(x - h)) / (2 * h)
Problems: Approximation error O(h²). Floating-point cancellation for small h. O(d) forward passes per gradient. Not exact.
Automatic differentiation
Apply the chain rule mechanically to elementary operations:
import torch
x = torch.tensor(1.5, requires_grad=True)
y = x**3 * torch.sin(x)
y.backward()
print(x.grad) # exact gradient, no approximation
Pros: Exact (no approximation), O(1) backward passes regardless of parameter count, works for any computation expressible in the framework.
Part 2 - Dual Numbers: The Mathematics of Forward Mode AD
What are dual numbers?
A dual number has the form a + b·ε, where ε is a special symbol satisfying ε² = 0 (but ε ≠ 0).
Operations on dual numbers:
- Addition: (a + bε) + (c + dε) = (a+c) + (b+d)ε
- Multiplication: (a + bε)(c + dε) = ac + (ad + bc)ε + bdε² = ac + (ad+bc)ε
- Chain rule emerges automatically from multiplication!
How dual numbers compute derivatives
To compute f'(x₀), evaluate f at the dual number x₀ + 1·ε:
(since ε² = 0, higher-order terms vanish exactly)
The real part gives f(x₀), the dual part gives f'(x₀)!
class DualNumber:
"""
Dual number a + b*eps where eps^2 = 0.
Used to compute derivatives in forward mode AD.
"""
def __init__(self, real: float, dual: float):
self.real = real # function value
self.dual = dual # derivative value
def __add__(self, other):
if isinstance(other, DualNumber):
return DualNumber(self.real + other.real, self.dual + other.dual)
return DualNumber(self.real + other, self.dual)
def __mul__(self, other):
if isinstance(other, DualNumber):
# (a + b*eps)(c + d*eps) = ac + (ad+bc)*eps
return DualNumber(
self.real * other.real,
self.real * other.dual + self.dual * other.real
)
return DualNumber(self.real * other, self.dual * other)
def __radd__(self, other): return self.__add__(other)
def __rmul__(self, other): return self.__mul__(other)
def __repr__(self): return f"DualNumber({self.real:.6f} + {self.dual:.6f}*eps)"
import math
def sin_dual(d: DualNumber) -> DualNumber:
"""sin for dual numbers: sin(a + b*eps) = sin(a) + b*cos(a)*eps"""
return DualNumber(math.sin(d.real), d.dual * math.cos(d.real))
def exp_dual(d: DualNumber) -> DualNumber:
"""exp for dual numbers: exp(a + b*eps) = exp(a) + b*exp(a)*eps"""
return DualNumber(math.exp(d.real), d.dual * math.exp(d.real))
# Compute derivative of f(x) = x^3 * sin(x) at x = 1.5
# To compute f'(x0): evaluate at DualNumber(x0, 1)
x0 = 1.5
x_dual = DualNumber(x0, 1.0) # x₀ + 1*eps
y = x_dual * x_dual * x_dual * sin_dual(x_dual) # x^3 * sin(x)
print(f"f({x0}) = {y.real:.6f}") # function value
print(f"f'({x0}) = {y.dual:.6f}") # derivative
# Verify with analytical derivative: f'(x) = 3x^2*sin(x) + x^3*cos(x)
analytical = 3*x0**2 * math.sin(x0) + x0**3 * math.cos(x0)
print(f"Analytical f'({x0}) = {analytical:.6f}")
print(f"Match: {abs(y.dual - analytical) < 1e-10}")
Forward mode AD for multi-variable functions
To compute the gradient ∇f(x₀) for f: ℝⁿ → ℝ, we need n forward passes - one for each input dimension:
Pass 1: x₀ + e₁·ε → gives ∂f/∂x₁
Pass 2: x₀ + e₂·ε → gives ∂f/∂x₂
...
Pass n: x₀ + eₙ·ε → gives ∂f/∂xₙ
This requires n forward passes to compute the full gradient - O(n) cost.
For ML with n = millions of parameters and one scalar loss: forward mode is extremely expensive.
Part 3 - Reverse Mode AD: One Pass for All Gradients
The key insight
For a scalar function f: ℝⁿ → ℝ (like a loss function), forward mode requires n passes. Reverse mode requires only 1 backward pass, regardless of n.
This is why PyTorch uses reverse mode AD.
How reverse mode works
Reverse mode AD has two phases:
Phase 1 (forward pass): Execute the computation normally. Record each operation and its inputs in the computational graph. Save all intermediate values.
Phase 2 (backward pass): Traverse the graph in reverse. At each node, compute the vector-Jacobian product (VJP) - the gradient contribution from that node - using the chain rule.
At each node with operation y = f(x₁, x₂, ..., xₖ), the backward pass receives the upstream gradient v = ∂L/∂y and computes:
Manual reverse mode on a computational graph
import numpy as np
class Value:
"""
Scalar value with automatic differentiation (reverse mode).
Inspired by Karpathy's micrograd - the minimal AD engine.
"""
def __init__(self, data: float, _children=(), _op: str = ''):
self.data = data
self.grad = 0.0 # dL/d(self)
self._backward = lambda: None # function to compute gradients of inputs
self._prev = set(_children)
self._op = _op
def __add__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data + other.data, (self, other), '+')
def _backward():
# d(a+b)/da = 1, d(a+b)/db = 1
# upstream gradient * local derivative
self.grad += out.grad * 1.0
other.grad += out.grad * 1.0
out._backward = _backward
return out
def __mul__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data * other.data, (self, other), '*')
def _backward():
# d(a*b)/da = b, d(a*b)/db = a
self.grad += out.grad * other.data
other.grad += out.grad * self.data
out._backward = _backward
return out
def relu(self):
out = Value(max(0, self.data), (self,), 'ReLU')
def _backward():
# d ReLU(x)/dx = 1 if x > 0 else 0
self.grad += out.grad * (out.data > 0)
out._backward = _backward
return out
def __pow__(self, power):
assert isinstance(power, (int, float))
out = Value(self.data**power, (self,), f'**{power}')
def _backward():
# d(x^n)/dx = n * x^(n-1)
self.grad += out.grad * power * (self.data ** (power - 1))
out._backward = _backward
return out
def backward(self):
"""Topological sort + reverse traversal."""
# Build topological order
topo = []
visited = set()
def build_topo(v):
if v not in visited:
visited.add(v)
for child in v._prev:
build_topo(child)
topo.append(v)
build_topo(self)
# Backward pass
self.grad = 1.0 # dL/dL = 1
for v in reversed(topo):
v._backward()
def __repr__(self): return f"Value({self.data:.4f}, grad={self.grad:.4f})"
def __rmul__(self, other): return self.__mul__(other)
def __radd__(self, other): return self.__add__(other)
def __neg__(self): return self * -1
# Test: compute gradient of L = (wx + b)^2 at w=2, x=3, b=1
w = Value(2.0)
x = Value(3.0)
b = Value(1.0)
z = w * x + b # z = wx + b = 7
L = z ** 2 # L = z^2 = 49
print(f"L = {L.data}")
L.backward()
print(f"\nGradients:")
print(f"dL/dw = {w.grad:.4f}") # should be 2z*x = 2*7*3 = 42
print(f"dL/dx = {x.grad:.4f}") # should be 2z*w = 2*7*2 = 28
print(f"dL/db = {b.grad:.4f}") # should be 2z*1 = 2*7*1 = 14
# Verify numerically
h = 1e-6
def f_w(w_val): return (w_val*3 + 1)**2
def f_x(x_val): return (2*x_val + 1)**2
def f_b(b_val): return (2*3 + b_val)**2
print(f"\nNumerical verification:")
print(f"dL/dw ≈ {(f_w(2+h) - f_w(2-h))/(2*h):.4f}")
print(f"dL/dx ≈ {(f_x(3+h) - f_x(3-h))/(2*h):.4f}")
print(f"dL/db ≈ {(f_b(1+h) - f_b(1-h))/(2*h):.4f}")
Part 4 - PyTorch Autograd Deep Dive
The computational graph in PyTorch
PyTorch builds a dynamic computational graph (also called a "define-by-run" graph). The graph is created during the forward pass and freed after the backward pass.
import torch
# Visualizing the computational graph structure
x = torch.tensor(3.0, requires_grad=True)
w = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)
# Each operation creates a node with a grad_fn
z = w * x + b # z.grad_fn = AddBackward0
L = z ** 2 # L.grad_fn = PowBackward0
print(f"x.grad_fn: {x.grad_fn}") # None (leaf)
print(f"z.grad_fn: {z.grad_fn}") # AddBackward0
print(f"L.grad_fn: {L.grad_fn}") # PowBackward0
print()
# Inspect the graph structure
print(f"L's inputs: {L.grad_fn.next_functions}") # (z's grad_fn,)
print(f"z's inputs: {z.grad_fn.next_functions}") # (w's grad_fn, x's grad_fn, b's grad_fn...)
# Backward pass: traverses graph, accumulates gradients
L.backward()
print(f"\ndL/dw = {w.grad}") # 42.0
print(f"dL/dx = {x.grad}") # 28.0
print(f"dL/db = {b.grad}") # 14.0
Leaf tensors and non-leaf tensors
import torch
# Leaf tensors: created directly by the user, not by operations
a = torch.tensor([1.0, 2.0], requires_grad=True) # leaf
b = torch.tensor([3.0, 4.0], requires_grad=False) # leaf (no grad)
# Non-leaf tensor: result of an operation
c = a * b # non-leaf
d = c.sum() # non-leaf, scalar
print(f"a.is_leaf: {a.is_leaf}") # True - PyTorch stores .grad for leaf tensors
print(f"c.is_leaf: {c.is_leaf}") # False - intermediate result
d.backward()
print(f"a.grad: {a.grad}") # [3., 4.] - b values (chain rule for multiplication)
print(f"c.grad: {c.grad}") # None - only leaf tensors store .grad by default
# To retain non-leaf gradients (for debugging):
a2 = torch.tensor([1.0, 2.0], requires_grad=True)
b2 = torch.tensor([3.0, 4.0])
c2 = a2 * b2
c2.retain_grad() # ← explicitly request intermediate gradient
d2 = c2.sum()
d2.backward()
print(f"c2.grad (retained): {c2.grad}") # [1., 1.] - gradient of sum w.r.t. each c2 element
Gradient accumulation vs. zeroing
import torch
import torch.nn as nn
model = nn.Linear(3, 1)
X = torch.randn(10, 3)
y = torch.randn(10, 1)
# WRONG: gradients accumulate across backward calls
for epoch in range(3):
output = model(X)
loss = ((output - y)**2).mean()
loss.backward() # .grad adds to previous .grad
# model.weight.grad is now sum of 3 backward passes!
print(f"After epoch {epoch}: weight grad norm = {model.weight.grad.norm().item():.4f}")
print()
# RIGHT: zero gradients before each backward
for epoch in range(3):
model.zero_grad() # or optimizer.zero_grad()
output = model(X)
loss = ((output - y)**2).mean()
loss.backward()
print(f"After epoch {epoch}: weight grad norm = {model.weight.grad.norm().item():.4f}")
Intentional gradient accumulation
Sometimes accumulation is useful - for simulating large batch sizes on limited GPU memory:
import torch
import torch.nn as nn
model = nn.Linear(100, 10)
# Simulate batch_size=128 as 4 accumulated micro-batches of 32
accumulation_steps = 4
effective_batch_size = 128
micro_batch_size = effective_batch_size // accumulation_steps
model.zero_grad()
for micro_step in range(accumulation_steps):
X_micro = torch.randn(micro_batch_size, 100)
y_micro = torch.randint(0, 10, (micro_batch_size,))
output = model(X_micro)
loss = nn.CrossEntropyLoss()(output, y_micro)
# Divide by accumulation steps to get correct average gradient
loss = loss / accumulation_steps
loss.backward() # gradients accumulate
# After all micro-batches, gradient is equivalent to large batch
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer.step()
model.zero_grad()
Part 5 - Custom Gradient Functions
When you need a non-standard operation or want to override PyTorch's gradient:
import torch
class ClampedSigmoid(torch.autograd.Function):
"""
Sigmoid with gradient clamping to prevent vanishing gradients.
Forward: standard sigmoid
Backward: clip gradient to [min_grad, 1] to ensure minimum gradient flow
"""
@staticmethod
def forward(ctx, x, min_grad=0.01):
output = torch.sigmoid(x)
ctx.save_for_backward(output)
ctx.min_grad = min_grad
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
# Standard sigmoid gradient: sigma * (1 - sigma)
sigmoid_grad = output * (1 - output)
# Clamp to minimum to prevent vanishing
sigmoid_grad = torch.clamp(sigmoid_grad, min=ctx.min_grad)
# Chain rule: upstream gradient * local derivative
return grad_output * sigmoid_grad, None # None for min_grad (not a tensor)
# Usage
x = torch.tensor([-5.0, -2.0, 0.0, 2.0, 5.0], requires_grad=True)
y = ClampedSigmoid.apply(x, 0.01)
y.sum().backward()
print(f"x: {x.detach().numpy()}")
print(f"Clamped sigmoid gradient: {x.grad.numpy().round(4)}")
# Compare with standard sigmoid gradient
x2 = x.detach().clone().requires_grad_(True)
torch.sigmoid(x2).sum().backward()
print(f"Standard sigmoid gradient: {x2.grad.numpy().round(4)}")
Straight-through estimator (for discrete operations)
A classic custom gradient: make a discrete/non-differentiable operation differentiable by using the identity gradient:
import torch
class StraightThrough(torch.autograd.Function):
"""
Round in forward pass; identity (pass gradient unchanged) in backward pass.
Used for:
- Quantization-aware training
- Discretizing continuous relaxations
- Binary neural networks
"""
@staticmethod
def forward(ctx, x):
return torch.round(x) # discrete: 0.7 → 1.0, 0.3 → 0.0
@staticmethod
def backward(ctx, grad_output):
# Straight-through: pass gradient as if rounding was identity
return grad_output # no modification!
# Usage in quantization-aware training
x = torch.tensor([0.3, 0.7, 1.2, 1.8], requires_grad=True)
x_quantized = StraightThrough.apply(x)
print(f"Forward (rounded): {x_quantized.detach()}")
loss = x_quantized.sum()
loss.backward()
print(f"Backward (straight-through): {x.grad}") # [1, 1, 1, 1]
# Gradient flows as if rounding was identity - enables training
Numerically stable custom operations
import torch
class LogSumExp(torch.autograd.Function):
"""
Numerically stable log-sum-exp: log(Σ exp(xᵢ)).
Naive implementation overflows for large x.
Stable: shift by max(x) before computing exp.
"""
@staticmethod
def forward(ctx, x):
x_max = x.max(dim=-1, keepdim=True).values
shifted = x - x_max
exp_shifted = torch.exp(shifted)
sum_exp = exp_shifted.sum(dim=-1, keepdim=True)
log_sum = x_max + torch.log(sum_exp)
ctx.save_for_backward(x, log_sum)
return log_sum.squeeze(-1)
@staticmethod
def backward(ctx, grad_output):
x, log_sum = ctx.saved_tensors
# Gradient: softmax(x)
softmax = torch.exp(x - log_sum)
return grad_output.unsqueeze(-1) * softmax
# Test
x = torch.tensor([[1000.0, 1001.0, 1002.0]], requires_grad=True)
# Naive: overflows
try:
naive = torch.log(torch.exp(x).sum(dim=-1))
print(f"Naive: {naive}")
except Exception as e:
print(f"Naive fails: overflow → NaN")
# Stable custom
stable = LogSumExp.apply(x)
print(f"Stable: {stable}")
stable.sum().backward()
print(f"Gradient (softmax): {x.grad}")
Part 6 - Memory Optimization: Gradient Checkpointing
In training large models, intermediate activations from the forward pass consume most GPU memory (they must be saved for the backward pass). Gradient checkpointing trades memory for compute.
How it works
Instead of saving all intermediate activations:
- Divide the computation graph into segments ("checkpoints")
- During backward, recompute the forward pass for each segment on the fly
- Only the checkpoint inputs need to be saved, not all intermediates
Memory savings: O(√n) instead of O(n) for a graph with n operations (optimal checkpointing) Compute overhead: ~33% extra forward compute
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
class TransformerLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Attention + residual
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
# FFN + residual
x = self.norm2(x + self.ff(x))
return x
class ModelWithCheckpointing(nn.Module):
def __init__(self, n_layers: int, d_model: int = 512, n_heads: int = 8, d_ff: int = 2048):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayer(d_model, n_heads, d_ff) for _ in range(n_layers)
])
def forward(self, x, use_checkpoint: bool = False):
for layer in self.layers:
if use_checkpoint:
# Recompute layer during backward instead of storing activations
x = checkpoint.checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return x
# Compare memory usage
def measure_memory_usage(use_checkpoint: bool):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
model = ModelWithCheckpointing(n_layers=12, d_model=512)
model = model.cuda()
batch_size, seq_len, d_model = 8, 512, 512
x = torch.randn(batch_size, seq_len, d_model, device='cuda')
output = model(x, use_checkpoint=use_checkpoint)
loss = output.sum()
loss.backward()
peak_memory = torch.cuda.max_memory_allocated() / 1024**3 # GB
return peak_memory
# Uncomment to run on GPU:
# mem_normal = measure_memory_usage(use_checkpoint=False)
# mem_ckpt = measure_memory_usage(use_checkpoint=True)
# print(f"Without checkpointing: {mem_normal:.2f} GB")
# print(f"With checkpointing: {mem_ckpt:.2f} GB")
# print(f"Memory savings: {(1 - mem_ckpt/mem_normal)*100:.1f}%")
print("Gradient checkpointing saves ~40-60% memory at cost of ~33% extra compute.")
print("Essential for training large language models within GPU memory constraints.")
Part 7 - torch.no_grad() and When to Use It
What torch.no_grad() does
Inside torch.no_grad() context:
- No computational graph is built for operations
- Tensors marked with
requires_grad=Truedo not accumulate gradients - Saves significant memory (no need to store intermediate activations)
- Speeds up inference by ~30-50%
import torch
import torch.nn as nn
model = nn.Linear(100, 10)
X = torch.randn(32, 100)
# For inference/evaluation: always use no_grad
with torch.no_grad():
output = model(X)
predictions = output.argmax(dim=1)
# No graph built, no activations stored
# output.grad_fn is None
print(f"output.grad_fn: {output.grad_fn}") # None
# For training: don't use no_grad
output_train = model(X)
print(f"output.grad_fn: {output_train.grad_fn}") # LinearBackward or similar
When to use torch.no_grad()
import torch
import torch.nn as nn
class TrainingLoop:
def __init__(self, model, optimizer):
self.model = model
self.optimizer = optimizer
def train_step(self, X, y):
"""Training: need gradients."""
# DO NOT use no_grad here
self.optimizer.zero_grad()
output = self.model(X)
loss = nn.CrossEntropyLoss()(output, y)
loss.backward()
self.optimizer.step()
return loss.item()
@torch.no_grad() # decorator form - applies to entire method
def evaluate(self, X, y):
"""Evaluation: no gradients needed."""
output = self.model(X)
loss = nn.CrossEntropyLoss()(output, y)
accuracy = (output.argmax(1) == y).float().mean()
return loss.item(), accuracy.item()
@torch.no_grad()
def predict(self, X):
"""Inference: no gradients needed."""
return self.model(X).softmax(dim=1)
def compute_gradients_for_analysis(self, X, y):
"""Computing gradients but not updating: use grad but no optimizer."""
self.optimizer.zero_grad()
output = self.model(X)
loss = nn.CrossEntropyLoss()(output, y)
loss.backward() # need backward here to get gradients
grads = {name: param.grad.clone() for name, param in self.model.named_parameters()
if param.grad is not None}
return grads
torch.inference_mode() - faster than no_grad()
import torch
# torch.inference_mode() is stronger than no_grad():
# - Also disables view tracking (additional overhead reduction)
# - Tensors created within cannot later be used to compute gradients
# - Recommended for production inference (faster than no_grad)
@torch.inference_mode()
def batch_inference(model, X: torch.Tensor) -> torch.Tensor:
"""Production inference: maximum speed."""
return model(X).softmax(dim=1)
Part 8 - Debugging Gradient Issues
import torch
import torch.nn as nn
def check_gradient_flow(model: nn.Module, loss: torch.Tensor) -> dict:
"""
Check gradient magnitudes through all layers of a model.
Useful for detecting vanishing/exploding gradients.
"""
# Call backward if not already called
stats = {}
for name, param in model.named_parameters():
if param.grad is None:
stats[name] = {'status': 'NO GRADIENT', 'norm': 0.0}
else:
grad_norm = param.grad.data.norm(2).item()
param_norm = param.data.norm(2).item()
stats[name] = {
'status': 'ok',
'grad_norm': grad_norm,
'param_norm': param_norm,
'grad_param_ratio': grad_norm / (param_norm + 1e-8)
}
return stats
def find_nan_gradients(model: nn.Module) -> list:
"""Find parameters with NaN or Inf gradients."""
bad_params = []
for name, param in model.named_parameters():
if param.grad is not None:
if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
bad_params.append(name)
return bad_params
# Example usage
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5),
)
X = torch.randn(16, 10)
y = torch.randint(0, 5, (16,))
loss = nn.CrossEntropyLoss()(model(X), y)
loss.backward()
stats = check_gradient_flow(model, loss)
print("Gradient flow analysis:")
for name, s in stats.items():
if s['status'] == 'ok':
print(f" {name}: grad_norm={s['grad_norm']:.2e}, "
f"param_norm={s['param_norm']:.2e}, "
f"ratio={s['grad_param_ratio']:.2e}")
else:
print(f" {name}: {s['status']}")
nan_params = find_nan_gradients(model)
print(f"\nNaN/Inf gradients in: {nan_params if nan_params else 'none'}")
Part 9 - Common AD Mistakes
:::danger In-place operations on tensors that need gradients In-place operations overwrite values that autograd may need for the backward pass.
# WRONG
x = torch.randn(3, requires_grad=True)
y = x ** 2
x += 1 # in-place! modifies x's storage; autograd stored reference to old x
y.sum().backward() # may fail or give wrong result
# RIGHT
x = torch.randn(3, requires_grad=True)
y = x ** 2
x = x + 1 # creates new tensor
y.sum().backward() # works correctly
:::
:::warning Detaching tensors incorrectly
.detach() creates a tensor that shares storage but does not participate in gradient computation. Modifying a detached tensor may cause silent incorrect gradients.
# Correct use: detach for metric computation (don't want gradient through it)
loss_value = loss.detach().item() # Python float, no graph
# Correct use: stop gradient through a part of the network (e.g., target network in RL)
with torch.no_grad():
target = target_network(state) # no gradient through target network
# Wrong use: detaching then expecting gradients later
x = torch.randn(3, requires_grad=True)
y = x.detach() # y has no grad_fn
z = y ** 2
z.sum().backward() # x.grad is None - gradient was cut!
:::
:::tip Debugging with anomaly_detection
PyTorch can give you a more informative traceback when NaN gradients occur:
torch.autograd.set_detect_anomaly(True) # Enable before training
# Now backward() will show where the NaN first appeared
# Disable for production (it's slow):
torch.autograd.set_detect_anomaly(False)
:::
Interview Questions
Q1: What is the difference between automatic differentiation, numerical differentiation, and symbolic differentiation?
Symbolic differentiation (Mathematica, SymPy): manipulates mathematical expressions to produce a closed-form derivative expression. Exact but can produce exponentially large expressions for complex functions. Not suitable for arbitrary code.
Numerical differentiation (finite differences): f'(x) ≈ (f(x+h) - f(x-h))/(2h). Easy to implement but approximate (error O(h²)), slow (O(d) evaluations for d parameters), and numerically unstable for small h.
Automatic differentiation: applies the chain rule to each elementary operation mechanically. Exact (no approximation), efficient (one backward pass for all d gradients in reverse mode), works for any computation expressible in the framework.
AD is not approximation. loss.backward() in PyTorch gives the exact gradient of your loss, not an estimate.
Q2: Explain forward mode vs reverse mode AD, and why ML uses reverse mode.
Both apply the chain rule exactly. They differ in traversal direction.
Forward mode: Propagates a "tangent" vector alongside each value. To compute ∂y/∂xᵢ for one input, evaluate once with seed vector eᵢ. For d inputs and m outputs, requires d forward sweeps. O(d·cost_of_forward). Efficient when d << m (few inputs, many outputs).
Reverse mode: Runs one forward pass (storing all intermediates), then one backward pass computing ∂L/∂xᵢ for ALL inputs simultaneously via VJPs. For d inputs and 1 output (scalar loss), costs 1 backward sweep. O(cost_of_forward). Efficient when m << d (few outputs, many inputs).
In ML: We have d = millions of parameters (inputs to loss), m = 1 (scalar loss). Reverse mode is the obvious choice - one backward pass computes all gradients simultaneously.
Memory trade-off: Reverse mode must store all intermediate values from the forward pass. Forward mode has O(1) overhead but O(d) total cost.
Q3: What does `requires_grad=True` do in PyTorch, and what is a leaf tensor?
requires_grad=True: Tells PyTorch to track all operations on this tensor in the computational graph. When loss.backward() is called, PyTorch will compute the gradient of the loss with respect to this tensor.
Leaf tensors: Tensors created directly by the user (not by operations on other tensors). Parameters in nn.Module are leaf tensors. Non-leaf tensors (results of operations) are intermediates in the graph.
PyTorch stores .grad automatically only for leaf tensors with requires_grad=True. For non-leaf tensors, you must call .retain_grad() explicitly.
Memory implication: After backward, PyTorch frees the computational graph and all non-leaf intermediate tensors. Only leaf tensor .grad values are retained. This is why calling backward() twice raises an error - the graph was freed.
Practical use: nn.Module.parameters() returns all leaf parameter tensors. optimizer.zero_grad() zeros .grad for all leaf tensors in the optimizer's parameter groups.
Q4: When would you implement a custom `torch.autograd.Function`?
Use torch.autograd.Function when:
-
Custom mathematical operation not in PyTorch: e.g., a specialized activation function, a differentiable rendering operation, or a physics simulation step
-
Numerically stable backward that differs from naive chain rule: The log-sum-exp example - the naive gradient computation overflows; the custom backward uses the numerically stable form
-
Straight-through estimator: For discrete operations (rounding, argmax, quantization) that have zero or undefined gradient, replace with a surrogate gradient in backward
-
Memory-efficient backward: When the forward computation requires intermediate values that are expensive to store, but the gradient can be recomputed efficiently from the output
-
Third-party library integration: Wrapping CUDA kernels, C++ extensions, or specialized linear algebra routines
The key requirement: implement forward (returns output, saves needed values in ctx) and backward (receives upstream gradient, returns input gradients via chain rule).
Q5: What is gradient checkpointing and when should you use it?
Gradient checkpointing reduces memory usage by trading compute for memory. Instead of storing all intermediate activations from the forward pass (needed for backward), it:
- Stores only "checkpoint" tensors at segment boundaries
- During backward, re-runs the forward computation for each segment to recompute the intermediate activations on the fly
Memory complexity: Without checkpointing, storing all activations for a model with n layers requires O(n) memory. With optimal checkpointing (at every √n layer), requires O(√n) memory.
Compute overhead: Each segment's forward is run twice (once during forward, once during backward) - approximately 33% extra compute.
When to use it:
- Training large language models (GPT, BERT, LLaMA) where GPU memory is the bottleneck
- Whenever the batch size you want exceeds available GPU memory even with the smallest model that fits inference
- Training with very long sequences (attention memory scales quadratically with sequence length)
When NOT to use it:
- Inference (no backward needed, no activations to store)
- Models that fit comfortably in GPU memory (saves memory but wastes compute)
In PyTorch: torch.utils.checkpoint.checkpoint(function, *inputs) - wraps any module or function.
Quick Reference
| Concept | Forward Mode | Reverse Mode |
|---|---|---|
| Direction | Input → output | Output → input |
| Storage | O(1) per sweep | O(n) intermediate values |
| Cost for d inputs, 1 output | O(d · forward) | O(forward) |
| Efficient for | Few inputs, many outputs | Many inputs, few outputs (ML) |
| PyTorch | - | autograd (default) |
| JAX | jvp | vjp / grad |
| PyTorch Concept | What It Does |
|---|---|
requires_grad=True | Track tensor in computational graph |
.grad | Accumulated gradient (leaf tensors only) |
.backward() | Reverse mode pass; frees graph |
retain_graph=True | Keep graph for multiple backward passes |
torch.no_grad() | Disable graph building (inference) |
torch.inference_mode() | Maximum speed inference (no view tracking) |
.detach() | New tensor, same data, no grad |
ctx.save_for_backward(...) | Save tensors needed for backward |
checkpoint.checkpoint(fn, x) | Gradient checkpointing |
Key Takeaways
- Automatic differentiation is not approximation - it applies the chain rule to elementary operations and gives exact gradients
- Forward mode AD propagates tangent vectors from input to output; efficient for few inputs, many outputs
- Reverse mode AD (used by PyTorch) propagates gradients from output to inputs in one backward pass; efficient for many inputs, one scalar loss - the ML setting
- PyTorch builds a dynamic computational graph during the forward pass;
backward()traverses it in reverse topological order - Only leaf tensors with
requires_grad=Truestore.grad; intermediate tensors need.retain_grad() - Custom
torch.autograd.Functionis for non-standard operations, numerically stable custom backward, or straight-through estimators - Gradient checkpointing trades ~33% extra compute for O(√n) memory, essential for large model training
- Use
torch.no_grad()(ortorch.inference_mode()) for all inference/evaluation code
Next: Optimization Algorithms Deep Dive →
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Automatic Differentiation demo on the EngineersOfAI Playground - no code required.
:::
