Skip to main content

PyTorch Training Loop

It is 2022. An ML engineer at Meta is building an image classification model to help flag policy-violating content on Instagram before it reaches users. The model is a ResNet-50 fine-tuned on 40 million labeled images. She submits the first training run on a single A100 GPU and estimates it will finish in 72 hours - three full days. The accuracy plots look reasonable in the early epochs, so she leaves it running overnight.

The next morning: the loss is NaN. A single corrupted batch triggered gradient explosion because there was no gradient clipping. The run is dead.

She restarts with clip_grad_norm_. Training is stable now but still taking 3 days. Her manager asks why they are not using automatic mixed precision. She adds AMP - now it finishes in 18 hours. The batch size with AMP was doubled, but GPU memory is still 40% idle. She adds gradient accumulation to double the effective batch size again without touching memory. Suddenly the loss is converging 30% faster because the gradient estimate is more accurate.

Final optimization: the DataLoader had num_workers=0. A quick profile shows 40% GPU idle time between batches. She adds num_workers=8 and pin_memory=True. Training time drops to 4 hours.

Same model. Same GPU. Same data. Four hours instead of seventy-two - purely from training loop engineering. Every line of the training loop matters.

You join an ML team that has a strong model architecture but a flaky training script. Some runs converge, some diverge, and occasionally the loss becomes NaN mid-way. The loss curve looks different every time despite the same seed. After reading the script, you find three bugs: zero_grad() is called after backward() instead of before, the validation loop runs with model.train() still active (Dropout is on), and gradients are not clipped - a single bad batch occasionally causes gradient explosion.

The training loop is not boilerplate. Every line of it matters.

The Correct Order of Operations

Before any code, the order is the most important thing to memorize:

zero_grad → forward → loss → backward → clip_grads → step

Each step has a specific reason:

  1. zero_grad() - PyTorch accumulates gradients by default. Clear them first.
  2. Forward pass - compute predictions
  3. Loss computation - scalar that measures error
  4. backward() - compute gradients via autograd (fills .grad on every parameter)
  5. Gradient clipping - cap gradient norms before they explode
  6. optimizer.step() - update parameters using the computed gradients
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import time

def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
model.train() # enable Dropout, use batch stats in BatchNorm
total_loss = 0.0
n_batches = 0

for batch_x, batch_y in dataloader:
# 1. Move data to device
batch_x = batch_x.to(device, non_blocking=True)
batch_y = batch_y.to(device, non_blocking=True)

# 2. Zero gradients BEFORE forward pass
optimizer.zero_grad()

# 3. Forward pass
logits = model(batch_x)

# 4. Compute loss
loss = loss_fn(logits, batch_y)

# 5. Backward pass - computes gradients
loss.backward()

# 6. (Optional but recommended) Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 7. Optimizer step - update parameters
optimizer.step()

total_loss += loss.item() # .item() detaches from graph, returns Python float
n_batches += 1

return total_loss / n_batches


@torch.no_grad()
def evaluate(model, dataloader, loss_fn, device):
model.eval() # disable Dropout, use running stats in BatchNorm
total_loss = 0.0
all_preds = []
all_targets = []

for batch_x, batch_y in dataloader:
batch_x = batch_x.to(device, non_blocking=True)
batch_y = batch_y.to(device, non_blocking=True)

logits = model(batch_x)
loss = loss_fn(logits, batch_y)

total_loss += loss.item()
preds = logits.argmax(dim=1) # for classification
all_preds.append(preds.cpu())
all_targets.append(batch_y.cpu())

all_preds = torch.cat(all_preds)
all_targets = torch.cat(all_targets)
acc = (all_preds == all_targets).float().mean().item()

return total_loss / len(dataloader), acc

Learning Rate Scheduling

The learning rate is the most sensitive hyperparameter in training. Too high and the loss oscillates or explodes. Too low and training is painfully slow and may get stuck. The insight from years of practice: the optimal learning rate changes over the course of training. You need a high LR early to move quickly toward good regions of the loss landscape, and a low LR late to settle precisely into a minimum.

Learning rate schedulers automate this.

Cosine Annealing

The most widely used schedule for image models and LLMs. The learning rate follows a cosine curve from ηmax\eta_{\max} down to ηmin\eta_{\min}:

ηt=ηmin+12(ηmaxηmin)(1+cos(tπT))\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\frac{t \pi}{T}\right)\right)

