PyTorch Foundations
It is 2018. A team of researchers at Google Brain is trying to implement a novel attention mechanism that needs to branch differently based on the sequence length of each sample in the batch. In TensorFlow 1.x, this is nearly impossible: the computational graph is defined statically before any data flows through it. To implement dynamic behavior, you have to use tf.while_loop and tf.cond, which compile into the static graph at definition time and are notoriously difficult to debug. When something goes wrong, the stack trace points to graph construction code, not to the line in your research code where the actual logic lives.
A researcher on the team opens a pull request that ports the implementation to PyTorch. The same conditional logic is now just an if statement in Python. The debugging session that took two days in TensorFlow takes twenty minutes: they attach pdb, set a breakpoint inside forward(), and inspect the tensor values directly. They add print(x.shape) in the middle of the network - it just works, because PyTorch builds the computation graph dynamically as each Python line executes, not at definition time.
This is the mental model shift that made PyTorch dominate ML research within two years. In TensorFlow 1.x, you define the graph, then execute the graph - two separate phases, two separate mental models, two separate debugging strategies. In PyTorch, definition and execution happen simultaneously. The computation graph is a by-product of running your Python code. You can use ordinary Python debuggers, print statements, conditional logic, and data-dependent control flow. The "define-by-run" model did not just make PyTorch more convenient - it changed what kinds of models researchers were willing to try.
By 2020, the majority of papers published at NeurIPS and ICML used PyTorch. By 2022, even Google's own research teams were publishing PyTorch code alongside TensorFlow. The framework that wins researchers eventually wins production.
Why Dynamic Graphs Changed Research
The static graph approach in TensorFlow 1.x required you to think about computation differently. You were essentially writing a program that compiled into a computation graph, then handed that graph to a runtime executor. This meant:
- Debugging required
tf.Printnodes inserted into the graph, or running atf.Sessionin a special evaluation mode - Dynamic shapes (where the graph structure changes based on data) required special primitives like
tf.while_loop - A bug in your network's logic might only manifest at graph execution time, far from where the bug was introduced
- Experimentation was slow because graph recompilation added overhead between iterations
PyTorch's "define-by-run" approach means the computation graph is built fresh on every forward pass. Each tensor operation adds a node to the graph. The graph only exists in the context of a single forward pass - it is not a persistent data structure you compile ahead of time. This makes every forward pass debuggable with standard Python tools.
The cost is that dynamic graphs are harder to optimize and deploy. This is why PyTorch 2.0 introduced torch.compile() - a way to get static-graph-style optimization while keeping dynamic-graph-style authoring. The best of both worlds.
Tensors: The Fundamental Data Structure
A PyTorch tensor is a multi-dimensional array with additional machinery for automatic differentiation and GPU acceleration. It is conceptually equivalent to a NumPy array with a tracked computation history.
import torch
import numpy as np
# Creation
a = torch.tensor([1.0, 2.0, 3.0]) # from list - float32
b = torch.tensor([[1, 2], [3, 4]]) # from nested list - int64
c = torch.zeros(3, 4, dtype=torch.float32)
d = torch.ones(5, dtype=torch.float64)
e = torch.randn(100, 20) # standard normal, float32
f = torch.arange(0, 10, 2) # [0, 2, 4, 6, 8]
g = torch.linspace(0, 1, 101)
# From NumPy (shares memory - no copy)
arr = np.array([1.0, 2.0, 3.0])
t = torch.from_numpy(arr) # view into same memory
t[0] = 99.0
print(arr[0]) # 99.0 - same memory
# To NumPy
arr_back = t.numpy() # also a view
# Shape inspection
print(e.shape) # torch.Size([100, 20])
print(e.dtype) # torch.float32
print(e.device) # cpu
print(e.numel()) # 2000 - total elements
# Reshaping
x = torch.randn(12)
x2d = x.reshape(3, 4) # or x.view(3, 4) - view when possible
x3d = x.unsqueeze(0) # (1, 12) - add batch dimension
xsq = x2d.squeeze() # remove size-1 dimensions
Dtypes and their ML use
| Dtype | Bits | When to use |
|---|---|---|
torch.float32 | 32 | Default for training |
torch.float16 | 16 | Mixed-precision (FP16) - needs GradScaler |
torch.bfloat16 | 16 | Mixed-precision on A100/H100 - better range |
torch.float64 | 64 | Numerical stability tests only |
torch.int64 | 64 | Token IDs, labels, indices |
torch.bool | 8 | Masks, attention masks |
# Explicit dtype control
x = torch.randn(10, dtype=torch.float16)
y = x.float() # → float32
z = x.half() # → float16
w = x.to(torch.bfloat16)
# Type checking
print(x.is_floating_point()) # True
print(x.dtype == torch.float16) # True
view vs reshape: memory layout matters
This distinction comes up often in interviews and causes real bugs.
x = torch.randn(4, 6) # shape (4, 6), contiguous in memory
# view: requires contiguous memory, returns a view (shares memory)
y = x.view(6, 4) # shape (6, 4), shares data with x
y[0, 0] = 999.0
print(x[0, 0]) # 999.0 - same memory
# After transpose, memory is non-contiguous
x_t = x.t() # shape (6, 4), non-contiguous
try:
z = x_t.view(4, 6) # RuntimeError: view size is not compatible
except RuntimeError as e:
print(f"view failed: {e}")
# reshape: handles non-contiguous tensors by copying when necessary
z = x_t.reshape(4, 6) # works - copies data if needed
z2 = x_t.contiguous().view(4, 6) # explicit: make contiguous, then view
# Check contiguity
print(x.is_contiguous()) # True
print(x_t.is_contiguous()) # False
The rule: use view when you know the tensor is contiguous and you want to guarantee no copy. Use reshape when you don't know or don't care. In performance-critical code (e.g., inside a training loop that runs millions of times), avoiding unnecessary copies matters.
Device Management: CPU, CUDA, MPS
Moving tensors between devices is explicit in PyTorch. This explicitness is a feature - you always know where your data is.
# Detect available device
device = (
'cuda' if torch.cuda.is_available()
else 'mps' if torch.backends.mps.is_available() # Apple Silicon
else 'cpu'
)
print(f"Using device: {device}")
# Move tensors to device
x = torch.randn(100, 20)
x_gpu = x.to(device) # or x.cuda() / x.cpu()
x_gpu = x.to('cuda:0') # specific GPU
# Create tensor directly on device
w = torch.randn(20, 10, device=device)
b = torch.zeros(10, device=device)
# Confirm device
print(x_gpu.device) # cuda:0
# Move back to CPU for NumPy conversion
x_cpu = x_gpu.cpu().numpy() # MUST be on CPU for .numpy()
Device best practice: define device once, pass everywhere
# At the top of your training script:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create model on device
model = MyModel().to(device)
# Move batch to device in training loop
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
output = model(batch_x)
:::warning Cross-device operations
Operations between tensors on different devices raise a RuntimeError. If you see Expected all tensors to be on the same device, find where a tensor wasn't moved - the model, a label, or a mask.
:::
Autograd Deep Dive: How Automatic Differentiation Works
PyTorch builds a computation graph dynamically as operations execute. When you call .backward(), it walks this graph in reverse to compute gradients for every tensor with requires_grad=True.
requires_grad and grad_fn
# Enable gradient tracking
x = torch.tensor([2.0, 3.0], requires_grad=True)
# Forward pass - builds computation graph
y = x ** 2 # y = [4, 9], grad_fn=PowBackward
z = y.sum() # z = 13, grad_fn=SumBackward
# Backward pass - computes gradients
z.backward()
# Access gradients
print(x.grad) # tensor([4., 6.]) - dz/dx = 2*x
Each operation that involves a requires_grad=True tensor creates a grad_fn node. You can inspect the graph by following grad_fn references.
a = torch.tensor(3.0, requires_grad=True)
b = a * 2 # b.grad_fn = MulBackward0
c = b + 1 # c.grad_fn = AddBackward0
d = torch.log(c) # d.grad_fn = LogBackward0
# d = log(2a + 1)
# dd/da = 1/(2a+1) * 2 = 2/7 ≈ 0.2857 at a=3
d.backward()
print(a.grad) # tensor(0.2857)
# Inspect the graph structure
print(d.grad_fn) # <LogBackward0 object at 0x...>
print(d.grad_fn.next_functions) # ((AddBackward0, 0),)
print(d.grad_fn.next_functions[0][0].next_functions) # ((MulBackward0, 0),)
# Intermediate nodes: requires_grad propagates automatically
print(b.requires_grad) # True - leaf's grad propagates
print(b.is_leaf) # False - created by an operation
print(a.is_leaf) # True - user-created with requires_grad=True
The computation graph: forward builds, backward traverses
The critical insight is that the computation graph is built during the forward pass and traversed during the backward pass. It is not a static structure - it is rebuilt on every forward pass, which is what makes dynamic control flow possible.
PyTorch applies this chain rule automatically by storing the local Jacobian (or gradient function) at each node in the graph. During backward(), it traverses the graph in reverse topological order, multiplying the incoming gradient by the local gradient at each node.
# Visualize the gradient flow for a simple network
import torch
import torch.nn as nn
x = torch.randn(3, requires_grad=True) # input
# Simple computation: L = sum(relu(Wx + b)^2)
W = torch.randn(4, 3, requires_grad=True)
b = torch.zeros(4, requires_grad=True)
h = torch.relu(W @ x + b) # shape (4,)
L = (h ** 2).sum() # scalar
# Forward pass is complete - graph is built
# Each intermediate tensor knows its grad_fn
L.backward()
# Now check gradients
print(f"x.grad shape: {x.grad.shape}") # (3,)
print(f"W.grad shape: {W.grad.shape}") # (4, 3)
print(f"b.grad shape: {b.grad.shape}") # (4,)
# h.grad is None - intermediate nodes don't retain gradients by default
# (saves memory - you usually don't need intermediate gradients)
print(f"h.grad: {h.grad}") # None
retain_graph: why it's expensive
By default, the computation graph is freed after backward() to release memory. If you call backward() twice on the same graph, you get a RuntimeError.
x = torch.tensor([2.0], requires_grad=True)
y = x ** 3 # y = 8, dy/dx = 3x^2 = 12
y.backward() # graph is freed here
print(x.grad) # tensor([12.])
try:
y.backward() # RuntimeError: graph freed
except RuntimeError as e:
print(f"Error: {e}")
# retain_graph=True: keep the graph for multiple backward passes
x.grad.zero_()
y = x ** 3
y.backward(retain_graph=True) # graph is NOT freed
print(x.grad) # tensor([12.])
y.backward() # second backward - works now
print(x.grad) # tensor([24.]) - accumulated!
retain_graph=True is expensive because it prevents the intermediate activations from being freed after the backward pass. The graph stays in memory. This is occasionally needed for multi-loss training (where you call loss1.backward(retain_graph=True) then loss2.backward()), but it doubles the memory footprint of the backward pass.
Gradient accumulation and zeroing
x = torch.tensor([1.0, 2.0], requires_grad=True)
loss1 = (x ** 2).sum()
loss1.backward()
print(x.grad) # tensor([2., 4.])
loss2 = (x ** 2).sum()
loss2.backward()
print(x.grad) # tensor([4., 8.]) - ACCUMULATED, not replaced!
# Zero before each backward pass
x.grad.zero_()
loss3 = (x ** 2).sum()
loss3.backward()
print(x.grad) # tensor([2., 4.]) - correct
This accumulation behavior is why optimizer.zero_grad() must be called before each backward pass in the training loop. It is also intentionally used for gradient accumulation to simulate larger batch sizes:
accumulation_steps = 4
optimizer.zero_grad()
for i, (batch_x, batch_y) in enumerate(dataloader):
output = model(batch_x)
loss = criterion(output, batch_y) / accumulation_steps # scale loss
loss.backward() # gradients accumulate in .grad
if (i + 1) % accumulation_steps == 0:
optimizer.step() # update weights with accumulated gradients
optimizer.zero_grad() # reset for next accumulation window
# Effect: equivalent to training with batch_size * accumulation_steps batch size
# Useful when you can't fit a large batch in GPU memory
Detaching tensors from the graph
Sometimes you want to use a tensor's value without propagating gradients through it.
# .detach(): creates a new tensor that shares data but has no grad_fn
x = torch.randn(3, requires_grad=True)
y = x ** 2
y_detached = y.detach() # no grad_fn - gradients stop here
print(y_detached.requires_grad) # False
print(y_detached.grad_fn) # None
# Common use: computing metrics without building a gradient graph
with torch.no_grad():
predictions = model(val_x)
# equivalent to calling .detach() on all outputs
# Another common use: target networks in reinforcement learning
# The target network's outputs should NOT flow gradients back
target_output = target_net(state).detach() # used as a label, not differentiated through
loss = F.mse_loss(main_net(state), target_output)
loss.backward() # gradients flow through main_net only
The Training Loop: Forward, Backward, Step
The standard PyTorch training loop uses optimizer.zero_grad(), loss.backward(), and optimizer.step() in that order.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# Minimal training loop
def train_one_epoch(model, dataloader, optimizer, criterion, device):
model.train()
total_loss = 0.0
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
# 1. Zero gradients from previous batch
optimizer.zero_grad()
# 2. Forward pass - builds computation graph
output = model(batch_x)
# 3. Compute loss
loss = criterion(output, batch_y)
# 4. Backward pass - compute gradients
loss.backward()
# 5. Update weights
optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
nn.Module: Building Blocks of All PyTorch Models
Every PyTorch model is an nn.Module. The Module API provides parameter management, device movement, serialization, training/eval mode switching, and composition.
import torch
import torch.nn as nn
class LinearModel(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
# nn.Linear registers weight and bias as Parameters automatically
self.fc = nn.Linear(in_features, out_features)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.fc(x))
model = LinearModel(20, 10)
# Inspect parameters
for name, param in model.named_parameters():
print(f"{name}: shape={param.shape}, requires_grad={param.requires_grad}")
# fc.weight: shape=torch.Size([10, 20]), requires_grad=True
# fc.bias: shape=torch.Size([10]), requires_grad=True
# Total parameter count
n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total: {n_params:,} | Trainable: {n_trainable:,}")
Composing modules into larger models
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim, dropout=0.3):
super().__init__()
dims = [input_dim] + hidden_dims + [output_dim]
layers = []
for i in range(len(dims) - 1):
layers.append(nn.Linear(dims[i], dims[i+1]))
if i < len(dims) - 2: # not the last layer
layers.append(nn.BatchNorm1d(dims[i+1]))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
model = MLP(input_dim=784, hidden_dims=[512, 256], output_dim=10)
x = torch.randn(32, 784) # batch of 32
output = model(x)
print(output.shape) # (32, 10)
Common nn layers
# Linear layers
nn.Linear(in_features, out_features, bias=True)
# Convolution
nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0)
# Recurrent
nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)
# Normalization
nn.BatchNorm1d(num_features)
nn.LayerNorm(normalized_shape)
# Activation
nn.ReLU()
nn.GELU()
nn.Sigmoid()
nn.Softmax(dim=-1)
# Regularization
nn.Dropout(p=0.5)
# Embedding
nn.Embedding(num_embeddings, embedding_dim, padding_idx=0)
nn.Parameter vs register_buffer
class ModelWithBuffer(nn.Module):
def __init__(self, dim):
super().__init__()
# Parameter: learnable, appears in model.parameters(), saved in state_dict
self.weight = nn.Parameter(torch.randn(dim, dim))
# Buffer: not learnable, moved with .to(device), saved in state_dict
# Use for constants that need to follow the model to GPU
running_mean = torch.zeros(dim)
self.register_buffer('running_mean', running_mean)
# Plain attribute: NOT moved with .to(device), NOT saved
self.some_constant = 42 # fine for plain Python scalars
def forward(self, x):
return (x - self.running_mean) @ self.weight
model = ModelWithBuffer(20)
model.to('cuda')
# model.weight → on cuda
# model.running_mean → on cuda
# model.some_constant → still 42 (int, unaffected)
# State dict includes both parameters and buffers
for key in model.state_dict():
print(key) # 'weight', 'running_mean'
torch.no_grad() and torch.inference_mode()
During inference and validation, you do not need gradients. Disabling the computation graph saves memory and speeds up computation by 20–40%.
model.eval()
# Option 1: no_grad - disables gradient computation
with torch.no_grad():
output = model(x)
# output has no grad_fn - computation graph is not built
# Option 2: inference_mode - stricter, faster (PyTorch 1.9+)
with torch.inference_mode():
output = model(x)
# output cannot be used in a backward pass even if you wanted to
# Which to use:
# - inference_mode() during prediction/serving (faster)
# - no_grad() during validation in training (allows computing gradients after the block if needed)
Memory Management: CUDA Memory Model
GPU memory management is a critical skill for training large models. Understanding how PyTorch allocates and releases GPU memory prevents out-of-memory errors and unnecessary training interruptions.
CUDA memory basics
# Check GPU memory usage
if torch.cuda.is_available():
print(f"Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
print(f"Max allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
# PyTorch uses a caching memory allocator:
# - "Allocated" memory = tensors currently live
# - "Reserved" memory = memory the allocator holds from previous allocations
# (not returned to the OS, ready to be reused for future allocations)
# - Reserved > Allocated is normal - it means the allocator has free blocks
# Release cached memory back to the OS (rarely needed)
torch.cuda.empty_cache()
# This does NOT free allocated tensors - it only releases unneeded cache blocks
# Calling it frequently hurts performance (forces re-allocation)
Common OOM causes and fixes
# CAUSE 1: Accumulating loss across batches without detaching
total_loss = 0
for batch_x, batch_y in dataloader:
loss = criterion(model(batch_x), batch_y)
total_loss += loss # WRONG: keeps the entire computation graph alive
# FIX: use .item() to extract a Python scalar (detaches from graph)
total_loss += loss.item()
# CAUSE 2: Storing predictions with gradients
all_preds = []
for batch_x, _ in dataloader:
preds = model(batch_x)
all_preds.append(preds) # WRONG if preds has grad_fn
# FIX: detach or use no_grad
with torch.no_grad():
preds = model(batch_x)
all_preds.append(preds.cpu()) # also move to CPU to free GPU memory
# CAUSE 3: Not clearing the cache between experiments in a notebook
# FIX:
torch.cuda.empty_cache()
# CAUSE 4: Batch size too large - just reduce it and use gradient accumulation
Gradient checkpointing: training with less memory
Gradient checkpointing trades compute for memory. During the forward pass, instead of storing all intermediate activations (which are needed for backprop), it discards them and recomputes them during the backward pass.
For a network with layers, normal training stores activations. With gradient checkpointing on every layers, memory drops to at the cost of one additional forward pass.
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class CheckpointedBlock(nn.Module):
"""A transformer block with gradient checkpointing."""
def __init__(self, dim):
super().__init__()
self.ln1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads=8, batch_first=True)
self.ln2 = nn.LayerNorm(dim)
self.ff = nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.GELU(),
nn.Linear(4 * dim, dim),
)
def _forward(self, x):
# Self-attention
attn_out, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))
x = x + attn_out
# Feed-forward
x = x + self.ff(self.ln2(x))
return x
def forward(self, x):
# checkpoint re-runs _forward during backward to recompute activations
# instead of storing them - saves ~50% memory for transformer blocks
return checkpoint(self._forward, x, use_reentrant=False)
class TransformerWithCheckpointing(nn.Module):
def __init__(self, dim, n_layers):
super().__init__()
self.blocks = nn.ModuleList([CheckpointedBlock(dim) for _ in range(n_layers)])
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
# Without checkpointing: stores activations for all 24 layers simultaneously
# With checkpointing: recomputes activations during backward - ~2x more compute, ~50% less memory
model = TransformerWithCheckpointing(dim=512, n_layers=24).to('cuda')
:::tip When to use gradient checkpointing Use gradient checkpointing when you are memory-constrained and would otherwise need to reduce batch size. The trade-off: roughly 30-40% slower training (one extra forward pass per backward pass), but you can often double or triple the effective batch size. This is used in production training of large language models - GPT-2 and BERT both use it for training on limited GPU memory. :::
Custom Autograd Functions
Sometimes you need to implement a custom gradient that PyTorch cannot derive automatically. For example, quantization operations (rounding) have a zero gradient almost everywhere, which would prevent learning. The straight-through estimator is a common workaround.
import torch
class StraightThroughQuantize(torch.autograd.Function):
"""
Quantize activations to {0, 1} in the forward pass.
Use straight-through gradient in the backward pass (pretend it was identity).
This is used in binary neural networks: the forward pass uses discrete values,
but the backward pass passes gradients through as if the operation was identity.
"""
@staticmethod
def forward(ctx, x):
# Quantize: threshold at 0 → binary output
return (x > 0).float()
@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator: pass gradient through unchanged
# (pretend the quantization was an identity function)
return grad_output # no clipping here - full gradient passes through
# Usage
quantize = StraightThroughQuantize.apply
x = torch.randn(5, requires_grad=True)
y = quantize(x) # forward: binary values {0, 1}
loss = y.sum()
loss.backward()
print(x.grad) # all 1s - gradient passes straight through
# More common variant: clip the straight-through gradient to [-1, 1]
class ClippedStraightThrough(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return (x > 0).float()
@staticmethod
def backward(ctx, grad_output):
return grad_output.clamp(-1, 1)
Custom gradient for numerical stability
Another use case: implementing a numerically stable operation where PyTorch's automatic differentiation would produce NaN or Inf for edge cases.
class StableLogSigmoid(torch.autograd.Function):
"""
Numerically stable log(sigmoid(x)).
log(sigmoid(x)) = log(1 / (1 + e^{-x})) = -log(1 + e^{-x}) = -softplus(-x)
For large negative x, sigmoid(x) → 0, log(sigmoid(x)) → -inf.
Naive implementation: log(sigmoid(x)) computes sigmoid first → 0 → log(0) = -inf with NaN gradient.
Stable implementation: -softplus(-x) avoids the intermediate zero.
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
# Numerically stable: for x >= 0: -log(1 + e^{-x}), for x < 0: x - log(1 + e^x)
return torch.where(x >= 0, -torch.log1p(torch.exp(-x)), x - torch.log1p(torch.exp(x)))
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
sigmoid_x = torch.sigmoid(x)
return grad_output * (1 - sigmoid_x)
Mixed Precision Training
FP32 uses 32 bits per parameter. FP16 uses 16 bits. Training in FP16 halves memory usage and doubles throughput on Tensor Cores (NVIDIA GPUs from Volta onwards). But FP16 has a much smaller dynamic range than FP32:
- FP32 range: approximately
- FP16 range: approximately
Gradients in deep networks are often very small - on the order of to . In FP16, values smaller than round to zero ("underflow"). This kills training.
GradScaler fixes this: it multiplies the loss by a large scale factor before backprop (making the gradients larger), then divides the gradients by the same factor before the optimizer step (restoring their true magnitude). If the scaled gradients overflow FP16 (produce Inf or NaN), the optimizer step is skipped and the scale factor is reduced.
import torch
import torch.nn as nn
from torch.amp import autocast, GradScaler
device = torch.device('cuda')
model = MLP(784, [512, 256], 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# GradScaler manages the loss scaling factor
scaler = GradScaler()
def train_epoch_amp(model, dataloader, optimizer, criterion, scaler, device):
model.train()
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
optimizer.zero_grad()
# autocast: runs forward pass in FP16 where safe, FP32 elsewhere
with autocast(device_type='cuda', dtype=torch.float16):
output = model(batch_x)
loss = criterion(output, batch_y)
# Scale loss before backward (prevents FP16 gradient underflow)
scaler.scale(loss).backward()
# Unscale gradients and clip (optional but recommended)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# If gradients are finite: optimizer.step(). If Inf/NaN: skip.
scaler.step(optimizer)
# Update the scale factor for next iteration
scaler.update()
bfloat16: better range, similar performance
On A100 and H100 GPUs, bfloat16 (Brain Float 16) is often preferred over float16. It has the same number of exponent bits as float32 (8 bits), giving it the same dynamic range. It has fewer mantissa bits (7 vs 23 in FP32), which means less precision - but in practice, deep learning training is more sensitive to range than precision.
# On A100/H100: use bfloat16 - no GradScaler needed
with autocast(device_type='cuda', dtype=torch.bfloat16):
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
# bfloat16 doesn't underflow because it has FP32's dynamic range
# GradScaler is unnecessary
| Format | Exponent bits | Mantissa bits | Dynamic range | Underflow risk |
|---|---|---|---|---|
| FP32 | 8 | 23 | ~ | None |
| FP16 | 5 | 10 | ~ | High |
| BF16 | 8 | 7 | ~ | None |
:::tip Mixed precision memory savings A 7B parameter model in FP32: bytes = 28 GB. In FP16/BF16: 14 GB. With Adam optimizer states (momentum + variance, each FP32): add another 28 GB. Total for FP32 training: ~84 GB. Total for mixed precision: ~56 GB. This is the difference between needing 4 A100s vs 2 A100s. :::
torch.compile(): PyTorch 2.0 Performance
PyTorch 2.0 introduced torch.compile(), which compiles your model using TorchDynamo (traces Python bytecode to extract the computation graph) and TorchInductor (generates optimized kernels for CPU or CUDA).
import torch
model = MLP(784, [512, 256], 10).to('cuda')
# Compile the model - first call triggers compilation (warm-up)
compiled_model = torch.compile(model)
# Compilation modes:
# - 'default': balance compile time vs runtime speedup
# - 'reduce-overhead': minimize runtime overhead (more compile time)
# - 'max-autotune': maximum performance (very long compile time - for production)
compiled_model = torch.compile(model, mode='reduce-overhead')
# First forward pass: ~10-60 seconds for compilation
x = torch.randn(32, 784, device='cuda')
with torch.no_grad():
out = compiled_model(x) # triggers compilation
# Subsequent passes: significantly faster (20-200% speedup depending on model)
for i in range(100):
with torch.no_grad():
out = compiled_model(x)
# What torch.compile does:
# 1. TorchDynamo: captures Python bytecode, identifies tensor operations
# 2. TorchFX: represents the computation as a graph
# 3. TorchInductor: generates fused CUDA kernels (operator fusion, tiling, etc.)
# The result: fewer kernel launches, better memory access patterns, less Python overhead
TorchScript: ahead-of-time compilation for deployment
For deployment where you cannot have Python, torch.jit.script() compiles your model to TorchScript - a statically typed subset of Python that can run in C++.
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.relu(self.fc(x))
model = SimpleNet()
# Script: statically analyzes Python code
scripted = torch.jit.script(model)
# Trace: records operations for a specific input shape (less flexible)
example_input = torch.randn(1, 10)
traced = torch.jit.trace(model, example_input)
# Save for C++ deployment
scripted.save('model.pt')
# Load in Python or C++
loaded = torch.jit.load('model.pt')
# Difference between script and trace:
# - torch.jit.script: handles if/else, loops, dynamic control flow
# - torch.jit.trace: records ONE execution - misses dynamic behavior
# Use script when your forward() has data-dependent branching
Distributed Training Concepts
When a model or dataset is too large for a single GPU, distributed training splits the work across multiple GPUs or machines.
DataParallel vs DistributedDataParallel
import torch.nn as nn
model = MLP(784, [512, 256], 10)
# DataParallel (DP): simple but slow - uses a single-process, multi-thread approach
# Scatters data to each GPU, gathers outputs to GPU 0, computes loss on GPU 0
# Problem: GPU 0 is a bottleneck (all output gathering happens there)
model_dp = nn.DataParallel(model) # uses all available GPUs
# DistributedDataParallel (DDP): recommended - one process per GPU
# Each process has its own model replica
# After backward, gradients are synchronized via all-reduce operation
# No single-GPU bottleneck - all GPUs contribute equally
# DDP setup (run with: torchrun --nproc_per_node=4 train.py)
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_ddp(rank, world_size):
dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train_ddp(rank, world_size):
setup_ddp(rank, world_size)
model = MLP(784, [512, 256], 10).to(rank)
model = DDP(model, device_ids=[rank]) # wrap with DDP
optimizer = torch.optim.Adam(model.parameters())
# Training loop: identical to single-GPU, DDP handles gradient sync
# After loss.backward(), all-reduce averages gradients across all GPUs
# All GPUs see the same gradient → take the same optimizer step → stay in sync
The all-reduce operation
After each backward pass in DDP, PyTorch performs an "all-reduce" across all processes: each GPU sends its local gradients to all other GPUs, and receives their gradients. Each GPU then averages all received gradients and uses this average for the optimizer step.
where is the number of GPUs (world size) and is the gradient computed on GPU 's mini-batch. The result is equivalent to computing the gradient on a batch times larger, split evenly across GPUs.
# DDP communication is overlapped with backward computation:
# As soon as a bucket of gradients is ready on the backward pass,
# the all-reduce for that bucket starts while the rest of backward continues.
# This hides most of the communication latency.
# When to use each:
# - Single GPU: plain model
# - Multiple GPUs, single machine: DDP (prefer) or DP
# - Multiple GPUs, multiple machines: DDP only (DP doesn't support multi-node)
# - Model too large for one GPU: model parallelism (beyond this lesson)
Reproducibility: Seed Everything
ML results must be reproducible. PyTorch has multiple sources of randomness.
import torch
import numpy as np
import random
import os
def set_seed(seed: int = 42):
"""Set all random seeds for reproducibility."""
random.seed(seed) # Python random
np.random.seed(seed) # NumPy
torch.manual_seed(seed) # PyTorch CPU
torch.cuda.manual_seed(seed) # PyTorch CUDA (single GPU)
torch.cuda.manual_seed_all(seed) # PyTorch CUDA (all GPUs)
os.environ['PYTHONHASHSEED'] = str(seed) # Python hash randomization
# For fully deterministic algorithms (may slow down training)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
:::note Determinism trade-off
Setting cudnn.deterministic = True disables some CUDA optimizations and can make training 10–30% slower. For most experiments, setting the manual seed is sufficient. Use full determinism only when debugging non-reproducible results.
:::
Common Mistakes
:::danger Forgetting zero_grad() between batches
# WRONG: gradients accumulate across batches
for batch_x, batch_y in dataloader:
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step() # using accumulated gradients from ALL previous batches!
# CORRECT:
for batch_x, batch_y in dataloader:
optimizer.zero_grad() # reset before each batch
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
:::
:::danger Storing loss tensor instead of loss.item()
# WRONG: keeps entire computation graph alive for every batch
losses = []
for batch_x, batch_y in dataloader:
loss = criterion(model(batch_x), batch_y)
losses.append(loss) # graph grows with every batch → OOM
# CORRECT:
losses.append(loss.item()) # extracts Python float, detaches from graph
:::
:::warning Calling backward() on a non-scalar without specifying gradient
x = torch.randn(3, requires_grad=True)
y = x ** 2 # shape (3,) - not a scalar
# y.backward() # RuntimeError: grad can be implicitly created only for scalar outputs
# Option 1: reduce to scalar first
y.sum().backward()
# Option 2: provide gradient (for vector-Jacobian product)
y.backward(torch.ones(3)) # equivalent to y.sum().backward()
:::
:::warning view() on non-contiguous tensor
x = torch.randn(4, 6)
x_t = x.t() # transpose - non-contiguous
# x_t.view(24) # RuntimeError: non-contiguous tensor
# Fix 1: make contiguous first
x_t.contiguous().view(24)
# Fix 2: use reshape (handles non-contiguous by copying if needed)
x_t.reshape(24)
:::
YouTube Resources
| Video | Channel | What it covers |
|---|---|---|
| PyTorch Tutorial | Patrick Loeber | Complete PyTorch from scratch |
| Autograd Explained | Elliot Waite | PyTorch autograd internals |
| PyTorch 2.0 compile | PyTorch | torch.compile walkthrough |
| Mixed Precision Training | Weights & Biases | AMP training guide |
Interview Q&A
Q1: How does PyTorch's autograd system work? What is the computation graph?
PyTorch uses dynamic computation graphs - the graph is built on-the-fly as operations execute, one Python line at a time. Each tensor operation that involves a requires_grad=True tensor creates a node in a directed acyclic graph (DAG) and attaches a grad_fn to the output tensor - the backward function for that operation. Leaf tensors (user-created with requires_grad=True) are the starting points. When you call .backward() on a scalar loss, PyTorch traverses this graph in reverse topological order, applying the chain rule at each node to accumulate gradients. The graph is rebuilt fresh on every forward pass - it is not a persistent static structure. This is what allows arbitrary Python control flow (if statements, loops, recursion) inside forward(): the graph traces whatever code path actually executed, not a compiled representation of all possible paths.
Q2: What is the difference between view and reshape? When does view fail?
view creates a new tensor that shares the same underlying storage as the original - no data is copied. It requires the tensor to be contiguous in memory (elements stored in a single unbroken block in the order implied by the shape). reshape is more flexible: it returns a view when possible (same as view), but falls back to making a copy when the tensor is non-contiguous. Non-contiguous tensors arise after operations that change the logical ordering without moving data: transpose (x.t()), permute, expand, narrow, select. The practical rule: use view when you know the tensor is contiguous and you explicitly want to guarantee no copy. Use reshape when you are not sure or do not care. In a training loop, the difference is performance: reshape can introduce unexpected copies if your tensor is non-contiguous, which PyTorch won't warn you about.
Q3: Why must you call optimizer.zero_grad() before each backward pass?
PyTorch accumulates gradients by adding to .grad on each backward() call rather than replacing it. This is intentional - it enables gradient accumulation (running multiple forward/backward passes before a single optimizer.step()) to simulate larger batch sizes than fit in memory. But in a standard training loop, you want fresh gradients for each batch. Forgetting zero_grad() accumulates gradients from all previous batches: after batches without zeroing, the effective gradient is instead of . This makes learning unstable and the loss curves look strange. The canonical order: zero_grad() → forward() → loss.backward() → optimizer.step().
Q4: What is gradient checkpointing and when would you use it?
Gradient checkpointing (implemented in torch.utils.checkpoint.checkpoint) is a memory-compute trade-off: during the forward pass, instead of storing all intermediate activations (which are needed for backprop), it discards them and recomputes them on-demand during the backward pass. For a model with layers, normal backprop stores activations. Checkpointing at every layers reduces memory to at the cost of roughly one extra forward pass per backward pass (about 30-40% slower training). Use it when you are memory-constrained - common scenarios: training large transformers (BERT, GPT-style) on limited GPU memory, or fitting larger batch sizes that wouldn't otherwise fit. The trade-off is almost always worth it when the alternative is reducing batch size to a point where training becomes unstable or very slow.
Q5: When should you use torch.no_grad() vs torch.inference_mode(), and what is the difference?
Both context managers disable gradient computation and stop building the computation graph. The differences: inference_mode() (PyTorch 1.9+) is strictly stronger - tensors created inside it are marked as non-inferrable and cannot later be used in a backward pass even if you move them outside the context. no_grad() merely skips building the graph, but tensors can re-enter a gradient-tracked context. inference_mode() is faster because it can skip more bookkeeping. Practical rule: use inference_mode() for production inference, model export (torch.onnx.export, torch.jit.trace), and benchmark measurements. Use no_grad() during the validation loop inside training when you want the maximum flexibility (e.g., computing gradients for visualization or debugging after the no_grad block, without worrying about whether any tensors were created inside it).
Q6: How does mixed precision training work, and what problem does GradScaler solve?
Mixed precision training runs the forward pass in FP16 (or BF16) to reduce memory and increase throughput on Tensor Cores, but accumulates gradients and maintains the master copy of weights in FP32. The problem: FP16 has a much smaller dynamic range than FP32 (max ~65504 vs ~). Gradients in deep networks are often very small - to - which underflow to zero in FP16, making the model stop learning. GradScaler fixes this by multiplying the loss by a large constant (the "scale factor," initially 65536) before backprop. This makes gradients larger and prevents them from underflowing. Before the optimizer step, the gradients are divided by the same scale factor to restore their true magnitude. If any gradient overflows (becomes Inf or NaN), the optimizer step is skipped and the scale factor is halved. Over time, GradScaler adapts the scale factor to the right magnitude for your model. BF16 avoids this problem entirely because it has the same dynamic range as FP32, making GradScaler unnecessary - this is why BF16 is preferred on A100/H100 GPUs.
Q7: What is the difference between DataParallel and DistributedDataParallel, and why is DDP preferred?
DataParallel is a single-process, multi-thread approach. It replicates the model on each GPU, scatters the input batch across GPUs, runs forward passes in parallel, then gathers all outputs to GPU 0 to compute the loss. GPU 0 runs the backward pass, then scatters the gradients to all other GPUs. The problem: GPU 0 is a bottleneck - it receives all outputs (network traffic), runs the backward pass alone, and manages the gradient scatter. In practice, GPU 0 is significantly more loaded than others, wasting much of the parallelism.
DistributedDataParallel uses one process per GPU. Each process has its own model replica, its own optimizer, and its own data loader. After each backward pass, an all-reduce operation averages the gradients across all processes. Each process then updates its own copy of the weights - because the gradients are identical after all-reduce, all copies stay synchronized. DDP is preferred because: (1) no GPU 0 bottleneck - all GPUs are symmetric, (2) gradient communication is overlapped with the backward computation (as soon as a gradient bucket is ready, its all-reduce starts), (3) it works across multiple machines (DataParallel is single-machine only). The cost: DDP requires torchrun or torch.distributed.launch to spawn multiple processes, which is slightly more complex to set up.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Neural Network Forward Pass demo on the EngineersOfAI Playground - no code required.
:::