where tt is the current step and TT is the total number of steps. At t=0t=0, ηt=ηmax\eta_t = \eta_{\max}. At t=Tt=T, ηt=ηmin\eta_t = \eta_{\min}.

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=n_epochs, # number of epochs for one full cosine cycle
eta_min=1e-6, # minimum LR at the end
)

for epoch in range(n_epochs):
train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
scheduler.step() # called AFTER optimizer.step()
print(f"LR: {scheduler.get_last_lr()}")

OneCycleLR - Fast Training

Developed by Leslie Smith, OneCycleLR uses a single cycle: ramp up from a low LR to a peak, then decay aggressively. It was found to achieve near state-of-the-art accuracy in just a few epochs. Fast.ai made it famous.

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=0.1,
steps_per_epoch=len(train_loader),
epochs=n_epochs,
pct_start=0.3, # 30% of training: warmup phase
anneal_strategy='cos', # cosine annealing in decay phase
)

for epoch in range(n_epochs):
model.train()
for batch_x, batch_y in train_loader:
optimizer.zero_grad()
loss = loss_fn(model(batch_x), batch_y)
loss.backward()
optimizer.step()
scheduler.step() # OneCycleLR: called EVERY BATCH, not every epoch

ReduceLROnPlateau - Adaptive

Monitors a metric (usually validation loss). If the metric does not improve for patience epochs, the LR is multiplied by factor. Simple and effective for situations where you do not know the right schedule in advance.

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min', # 'min' for loss, 'max' for accuracy
factor=0.5, # multiply LR by 0.5 when plateau detected
patience=5, # wait 5 epochs before reducing
min_lr=1e-6,
verbose=True,
)

for epoch in range(n_epochs):
train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
val_loss, val_acc = evaluate(model, val_loader, loss_fn, device)
scheduler.step(val_loss) # pass the metric to monitor

CosineAnnealingWarmRestarts - Cyclic with Restarts

Periodically resets the LR back to ηmax\eta_{\max}, allowing the model to escape local minima. Each restart can be longer than the previous (set T_mult > 1).

ηt=ηmin+12(ηmaxηmin)(1+cos(TcurTiπ))\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\frac{T_{cur}}{T_i} \pi\right)\right)

where TcurT_{cur} is epochs since last restart and TiT_i is the length of the current cycle.

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=10, # restart every 10 epochs initially
T_mult=2, # each restart is twice as long: 10, 20, 40 epochs...
eta_min=1e-6,
)

Warmup + Cosine Decay (Standard for LLM Training)

The standard recipe for transformer and LLM training is a linear warmup phase followed by cosine decay. The warmup prevents instability in the early steps when gradients are largest. This pattern is used by GPT, BERT, LLaMA, and virtually every large-scale transformer.

import math

def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps, min_lr_ratio=0.1):
"""
Linear warmup followed by cosine decay.
Used by virtually all modern LLM training runs.
"""
def lr_lambda(current_step):
if current_step < warmup_steps:
# Linear warmup: 0 → 1 over warmup_steps
return float(current_step) / float(max(1, warmup_steps))

# Cosine decay from 1 → min_lr_ratio after warmup
progress = float(current_step - warmup_steps) / float(
max(1, total_steps - warmup_steps)
)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
# Scale to [min_lr_ratio, 1.0]
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay

return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


# Usage: 2000 warmup steps, 100K total training steps
scheduler = get_cosine_schedule_with_warmup(
optimizer,
warmup_steps=2000,
total_steps=100_000,
min_lr_ratio=0.1, # decay to 10% of peak LR
)

When to Use Which Scheduler

SchedulerBest ForKey Characteristic
CosineAnnealingLRImage classification, general trainingSmooth decay, well-understood
OneCycleLRFast convergence, limited compute budgetSuperconvergence in few epochs
ReduceLROnPlateauUnknown schedule, early experimentsAdaptive, but needs validation set
CosineAnnealingWarmRestartsAvoiding local minima, ensemble snapshotsPeriodic resets, exploratory
Warmup + cosineLLMs, transformers, any large modelIndustry standard for scale

Gradient Clipping

Gradient clipping prevents gradient explosion - when a single bad batch sends gradients to extremely large values, causing a catastrophic parameter update that destroys training progress. It is especially critical for:

  • RNNs and LSTMs: gradients that flow back through many time steps can compound multiplicatively
  • Transformers: the attention mechanism can produce high-variance gradients early in training
  • Deep networks with skip connections: even with residuals, certain configurations are prone to large gradient spikes

The standard approach is L2 norm clipping: if the global gradient norm exceeds max_norm, all gradients are scaled down proportionally so the norm equals max_norm.

ifg2>gmax:ggmaxg2g\text{if} \quad \|\mathbf{g}\|_2 > g_{\max}: \quad \mathbf{g} \leftarrow \frac{g_{\max}}{\|\mathbf{g}\|_2} \mathbf{g}

# Basic gradient clipping
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

# Monitor the gradient norm to inform your max_norm choice
def get_grad_norm(model):
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
return total_norm ** 0.5

# Log it every N steps to understand your gradient dynamics
grad_norm = get_grad_norm(model)
# If grad_norm is consistently below max_norm, clipping is not active
# If grad_norm frequently exceeds max_norm, training may be unstable

Choosing max_norm

  • 1.0: the most common default (Hugging Face Trainer, PyTorch Lightning defaults)
  • 0.5: more aggressive clipping, useful for very deep networks or unstable early training
  • 5.0: gentler clipping, fine when gradient norms are occasionally large but training is stable

:::tip Monitoring Gradient Norms Log the raw gradient norm before clipping at every step for the first few thousand steps. If you see norms consistently at 5–10x your max_norm, your learning rate is likely too high rather than needing more aggressive clipping. The root cause is usually the LR, not the gradients themselves. :::

Gradient Accumulation for Large Effective Batch Sizes

GPU memory limits effective batch size. With a 40GB A100, a ResNet-50 with 224×224 images fits roughly 256 images per batch. But research has shown that many tasks benefit from effective batch sizes of 1024, 2048, or even 8192. Buying 8–32x more GPUs is expensive. Gradient accumulation gives you the same effect for free.

The idea: perform N forward/backward passes without calling optimizer.step(), accumulating gradients across N micro-batches. After N steps, the accumulated gradients are mathematically equivalent to computing gradients over a single large batch of size N * batch_size.

ACCUMULATION_STEPS = 4 # effective batch = batch_size * 4

model.train()
optimizer.zero_grad()

for step, (batch_x, batch_y) in enumerate(dataloader):
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)

logits = model(batch_x)
# Divide loss by accumulation steps so gradients are correctly scaled
loss = criterion(logits, batch_y) / ACCUMULATION_STEPS
loss.backward()

if (step + 1) % ACCUMULATION_STEPS == 0:
# Clip after accumulation, before step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()

:::tip Why divide the loss? When accumulating N steps, each backward call adds N partial gradient contributions. Without dividing by N, the effective gradient is N times larger than intended - equivalent to multiplying the learning rate by N, which destabilizes training. :::

Gradient Accumulation with AMP

When combining gradient accumulation with mixed precision, the GradScaler must be used carefully:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
ACCUMULATION_STEPS = 4
optimizer.zero_grad()

for step, (batch_x, batch_y) in enumerate(dataloader):
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)

# AMP autocast wraps the forward pass
with autocast(device_type='cuda', dtype=torch.float16):
logits = model(batch_x)
loss = criterion(logits, batch_y) / ACCUMULATION_STEPS

# Scale the loss (GradScaler handles the scaling factor)
scaler.scale(loss).backward()

if (step + 1) % ACCUMULATION_STEPS == 0:
# Unscale before clipping so max_norm is in the right units
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()

BatchNorm and Gradient Accumulation

BatchNorm computes statistics from the current mini-batch, not the effective batch. This means with ACCUMULATION_STEPS=4, BatchNorm sees batch_size samples per step, not batch_size*4. For tasks where BatchNorm statistics matter (small batch training), consider switching to LayerNorm or GroupNorm, which do not have this limitation.

Mixed Precision Training

Mixed precision uses float16 (or bfloat16) for most computations and float32 for numerically sensitive operations (like softmax and loss). It gives 2–4x speedup on modern GPUs (V100+, A100) and reduces GPU memory by ~50%.

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler() # scales loss to prevent float16 underflow

def train_one_epoch_amp(model, dataloader, optimizer, loss_fn, device):
model.train()
total_loss = 0.0

for batch_x, batch_y in dataloader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)

optimizer.zero_grad()

# autocast: operations inside run in float16 where safe
with autocast(device_type='cuda', dtype=torch.float16):
logits = model(batch_x)
loss = loss_fn(logits, batch_y)

# GradScaler: scale loss to avoid fp16 underflow in gradients
scaler.scale(loss).backward()

# Unscale before gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# scaler.step: unscales gradients and calls optimizer.step()
# If NaN/Inf found in gradients, step is skipped automatically
scaler.step(optimizer)

# Update the scaler
scaler.update()

total_loss += loss.item()

return total_loss / len(dataloader)

bfloat16 (preferred on A100/H100)

# bfloat16 has the same exponent range as float32 (better stability than float16)
# No GradScaler needed on hardware that natively supports bf16

with autocast(device_type='cuda', dtype=torch.bfloat16):
logits = model(batch_x)
loss = loss_fn(logits, batch_y)

loss.backward()
optimizer.step()
optimizer.zero_grad()

Checkpointing: Save and Resume

Always save checkpoints - not just the model weights, but the full training state so you can resume exactly.

def save_checkpoint(path, model, optimizer, scheduler, epoch, val_loss, config):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
'val_loss': val_loss,
'config': config,
}, path)
print(f"Saved checkpoint: {path}")


def load_checkpoint(path, model, optimizer=None, scheduler=None, device='cpu'):
checkpoint = torch.load(path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

if optimizer and 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

if scheduler and checkpoint.get('scheduler_state_dict'):
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

epoch = checkpoint.get('epoch', 0)
val_loss = checkpoint.get('val_loss', float('inf'))
print(f"Loaded checkpoint from epoch {epoch} (val_loss={val_loss:.4f})")
return epoch, val_loss


# Resume training
START_FROM_CHECKPOINT = 'checkpoints/epoch_12.pt'

model = MLP(784, [512, 256], 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)

start_epoch, best_val_loss = load_checkpoint(
START_FROM_CHECKPOINT, model, optimizer, scheduler, device
)

for epoch in range(start_epoch + 1, n_epochs):
# continue training...
pass

Checkpoint Strategy in Production

import os
from pathlib import Path

class CheckpointManager:
"""Manages saving periodic and best checkpoints."""

def __init__(self, checkpoint_dir, keep_last_n=3):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.keep_last_n = keep_last_n
self.checkpoints = []
self.best_val_loss = float('inf')

def save(self, model, optimizer, scheduler, epoch, val_loss, config):
# Save periodic checkpoint
path = self.checkpoint_dir / f"epoch_{epoch:04d}.pt"
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
'val_loss': val_loss,
'config': config,
}, path)
self.checkpoints.append(path)

# Remove old checkpoints beyond keep_last_n
while len(self.checkpoints) > self.keep_last_n:
old = self.checkpoints.pop(0)
if old.exists():
old.unlink()

# Save best model separately
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
best_path = self.checkpoint_dir / "best_model.pt"
torch.save(model.state_dict(), best_path)
print(f" New best model saved (val_loss={val_loss:.4f})")

Early Stopping

class EarlyStopping:
"""Stop training when validation loss stops improving."""

def __init__(self, patience=10, min_delta=1e-4, checkpoint_path='best.pt'):
self.patience = patience
self.min_delta = min_delta
self.checkpoint_path = checkpoint_path
self.best_loss = float('inf')
self.counter = 0
self.stopped = False

def step(self, val_loss, model):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
torch.save(model.state_dict(), self.checkpoint_path)
return False # don't stop

self.counter += 1
print(f"EarlyStopping counter: {self.counter}/{self.patience}")
if self.counter >= self.patience:
self.stopped = True
return True # stop training

return False


# Usage
early_stop = EarlyStopping(patience=10, checkpoint_path='best_model.pt')

for epoch in range(n_epochs):
train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
val_loss, val_acc = evaluate(model, val_loader, loss_fn, device)

if early_stop.step(val_loss, model):
print(f"Early stopping triggered at epoch {epoch+1}")
break

# Load best weights before inference
model.load_state_dict(torch.load('best_model.pt'))

Loss Function Selection Guide

Choosing the wrong loss function is one of the most common bugs in ML - and it silently trains a model on the wrong objective.

TaskLoss FunctionPyTorch APINotes
Binary classificationBCE with logitsnn.BCEWithLogitsLossCombines sigmoid + BCE; numerically stable. Never use nn.BCELoss directly on raw logits.
Multi-class classificationCross-entropynn.CrossEntropyLossCombines LogSoftmax + NLLLoss. Input: raw logits. Target: integer class indices.
Multi-label classificationBCE with logitsnn.BCEWithLogitsLossOne sigmoid per class, independent. Input: (B, C) logits. Target: (B, C) float binary.
RegressionMSEnn.MSELossL2 loss. Sensitive to outliers.
Regression with outliersHubernn.HuberLossL1 for large errors, L2 for small. More robust.
RankingMargin rankingnn.MarginRankingLossFor learning to rank, where pairs or triplets are compared.
Knowledge distillationKL divergencenn.KLDivLossMeasures distribution mismatch between student and teacher.
Semantic segmentationDice + BCECustomDice loss handles class imbalance better than pure CE.
# Binary classification: ONE output neuron, no sigmoid in model
criterion = nn.BCEWithLogitsLoss()
# logits: (B,) or (B, 1), targets: (B,) float in {0.0, 1.0}

# Multi-class: N output neurons, no softmax in model
criterion = nn.CrossEntropyLoss()
# logits: (B, num_classes), targets: (B,) int64 in [0, num_classes)

# Weighted cross-entropy for class imbalance
class_weights = torch.tensor([1.0, 5.0, 3.0], device=device) # weight rare classes higher
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Label smoothing: prevents overconfidence
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Regression
criterion = nn.HuberLoss(delta=1.0) # delta: threshold between L1 and L2 behavior

Training Loop with Progress Bars and Logging

from tqdm import tqdm
import wandb

def train_one_epoch_with_logging(
model, dataloader, optimizer, loss_fn, device, epoch,
scaler=None, accumulation_steps=1, log_every=50
):
model.train()
total_loss = 0.0
optimizer.zero_grad()

pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)

for step, (batch_x, batch_y) in enumerate(pbar):
batch_x = batch_x.to(device, non_blocking=True)
batch_y = batch_y.to(device, non_blocking=True)

if scaler:
with torch.cuda.amp.autocast():
logits = model(batch_x)
loss = loss_fn(logits, batch_y) / accumulation_steps
scaler.scale(loss).backward()
else:
logits = model(batch_x)
loss = loss_fn(logits, batch_y) / accumulation_steps
loss.backward()

total_loss += loss.item() * accumulation_steps

if (step + 1) % accumulation_steps == 0:
if scaler:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
if scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
optimizer.zero_grad()

pbar.set_postfix({
'loss': f'{loss.item() * accumulation_steps:.4f}',
'grad_norm': f'{grad_norm:.3f}',
'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'
})

# Log to Weights & Biases
if step % log_every == 0:
wandb.log({
'train/loss': loss.item() * accumulation_steps,
'train/grad_norm': grad_norm,
'train/lr': optimizer.param_groups[0]['lr'],
'train/step': epoch * len(dataloader) + step,
})

return total_loss / len(dataloader)

Training Loop Architecture Diagram

Full Production Training Loop Template

This is the complete template used in real production ML systems, incorporating all the techniques above.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import wandb
import math
import os

# ─── Configuration ─────────────────────────────────────────────────────────────
config = {
'lr': 3e-4,
'batch_size': 128,
'n_epochs': 100,
'hidden_dims': [512, 256],
'dropout': 0.3,
'weight_decay': 1e-4,
'patience': 15,
'amp': True,
'accumulation_steps': 4,
'max_grad_norm': 1.0,
'warmup_epochs': 5,
'checkpoint_dir': 'checkpoints/',
}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─── Data ──────────────────────────────────────────────────────────────────────
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.long)
X_val_t = torch.tensor(X_val, dtype=torch.float32)
y_val_t = torch.tensor(y_val, dtype=torch.long)

train_loader = DataLoader(
TensorDataset(X_train_t, y_train_t),
batch_size=config['batch_size'],
shuffle=True,
num_workers=4,
pin_memory=True,
persistent_workers=True,
)
val_loader = DataLoader(
TensorDataset(X_val_t, y_val_t),
batch_size=config['batch_size'] * 2,
num_workers=4,
pin_memory=True,
persistent_workers=True,
)

# ─── Model, Optimizer, Scheduler, Scaler ───────────────────────────────────────
model = MLP(X_train.shape[1], config['hidden_dims'], n_classes,
config['dropout']).to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config['lr'],
weight_decay=config['weight_decay'],
)
total_steps = len(train_loader) * config['n_epochs']
warmup_steps = len(train_loader) * config['warmup_epochs']
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
scaler = GradScaler() if config['amp'] and device.type == 'cuda' else None
early_stop = EarlyStopping(patience=config['patience'])
ckpt_mgr = CheckpointManager(config['checkpoint_dir'])

os.makedirs(config['checkpoint_dir'], exist_ok=True)
wandb.init(project='my-project', config=config)

# ─── Training Loop ─────────────────────────────────────────────────────────────
best_val_loss = float('inf')
global_step = 0

for epoch in range(config['n_epochs']):
# ── Train ──
model.train()
train_loss = 0.0
optimizer.zero_grad()

pbar = tqdm(train_loader, desc=f"Epoch {epoch+1:03d} [train]", leave=False)
for step, (batch_x, batch_y) in enumerate(pbar):
batch_x = batch_x.to(device, non_blocking=True)
batch_y = batch_y.to(device, non_blocking=True)

if scaler:
with autocast(device_type='cuda'):
logits = model(batch_x)
loss = loss_fn(logits, batch_y) / config['accumulation_steps']
scaler.scale(loss).backward()
else:
logits = model(batch_x)
loss = loss_fn(logits, batch_y) / config['accumulation_steps']
loss.backward()

train_loss += loss.item() * config['accumulation_steps']

if (step + 1) % config['accumulation_steps'] == 0:
if scaler:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), config['max_grad_norm']
)
if scaler:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
scheduler.step() # step-level scheduler
optimizer.zero_grad()
global_step += 1

wandb.log({
'train/loss': loss.item() * config['accumulation_steps'],
'train/grad_norm': grad_norm,
'train/lr': optimizer.param_groups[0]['lr'],
'step': global_step,
})
pbar.set_postfix(
loss=f"{loss.item() * config['accumulation_steps']:.4f}",
gnorm=f"{grad_norm:.3f}",
)

avg_train_loss = train_loss / len(train_loader)

# ── Validate ──
model.eval()
val_loss_sum, correct, total = 0.0, 0, 0
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x = batch_x.to(device, non_blocking=True)
batch_y = batch_y.to(device, non_blocking=True)
logits = model(batch_x)
val_loss_sum += loss_fn(logits, batch_y).item()
correct += (logits.argmax(1) == batch_y).sum().item()
total += len(batch_y)

avg_val_loss = val_loss_sum / len(val_loader)
val_acc = correct / total

wandb.log({'val/loss': avg_val_loss, 'val/acc': val_acc, 'epoch': epoch})
print(
f"Epoch {epoch+1:03d} | "
f"train={avg_train_loss:.4f} | "
f"val={avg_val_loss:.4f} | "
f"acc={val_acc:.4f} | "
f"lr={optimizer.param_groups[0]['lr']:.2e}"
)

# ── Checkpoint ──
ckpt_mgr.save(model, optimizer, scheduler, epoch, avg_val_loss, config)

# ── Early stopping ──
if early_stop.step(avg_val_loss, model):
print(f"Early stopping triggered at epoch {epoch+1}")
break

# Load best weights
model.load_state_dict(torch.load(early_stop.checkpoint_path, map_location=device))
wandb.finish()

Debugging Training Problems

Training failures almost always fall into one of four categories. Knowing the symptoms lets you diagnose and fix them quickly.

(a) Loss Not Decreasing

SymptomLikely CauseFix
Loss flat from epoch 1LR too lowIncrease LR by 10x and observe
Loss decreases slowlyLR slightly too lowUse LR finder (fast.ai) or 3x LR
Loss flat after initial dropStuck in local minimumAdd LR warmup, change optimizer
Loss at a specific value and stuckWrong loss functionCheck loss function vs task
Labels look correct but loss is log(num_classes)Random initialization stuckCheck model output shape
# Quick diagnostic: what should random-init loss be?
# For CrossEntropyLoss with C classes, loss ≈ log(C) at initialization
import math
print(f"Expected initial loss for {n_classes} classes: {math.log(n_classes):.4f}")
# If your loss starts much higher or lower, something is wrong

(b) NaN Loss

NaN loss is catastrophic - the model is unrecoverable after this.

# Checklist when you see NaN loss:
# 1. Gradient explosion: add clip_grad_norm_ with max_norm=1.0
# 2. LR too high: reduce by 10x
# 3. log(0) in custom loss: add epsilon
# loss = -torch.log(probs + 1e-8) # NOT -torch.log(probs)
# 4. Division by zero: protect denominators
# 5. float16 underflow: switch to bfloat16 or add GradScaler

# Detect NaN early
def check_nan(loss, model, step):
if torch.isnan(loss):
print(f"NaN loss at step {step}")
for name, param in model.named_parameters():
if param.grad is not None and torch.isnan(param.grad).any():
print(f" NaN gradient in: {name}")
raise RuntimeError("NaN detected - check learning rate and loss function")

(c) GPU Memory OOM (Out Of Memory)

# Symptoms: RuntimeError: CUDA out of memory
# Fixes in order of ease:
# 1. Reduce batch size
# 2. Use gradient accumulation to keep effective batch size
# 3. Enable gradient checkpointing (trades compute for memory)
# 4. Use mixed precision (fp16/bf16) - halves memory
# 5. Move the model to CPU for parts not in the critical path

from torch.utils.checkpoint import checkpoint_sequential

# Gradient checkpointing: recompute activations during backward
# instead of storing them. Reduces memory by sqrt(n_layers).
model.gradient_checkpointing_enable() # HuggingFace models
# or manually:
output = checkpoint_sequential(model.layers, chunks=4, input=x)

# Monitor GPU memory
print(f"Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")
torch.cuda.empty_cache() # free cached but unused memory

(d) Training Good, Validation Worse (Overfitting)

# Symptoms: train loss → 0, val loss diverges upward after peak
# Fixes:
# 1. Add dropout (0.1–0.5 depending on model size)
# 2. Add weight decay (1e-4 to 1e-2)
# 3. Data augmentation (image: flips, crops; text: back-translation)
# 4. Early stopping (stop at val loss minimum)
# 5. Reduce model capacity
# 6. More training data

# IMPORTANT: Check for data leakage first
# If val accuracy is suspiciously high early → data may be leaking
# Common causes: temporal splits done wrong, target encoding done before split

:::danger Data Leakage If your validation accuracy starts very high (e.g., 95% at epoch 1 on a hard problem), you almost certainly have data leakage. Common causes: preprocessing computed over the full dataset before train/val split, val images appearing in train folder, or test-time augmentation applied to labels. Always fit all preprocessing on the training set only. :::

Common Mistakes

Mistake 1: zero_grad() after backward

# WRONG
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.zero_grad() # too late - gradients already accumulated
optimizer.step()

# CORRECT
optimizer.zero_grad() # clear before computing
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()

Mistake 2: Validation in train mode

# WRONG
# (model.train() is still active from training loop)
for val_x, val_y in val_loader:
preds = model(val_x) # Dropout is ON - results are noisy

# CORRECT
model.eval()
with torch.no_grad():
for val_x, val_y in val_loader:
preds = model(val_x) # Dropout is OFF

Mistake 3: Accumulating tensors instead of floats

# WRONG: builds up a computation graph over all batches - memory leak
total_loss_tensor = torch.tensor(0.0, device=device)
for batch_x, batch_y in dataloader:
loss = criterion(model(batch_x), batch_y)
total_loss_tensor += loss # keeps the entire graph alive!

# CORRECT: detach each batch loss
total_loss = 0.0
for batch_x, batch_y in dataloader:
loss = criterion(model(batch_x), batch_y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.detach().item() # scalar, no graph retained

Mistake 4: Wrong scheduler.step() timing

# WRONG: calling step() before optimizer.step()
scheduler.step() # updates LR before the optimizer uses it
optimizer.step() # uses the already-updated LR

# ALSO WRONG: calling step() on ReduceLROnPlateau without a metric
scheduler.step() # ReduceLROnPlateau requires a metric argument

# CORRECT: step after optimizer, and pass metric when required
optimizer.step()
scheduler.step() # for CosineAnnealingLR
scheduler.step(val_loss) # for ReduceLROnPlateau

:::warning OneCycleLR Timing OneCycleLR is called every batch, not every epoch. This is the one scheduler that breaks the "step once per epoch" rule. Calling it per epoch produces a completely wrong schedule. :::

YouTube Resources

VideoCreatorWhat It Covers
Neural Networks: Zero to Hero - makemoreAndrej KarpathyClean PyTorch training loops from first principles
Learning Rate Scheduling with W&BWeights & BiasesLR schedules compared with live experiments
PyTorch Performance Tuning GuidePyTorchMixed precision, DataLoader, and training best practices
Debugging PyTorch ModelsPyTorchDiagnosing NaN loss, OOM, and training instability

Interview Q&A

Q1: Walk me through the order of operations in a correct PyTorch training loop.

The correct order is: (1) optimizer.zero_grad() - clear accumulated gradients from the previous step, (2) forward pass: logits = model(x), (3) loss = criterion(logits, y), (4) loss.backward() - compute gradients via autograd, (5) optionally clip_grad_norm_() - prevent gradient explosion, (6) optimizer.step() - update parameters using the computed gradients. zero_grad() must come before backward() because PyTorch accumulates gradients. Putting it after step() means the gradients from the current step are zeroed before they are used - but it still works because step happens first. The safest place is before backward().

Q2: What is gradient accumulation and how do you implement it correctly?

Gradient accumulation simulates a larger effective batch size when GPU memory cannot fit a single large batch. You perform N forward/backward passes without calling optimizer.step(), accumulating gradients across N micro-batches, then step once. The critical correctness requirement: divide the loss by N before calling backward(). Without this, the effective gradient is N times larger than intended - equivalent to multiplying the learning rate by N, which destabilizes training. Gradient clipping must also happen after accumulation and before the optimizer step. With AMP, scaler.unscale_(optimizer) must be called before clipping so the gradient norms are in the correct float32 scale. BatchNorm behavior is not affected by accumulation - it still sees mini-batch statistics.

Q3: When should you use gradient clipping, and how do you choose max_norm?

Gradient clipping is essential for RNNs and transformers, where gradients flowing through many time steps or attention layers can compound multiplicatively. It is also good practice for any model as a safety net against exploding gradients from corrupted data or learning rate misconfiguration. The standard default is max_norm=1.0 - used by Hugging Face, PyTorch Lightning, and most major training frameworks. To calibrate it, log the raw gradient norm for the first 1000 steps: if norms are consistently well below 1.0, clipping is not active and is just a safety net; if they frequently exceed it, training may be unstable and you should consider reducing the learning rate rather than just tightening the clip. The key insight: clipping is a band-aid - if norms are consistently large, fix the root cause (LR, architecture, data).

Q4: Compare CosineAnnealingLR and OneCycleLR. When would you choose each?

CosineAnnealingLR decays the learning rate smoothly from eta_max to eta_min following a cosine curve over T_max epochs. It is called once per epoch. It is the standard choice for image classification, language modeling, and any training run where you want predictable, smooth LR decay. OneCycleLR uses a single annealing cycle with a linear warmup phase to a peak LR, followed by aggressive cosine decay to near zero. It is called once per batch. It was designed for "superconvergence" - achieving high accuracy in significantly fewer epochs than traditional schedules. Choose OneCycleLR when you have a limited compute budget and need fast convergence (e.g., training from scratch with limited time). Choose CosineAnnealingLR for standard training runs where stability matters more than speed.

Q5: How do you debug NaN loss during training?

NaN loss has four main causes: (1) gradient explosion - fix with clip_grad_norm_(model.parameters(), 1.0), (2) learning rate too high - reduce by 10x and retry, (3) log of zero in a custom loss function - add epsilon: torch.log(probs + 1e-8), (4) float16 underflow with AMP - add a GradScaler or switch to bfloat16. To diagnose: after detecting a NaN loss, iterate over all named parameters and check if any gradient contains NaN. This tells you which layer is the source. Also check the data - corrupted input (infinity or NaN in the input tensor) propagates forward and produces NaN loss even with a correct implementation. Add assert not torch.isnan(batch_x).any() in the first few training steps to rule out bad data.

Q6: What should a production checkpoint save, and why does saving only model weights fail for resuming training?

A production checkpoint must save: the model state_dict, the optimizer state_dict, the scheduler state_dict, the current epoch number, the best validation metric, the random number generator states (optional but recommended for exact reproducibility), and the hyperparameter config. Saving only model weights is sufficient for inference but breaks training resumption because Adam-based optimizers store the first moment (momentum) and second moment (variance estimate) for every parameter. These adaptive moments take many steps to warm up and encode the training history. Starting from a checkpoint with only model weights resets all optimizer moments to zero, producing incorrect gradient estimates for the first few hundred steps - particularly damaging for LR schedulers that depend on step count.

Q7: How does mixed precision training work, and when is GradScaler not needed?

Mixed precision training runs the forward and backward passes in float16 (half precision) to exploit GPU tensor core throughput, while keeping the master weights in float32 for numerical stability during the optimizer step. The problem: float16 has a much narrower representable range than float32. Very small gradient values underflow to zero. GradScaler addresses this by multiplying the loss by a large scale factor before backward(), making gradients temporarily larger so they survive the float16 range. Before optimizer.step(), it divides the accumulated gradients by the same scale factor (unscaling). If NaN or Inf gradients are detected during unscaling, the optimizer step is skipped for that iteration and the scale factor is reduced. GradScaler is not needed when using bfloat16 because bfloat16 has the same exponent range as float32 - only the mantissa precision is reduced. A100 and H100 GPUs support bfloat16 natively, and it is the preferred format for all modern LLM training.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Training Dynamics demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.