Training Dynamics and Debugging
The Production Scenario
It is 11 PM on the night before a quarterly review. Your model has been training for six hours on four A100s. The validation loss is 2.341. It was 2.340 two hours ago. It was 2.342 an hour ago. The training loss is still decreasing - slowly - but validation has not moved.
You open the experiment tracker. The loss curve is flat as a frozen lake for the last two hours. The gradient norm plot shows healthy numbers. No NaNs. No obvious spike. The learning rate is 3e-5, on a cosine schedule that has been running for 40% of the total budget.
Is this normal convergence? Is the model stuck in a saddle point with zero gradient curvature in most directions? Did the cosine schedule decay the LR below the effective threshold two hours ago, turning the optimizer into a statue? Is there a subtle bug in your validation data pipeline that is computing the same metric for every batch? Is the model silently overfitting to the training set while validation quietly degrades - and you are measuring validation loss on a stale checkpoint?
You have a meeting in eight hours and cannot afford another six-hour experiment.
Every one of these hypotheses has a specific diagnostic test. Every pathology has a signature in the loss curve, the gradient norms, the activation statistics, and the learning rate schedule. Senior ML engineers do not guess - they diagnose. They open the right plot, run the right check, and know within twenty minutes whether to restart or wait.
This lesson is that diagnostic toolkit. Every symptom. Every test. Every fix.
Why Training Dynamics Matter
Neural network training is not a simple convex optimization. The loss function is a high-dimensional, non-convex surface with:
- Flat saddle points where gradients are near zero in all directions
- Sharp narrow valleys where large LR causes oscillation but small LR causes slow progress
- Flat wide basins where many equally good solutions live
- Cliff-like loss spikes in transformers and RNNs when gradients suddenly explode
Understanding the geometry of this surface - and how your optimizer, learning rate, and model interact with it - is what separates engineers who debug in hours from engineers who debug in days.
The tools in this lesson are organized by the diagnostic question they answer:
- Loss landscape - what shape is the surface I am optimizing?
- Gradient flow - are gradients reaching all layers?
- Learning rate - is my LR in the right regime?
- NaN detection - where did the first NaN appear?
- Activation health - are neurons alive and well-distributed?
- Data pipeline - is what the model is seeing correct?
The Loss Landscape: Geometry and Generalization
Why Loss Landscape Shape Matters
The loss landscape is the function over the high-dimensional parameter space. Finding a minimum is the goal of training, but not all minima are equal:
Sharp minima: narrow, deep basins. A small perturbation to the weights causes a large increase in loss. The model has memorized very specific features. When you deploy and real data differs slightly from training data, the model is already on the steep walls of the basin - validation loss is high.
Flat minima: wide, shallow basins. A large perturbation to the weights causes only a small increase in loss. The model has learned robust features. Slight distribution shift at inference keeps you in the bottom of the basin - validation loss stays low.
The Sharpness-Aware Minimization (SAM) paper formalized this: flat minima correlate with better generalization. In practice, this means:
- Large batch training tends toward sharp minima (the gradient estimator is too accurate, converging to the nearest minimum rather than exploring)
- Small batch training adds gradient noise that helps escape sharp minima and find flatter regions
- High learning rates and warmup schedules allow the optimizer to traverse the landscape more broadly before settling
The loss flatness metric: measure how much the loss changes when you perturb all weights by a fixed-norm random vector. If the loss barely changes (flat minimum), generalization will be good. If it spikes (sharp minimum), consider reducing batch size or adding noise.
import torch
import torch.nn as nn
import numpy as np
import copy
def measure_sharpness(
model: nn.Module,
loss_fn,
dataloader,
epsilon: float = 0.05,
n_samples: int = 20,
device: str = "cuda",
) -> dict:
"""
Estimate loss landscape sharpness by perturbing weights randomly.
Lower sharpness_ratio → flatter minimum → better expected generalization.
Returns: dict with base_loss, mean_perturbed_loss, sharpness_ratio.
"""
model.eval()
model = model.to(device)
# Compute base loss on a fixed sample
base_losses = []
sample_batches = []
with torch.no_grad():
for i, (x, y) in enumerate(dataloader):
if i >= 4:
break
x, y = x.to(device), y.to(device)
loss = loss_fn(model(x), y)
base_losses.append(loss.item())
sample_batches.append((x, y))
base_loss = np.mean(base_losses)
# Save original parameters
original_state = copy.deepcopy(model.state_dict())
# Perturb and measure
perturbed_losses = []
for _ in range(n_samples):
# Add random perturbation scaled to epsilon * param_norm
with torch.no_grad():
for param in model.parameters():
noise = torch.randn_like(param)
# Normalize to unit norm, then scale by epsilon * param_norm
param_norm = param.norm().item()
noise = noise / (noise.norm() + 1e-8)
param.add_(noise * epsilon * param_norm)
# Measure loss with perturbation
sample_losses = []
with torch.no_grad():
for x, y in sample_batches:
loss = loss_fn(model(x), y)
sample_losses.append(loss.item())
perturbed_losses.append(np.mean(sample_losses))
# Restore original parameters
model.load_state_dict(original_state)
mean_perturbed = np.mean(perturbed_losses)
sharpness_ratio = (mean_perturbed - base_loss) / (abs(base_loss) + 1e-8)
return {
"base_loss": base_loss,
"mean_perturbed_loss": mean_perturbed,
"sharpness_ratio": sharpness_ratio,
"interpretation": "sharp" if sharpness_ratio > 0.1 else "flat",
}
Loss vs Accuracy Divergence
In classification, training loss and training accuracy should decrease and increase together. When they diverge, something is wrong:
Loss decreases, accuracy stalls: the model is adjusting confidence levels without changing its decisions. Often caused by label smoothing or class imbalance - the model learns to be less confident about easy predictions rather than correcting hard ones.
Accuracy increases, loss stalls: the model is making correct decisions but not with high enough confidence. Usually fine - it will continue to improve slowly.
Validation loss increases while validation accuracy stays flat: the model is becoming less calibrated (more overconfident on wrong predictions) even though the argmax is still correct. This is a warning sign of future degradation.
Training loss spikes then recovers: a bad batch with atypical examples. If this happens repeatedly, check for data corruption or extreme outliers. Add input validation or robust loss (Huber loss instead of MSE for regression).
Training Failure Diagnosis Tree
Loss Curve Pathologies
The shape of your loss curve is diagnostic. Every pathology has a specific cause.
1. HEALTHY CONVERGENCE
Loss │*
│ **
│ ***
│ ****
│ *******
└──────────────────── Steps
2. EXPLODING LOSS (LR too high or bad init)
Loss │ /
│ /
│ ____ /
│ / \/
│****
└──────────────────── Steps
Fix: reduce LR 10x, add grad clipping
3. STUCK LOSS (LR too small, dead neurons, bad init)
Loss │****
│ ****
│ *****
│ ******** <- barely moves
└──────────────────── Steps
Fix: raise LR, check dead neuron fraction
4. OSCILLATING (LR too high for late training)
Loss │ * * * * * * * *
│ * * * * * * * * *
│* * * * * * * * * *
└──────────────────── Steps
Fix: reduce LR or use LR warmup + decay
5. OVERFIT (good train, diverging val)
Loss │* (train) <- still dropping
│ **
│ ***___ <- val starts rising
│ * * (val)
└──────────────────── Steps
Fix: regularization, more data, early stopping
6. UNDERFIT (both losses high, similar)
Loss │*****
│ ****
│ ****
│ **** <- flattens early and high
└──────────────────── Steps
Fix: larger model, reduce regularization, raise LR
| Symptom | Most Likely Cause | Primary Fix |
|---|---|---|
| Loss → NaN immediately | LR too high, log(0) | Reduce LR 10x, clamp inputs |
| Loss → NaN after N steps | Exploding gradients | Clip grad norm, reduce LR |
| Loss stuck from step 1 | Bad init, LR too small | He init, increase LR |
| Loss stuck after initial drop | LR decayed too fast | Longer warmup, check scheduler |
| Loss oscillates throughout | LR too high | Reduce by 3–10x |
| Train OK, val diverges | Overfitting | Regularization, more data |
| Both losses high and equal | Underfitting / data bug | Larger model, audit data |
| Val loss good but metric bad | Metric implementation bug | Verify metric from scratch |
Gradient Flow Analysis
Gradient flow is the circulatory system of learning. When gradients cannot reach early layers, those layers do not learn. When gradients are too large, parameters explode. Monitoring gradient norms per layer is the first diagnostic tool to reach for.
import torch
import torch.nn as nn
from collections import defaultdict
from typing import Dict, List
import matplotlib.pyplot as plt
def compute_gradient_norms(model: nn.Module) -> Dict[str, float]:
"""
Compute L2 gradient norm for every named parameter.
Call after loss.backward() and before optimizer.step().
"""
norms = {}
for name, param in model.named_parameters():
if param.grad is not None:
norms[name] = param.grad.detach().norm(2).item()
else:
norms[name] = 0.0 # no gradient - may indicate a bug
return norms
def print_gradient_summary(
model: nn.Module,
threshold_vanish: float = 1e-6,
threshold_explode: float = 10.0,
) -> Dict[str, float]:
"""
Summarize gradient norms with pass/fail status per layer.
Call after backward(), before optimizer.step().
"""
norms = compute_gradient_norms(model)
print(f"\n{'Layer':<50} {'Grad Norm':>12} {'Status':>12}")
print("-" * 76)
for name, norm in norms.items():
if norm == 0.0:
status = "NO GRAD"
elif norm < threshold_vanish:
status = "VANISHING"
elif norm > threshold_explode:
status = "EXPLODING"
else:
status = "OK"
print(f"{name:<50} {norm:>12.4e} {status:>12}")
all_norms = list(norms.values())
if all_norms:
print(f"\nSummary: min={min(all_norms):.2e}, "
f"max={max(all_norms):.2e}, "
f"mean={sum(all_norms)/len(all_norms):.2e}")
return norms
class GradientFlowLogger:
"""
Logs gradient norms per layer across training steps.
Enables plotting to visualize gradient health over time.
Usage:
logger = GradientFlowLogger(model)
# ... in training loop, after backward():
logger.record(step)
# ... after training:
logger.plot()
"""
def __init__(self, model: nn.Module):
self.model = model
self.history: Dict[str, List[tuple[int, float]]] = defaultdict(list)
def record(self, step: int):
for name, param in self.model.named_parameters():
if param.grad is not None:
norm = param.grad.detach().norm(2).item()
self.history[name].append((step, norm))
def plot(self, save_path: str = "gradient_flow.png", top_k: int = 6):
"""Plot gradient norms over time for the top_k most active layers."""
import matplotlib.pyplot as plt
# Select top_k layers by mean gradient norm
means = {
name: sum(n for _, n in vals) / max(len(vals), 1)
for name, vals in self.history.items()
}
top_layers = sorted(means, key=means.get, reverse=True)[:top_k]
fig, axes = plt.subplots(len(top_layers), 1,
figsize=(12, 3 * len(top_layers)), sharex=True)
if len(top_layers) == 1:
axes = [axes]
for ax, name in zip(axes, top_layers):
steps, norms = zip(*self.history[name])
ax.semilogy(steps, norms, linewidth=1.5)
ax.set_ylabel("Grad Norm (log)")
ax.set_title(name, fontsize=9)
ax.axhline(1e-6, color="red", linestyle="--", alpha=0.5, label="vanish threshold")
ax.axhline(10.0, color="orange", linestyle="--", alpha=0.5, label="explode threshold")
ax.legend(fontsize=7)
axes[-1].set_xlabel("Training Step")
fig.suptitle("Gradient Norms Per Layer Over Training")
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"Gradient flow plot saved to {save_path}")
What to look for in gradient norm plots:
- Healthy: all layers show similar-order-magnitude gradient norms. Early layers have norms within 10–100x of late layers.
- Vanishing: early layer norms are 1000x smaller than late layers. Catastrophic in deep networks without residual connections.
- Exploding: any layer shows norms above 10. Will produce NaN weights within steps if not clipped.
- Dead layers: gradient norm exactly 0.0 or near zero and constant. The layer is not participating in learning. Check for disconnected computational graph or all-dead ReLU units.
- Oscillating norms: gradient norm bounces between large and small. Often caused by LR too high - the optimizer is overshooting and bouncing around.
Learning Rate: The Most Important Hyperparameter
Why LR Matters More Than Architecture
In practice, a suboptimal learning rate destroys training faster than a suboptimal architecture. Too high: the optimizer overshoots optima and oscillates or diverges. Too low: convergence is so slow that the model does not reach a good solution within the compute budget. The right LR is in a narrow range that varies per model, per dataset, per optimizer, and per batch size.
The linear scaling rule (Goyal et al., 2017): when multiplying batch size by , multiply LR by . Intuition: a larger batch gives a less noisy gradient estimate, so you can take larger steps. This is empirically validated for SGD on vision tasks. For Adam, the relationship is weaker but still directional.
Learning Rate Finder (Leslie Smith's Algorithm)
Rather than guessing LR, run a systematic search:
- Start with a very small LR (1e-7)
- Train for one mini-batch, record the loss
- Multiply LR by a constant factor (e.g., 1.1–1.3)
- Repeat for 100–200 steps until loss diverges
- Plot loss vs LR (on log scale)
- Choose the LR at the steepest downward slope - just before the loss starts rising
The LR at the bottom of the curve is often too high (the model is about to diverge). The steepest descent point (roughly 10x smaller than the divergence LR) is the sweet spot.
import torch
import torch.nn as nn
import copy
import math
import matplotlib.pyplot as plt
from typing import Callable
class LRFinder:
"""
Leslie Smith's Learning Rate Finder.
Trains for a range of LRs and records loss.
Use the LR at the steepest decline in the plot.
Reference: Smith (2015), "Cyclical Learning Rates for Training Neural Networks"
"""
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
device: str = "cuda",
):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device
# Save initial state - restore after the search
self._model_state = copy.deepcopy(model.state_dict())
self._optim_state = copy.deepcopy(optimizer.state_dict())
self.lr_history: list[float] = []
self.loss_history: list[float] = []
def run(
self,
dataloader,
start_lr: float = 1e-7,
end_lr: float = 10.0,
n_steps: int = 100,
smooth_f: float = 0.05, # exponential smoothing factor
diverge_threshold: float = 5.0, # stop if loss > best * this
):
"""Run the LR range test."""
# Set starting LR
for group in self.optimizer.param_groups:
group["lr"] = start_lr
lr_factor = (end_lr / start_lr) ** (1.0 / n_steps)
best_loss = float("inf")
avg_loss = 0.0
data_iter = iter(dataloader)
self.model.train()
for step in range(n_steps):
try:
inputs, targets = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
inputs, targets = next(data_iter)
inputs = inputs.to(self.device)
targets = targets.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
loss.backward()
self.optimizer.step()
# Exponential smoothing of loss
current_loss = loss.item()
avg_loss = smooth_f * current_loss + (1 - smooth_f) * avg_loss
smoothed = avg_loss / (1 - (1 - smooth_f) ** (step + 1)) # bias correction
current_lr = self.optimizer.param_groups[0]["lr"]
self.lr_history.append(current_lr)
self.loss_history.append(smoothed)
if smoothed < best_loss:
best_loss = smoothed
# Divergence check
if smoothed > diverge_threshold * best_loss:
print(f"Loss diverged at step {step}, LR={current_lr:.2e}")
break
# Update LR for next step
for group in self.optimizer.param_groups:
group["lr"] *= lr_factor
# Restore original model and optimizer state
self.model.load_state_dict(self._model_state)
self.optimizer.load_state_dict(self._optim_state)
return self
def suggest_lr(self) -> float:
"""
Suggest the learning rate at the steepest loss decline.
This is the point of maximum negative gradient of the loss-lr curve.
"""
if len(self.loss_history) < 3:
return self.lr_history[0] if self.lr_history else 1e-3
# Compute numerical gradient of loss w.r.t. log(lr)
log_lrs = [math.log(lr) for lr in self.lr_history]
slopes = []
for i in range(1, len(self.loss_history) - 1):
slope = (self.loss_history[i + 1] - self.loss_history[i - 1]) / (
log_lrs[i + 1] - log_lrs[i - 1]
)
slopes.append((slope, self.lr_history[i]))
# Steepest descent = most negative slope
best_slope, best_lr = min(slopes, key=lambda x: x[0])
return best_lr
def plot(self, save_path: str = "lr_finder.png"):
"""Plot the loss-vs-LR curve."""
fig, ax = plt.subplots(figsize=(10, 6))
ax.semilogx(self.lr_history, self.loss_history, linewidth=2)
suggested = self.suggest_lr()
ax.axvline(suggested, color="red", linestyle="--",
label=f"Suggested LR: {suggested:.2e}")
ax.set_xlabel("Learning Rate (log scale)")
ax.set_ylabel("Loss (smoothed)")
ax.set_title("Learning Rate Finder - Choose LR at Steepest Descent")
ax.legend()
ax.grid(True, alpha=0.3)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"LR finder plot saved to {save_path}")
print(f"Suggested LR: {suggested:.2e}")
# Usage
# model = MyModel()
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-7)
# finder = LRFinder(model, optimizer, nn.CrossEntropyLoss(), device="cuda")
# finder.run(train_loader, start_lr=1e-7, end_lr=1.0, n_steps=100)
# finder.plot()
# print(f"Use LR: {finder.suggest_lr():.2e}")
Cyclical Learning Rates and Warmup
Linear Warmup: start with a very small LR and linearly increase to the target LR over the first 5–10% of training. This prevents early large-gradient steps from corrupting the initialization before the optimizer has calibrated its momentum estimates. Critical for transformers - without warmup, attention weights can diverge in the first 100 steps.
Cosine Annealing: after warmup, decay the LR following a cosine curve from the peak to a small floor value (e.g., 1/10th of peak). The cosine schedule is smoother than step decay - the gradual decrease allows the model to settle into a flat minimum rather than being abruptly stopped.
Cyclical LR (CLR): instead of monotonically decreasing, cycle the LR between a minimum and maximum. This can escape local minima (LR spikes kick the optimizer out) and find flatter minima (LR valleys allow settling). Leslie Smith introduced this in 2015, and it is still used in One-Cycle training.
import torch
import torch.optim as optim
import math
def get_cosine_schedule_with_warmup(
optimizer: optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.1,
) -> optim.lr_scheduler.LambdaLR:
"""
Linear warmup followed by cosine annealing.
Standard schedule for transformer training.
num_warmup_steps: steps for linear warmup phase
num_training_steps: total training steps
min_lr_ratio: LR at end of cosine decay as fraction of peak LR
"""
def lr_lambda(current_step: int) -> float:
if current_step < num_warmup_steps:
# Linear warmup: from near-zero to 1.0
return float(current_step) / float(max(1, num_warmup_steps))
# Cosine annealing from 1.0 to min_lr_ratio
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay
return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def get_one_cycle_schedule(
optimizer: optim.Optimizer,
max_lr: float,
total_steps: int,
pct_start: float = 0.3,
) -> optim.lr_scheduler.OneCycleLR:
"""
One Cycle LR policy (Leslie Smith, 2019).
Increases LR linearly to max, then decreases with cosine annealing.
Combined with momentum scheduling: high momentum at low LR, low at high LR.
"""
return optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=max_lr,
total_steps=total_steps,
pct_start=pct_start, # fraction of steps for warmup phase
anneal_strategy="cos",
div_factor=25.0, # initial_lr = max_lr / 25
final_div_factor=1e4, # final_lr = initial_lr / 1e4
)
# Warmup schedule usage example
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# scheduler = get_cosine_schedule_with_warmup(
# optimizer,
# num_warmup_steps=500,
# num_training_steps=10000,
# )
# for step, batch in enumerate(dataloader):
# loss = train_step(batch)
# loss.backward()
# optimizer.step()
# scheduler.step()
# optimizer.zero_grad()
Gradient Clipping: Defending Against Explosions
Gradient clipping prevents exploding gradients from corrupting the parameter update. Two clipping strategies:
Clip by global norm (recommended): compute the L2 norm of all gradients combined, then scale down all gradients proportionally if the norm exceeds a threshold. This preserves the direction of the gradient while limiting its magnitude.
Clip by value: independently clamp each gradient element to . Simpler but changes the direction of the gradient. Rarely used in modern practice.
import torch
import torch.nn as nn
def clip_and_log_gradients(
model: nn.Module,
max_norm: float = 1.0,
log_clipping: bool = True,
) -> float:
"""
Clip gradients by global norm and optionally log when clipping is active.
Returns the gradient norm BEFORE clipping.
The returned pre-clip norm is valuable for monitoring:
- Consistently large (>> max_norm): LR may be too high or init is poor
- Occasionally large: normal, clipping is working as intended
- Always small (<< max_norm): clipping is never active; could increase LR
"""
total_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=max_norm
)
norm_value = total_norm.item()
if log_clipping and norm_value > max_norm:
clip_ratio = max_norm / norm_value
print(f" [GradClip] Pre-clip norm: {norm_value:.4f}, "
f"clip ratio: {clip_ratio:.4f}")
return norm_value
# In training loop:
# loss.backward()
# pre_clip_norm = clip_and_log_gradients(model, max_norm=1.0)
# log_metric("grad/pre_clip_norm", pre_clip_norm, step=global_step)
# optimizer.step()
:::tip When to Set max_norm
For most tasks with Adam: max_norm=1.0. For RNNs and LSTMs: max_norm=0.1–0.5 (they are more prone to exploding gradients). For transformers: max_norm=1.0 is standard. If the pre-clip norm is almost always less than max_norm, the clipping is not doing anything - this is fine, it just means the training is already stable.
:::
NaN Detection and Hunting
NaN values propagate silently. One NaN in any forward computation corrupts all downstream values and all gradients in the backward pass. After one optimizer step, all parameters that received NaN gradients become NaN. Within two steps, the entire model can be corrupted.
Common NaN sources:
| Source | Cause | Fix |
|---|---|---|
log(0) | Cross-entropy on zero probability | torch.clamp(x, min=1e-8) before log |
0 / 0 | Division by zero in normalization | Add epsilon to denominator |
sqrt(0) gradient | Gradient of sqrt at exactly zero | sqrt(x + 1e-8) |
| Large logits | exp(700) overflows to inf | Scale inputs, clip logits |
| fp16 overflow | Values exceed float16 range (~65504) | Use GradScaler, reduce LR |
| Exploding gradients | LR too high or no clipping | Clip gradients |
| NaN in input data | Corrupt data pipeline | Validate inputs before forward |
| In-place op on graph | Tensor modified after it was saved | Remove in-place ops |
import torch
import torch.nn as nn
from typing import List
def find_nan_in_model(model: nn.Module) -> List[str]:
"""Find all parameters and buffers containing NaN or Inf."""
problems = []
for name, param in model.named_parameters():
if torch.isnan(param).any():
problems.append(f"NaN in parameter: {name}")
if torch.isinf(param).any():
problems.append(f"Inf in parameter: {name}")
for name, buf in model.named_buffers():
if torch.isnan(buf).any():
problems.append(f"NaN in buffer: {name}")
return problems
def register_nan_detection_hooks(model: nn.Module) -> List:
"""
Register forward hooks that raise on the first NaN produced by any module.
Pinpoints exactly which operation produces the first NaN.
Usage:
hooks = register_nan_detection_hooks(model)
try:
output = model(x)
except RuntimeError as e:
print(e)
finally:
for h in hooks: h.remove()
"""
hooks = []
def make_hook(name: str):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
if torch.isnan(output).any():
# Inspect inputs for additional context
input_nan = [
torch.isnan(inp).any().item()
for inp in input
if isinstance(inp, torch.Tensor)
]
raise RuntimeError(
f"NaN detected in OUTPUT of module '{name}' "
f"(type: {type(module).__name__}).\n"
f"Input had NaN: {input_nan}\n"
f"Output shape: {output.shape}, "
f"NaN count: {torch.isnan(output).sum().item()}"
)
return hook
for name, module in model.named_modules():
hooks.append(module.register_forward_hook(make_hook(name)))
return hooks
def validate_batch(
inputs: torch.Tensor,
targets: torch.Tensor,
step: int,
input_clamp: float = 100.0,
) -> bool:
"""
Validate a training batch before the forward pass.
Returns True if batch is clean, False if problems found.
"""
issues = []
if torch.isnan(inputs).any():
issues.append(f"NaN in inputs (count: {torch.isnan(inputs).sum().item()})")
if torch.isinf(inputs).any():
issues.append(f"Inf in inputs")
if inputs.abs().max().item() > input_clamp:
issues.append(f"Extreme input values: max_abs={inputs.abs().max().item():.2f}")
if torch.isnan(targets).any():
issues.append(f"NaN in targets")
if issues:
print(f"Step {step} batch validation FAILED:")
for issue in issues:
print(f" - {issue}")
return False
return True
def loss_with_nan_check(
loss_value: torch.Tensor,
step: int,
model: nn.Module,
raise_on_nan: bool = False,
) -> float:
"""Safely extract loss value, checking for NaN and logging model state."""
value = loss_value.item()
if not (value == value): # NaN check: NaN != NaN is always True
print(f"\nStep {step}: NaN loss detected!")
problems = find_nan_in_model(model)
if problems:
print("NaN found in model parameters:")
for p in problems:
print(f" {p}")
else:
print("Model parameters are clean - NaN came from this batch")
if raise_on_nan:
raise ValueError(f"NaN loss at step {step}")
return value
Complete Debugging Checklist (Runnable Code)
import torch
import torch.nn as nn
from typing import Iterator, Dict, List, Optional
def overfit_one_batch(
model: nn.Module,
sample_batch: tuple,
criterion: nn.Module,
lr: float = 1e-3,
max_steps: int = 500,
target_loss: float = 0.01,
device: str = "cpu",
) -> Dict:
"""
Test whether the model can memorize a single batch.
A healthy model and loss function should drive loss near zero.
If it cannot: architectural bug, wrong loss, disconnected graph,
or initialization problem - no amount of training will help.
Returns dict with success, final_loss, loss_history.
"""
model = model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
inputs, targets = sample_batch
inputs = inputs.to(device)
targets = targets.to(device)
history = []
print(f"Overfit-one-batch test ({max_steps} max steps)...")
for step in range(max_steps):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
loss_val = loss.item()
history.append(loss_val)
if step % 50 == 0:
print(f" Step {step:>4}: loss = {loss_val:.6f}")
if loss_val != loss_val: # NaN
print(" FAILED: NaN loss - check loss function or architecture")
return {"success": False, "final_loss": float("nan"), "history": history}
if loss_val < target_loss:
print(f" SUCCESS: memorized in {step + 1} steps (loss={loss_val:.6f})")
return {"success": True, "final_loss": loss_val, "history": history}
print(f" FAILED: did not reach target {target_loss} after {max_steps} steps")
print(f" Final loss: {history[-1]:.6f}")
print(" Likely causes: wrong loss function, bad init, disconnected graph")
return {"success": False, "final_loss": history[-1], "history": history}
class TrainingDebugger:
"""
All-in-one debugging harness for the training loop.
Monitors: loss trends, gradient norms, clipping activity, NaN detection.
Usage:
debugger = TrainingDebugger(model, log_every=100)
for step, (x, y) in enumerate(loader):
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
debugger.step(step, loss.item(), optimizer)
optimizer.step()
"""
def __init__(
self,
model: nn.Module,
log_every: int = 100,
clip_grad_norm: float = 1.0,
window_size: int = 200,
):
self.model = model
self.log_every = log_every
self.clip_grad_norm = clip_grad_norm
self.window_size = window_size
self.loss_history: List[float] = []
self.grad_norm_history: List[float] = []
self.clip_events: List[int] = [] # steps where clipping was active
self._consecutive_nan: int = 0
def step(
self,
global_step: int,
loss: float,
optimizer: torch.optim.Optimizer,
extra_metrics: Optional[Dict[str, float]] = None,
):
"""
Call after loss.backward(), before optimizer.step().
Handles: NaN detection, gradient clipping, logging.
"""
# NaN guard - skip optimizer step if loss is NaN
if loss != loss:
self._consecutive_nan += 1
print(f"\n[Step {global_step}] NaN loss ({self._consecutive_nan} consecutive)")
if self._consecutive_nan >= 3:
problems = find_nan_in_model(self.model)
if problems:
print("NaN propagated to parameters:")
for p in problems:
print(f" {p}")
return False # signal to skip optimizer step
self._consecutive_nan = 0
self.loss_history.append(loss)
# Gradient clipping and norm tracking
pre_clip_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=self.clip_grad_norm
).item()
self.grad_norm_history.append(pre_clip_norm)
if pre_clip_norm > self.clip_grad_norm * 0.99:
self.clip_events.append(global_step)
# Periodic logging
if global_step % self.log_every == 0:
self._log(global_step, loss, pre_clip_norm, optimizer, extra_metrics)
return True # signal to proceed with optimizer step
def _log(
self,
step: int,
loss: float,
grad_norm: float,
optimizer: torch.optim.Optimizer,
extra: Optional[Dict] = None,
):
lr = optimizer.param_groups[0]["lr"]
# Trend over recent window
window = self.loss_history[-self.window_size:]
if len(window) > 10:
trend = window[-1] - window[0]
trend_str = f"{trend:+.4f}"
else:
trend_str = "N/A"
clip_pct = len(self.clip_events) / max(len(self.grad_norm_history), 1) * 100
clipped_this_step = (grad_norm >= self.clip_grad_norm * 0.99)
print(
f"\n[Step {step:>7}] "
f"loss={loss:.4f} "
f"trend({self.window_size})={trend_str} | "
f"grad_norm={grad_norm:.4f}{'[CLIPPED]' if clipped_this_step else ''} "
f"clip_freq={clip_pct:.1f}% | "
f"lr={lr:.2e}"
)
# Warnings
if len(window) > 50 and abs(trend) < 1e-4:
print(" WARNING: Loss has barely moved in last 200 steps.")
print(" Possible: LR decayed too low, dead neurons, plateau.")
if clip_pct > 50:
print(" WARNING: Clipping >50% of steps - LR may be too high.")
if extra:
for k, v in extra.items():
print(f" {k}: {v:.4f}")
def summary(self) -> Dict:
"""Return summary statistics for the training run."""
if not self.loss_history:
return {}
return {
"final_loss": self.loss_history[-1],
"min_loss": min(self.loss_history),
"mean_grad_norm": sum(self.grad_norm_history) / max(len(self.grad_norm_history), 1),
"clip_frequency": len(self.clip_events) / max(len(self.grad_norm_history), 1),
"nan_events": self._consecutive_nan,
}
TensorBoard and Weights & Biases Integration
Logging training metrics to a dashboard is not optional in production - it is how you catch problems before they waste hours of compute.
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
class TrainingLogger:
"""
Dual-backend logger: TensorBoard + optional W&B.
Logs: losses, gradient norms, learning rates, activation statistics.
"""
def __init__(
self,
log_dir: str = "runs/experiment",
use_wandb: bool = False,
wandb_project: str = "my-project",
wandb_config: dict = None,
):
self.writer = SummaryWriter(log_dir)
self.use_wandb = use_wandb
if use_wandb:
import wandb
wandb.init(project=wandb_project, config=wandb_config or {})
def log_scalars(self, metrics: dict, step: int):
"""Log a dict of scalar metrics."""
for key, value in metrics.items():
self.writer.add_scalar(key, value, step)
if self.use_wandb:
import wandb
wandb.log({**metrics, "step": step})
def log_gradient_norms(self, model: nn.Module, step: int):
"""Log per-layer gradient norms and total gradient norm."""
norms = {}
total_sq = 0.0
for name, param in model.named_parameters():
if param.grad is not None:
norm = param.grad.detach().norm(2).item()
norms[f"grad_norm/{name}"] = norm
total_sq += norm ** 2
total_norm = total_sq ** 0.5
norms["grad_norm/total"] = total_norm
self.log_scalars(norms, step)
def log_activation_histogram(
self, model: nn.Module, inputs: torch.Tensor, step: int
):
"""
Log activation histograms for each layer.
Useful for spotting saturation, dead neurons, or distribution collapse.
"""
activation_cache = {}
def make_hook(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
activation_cache[name] = output.detach().cpu()
return hook
handles = []
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
handles.append(module.register_forward_hook(make_hook(name)))
model.eval()
with torch.no_grad():
model(inputs)
for h in handles:
h.remove()
for name, acts in activation_cache.items():
self.writer.add_histogram(f"activations/{name}", acts, step)
def log_weight_histograms(self, model: nn.Module, step: int):
"""Log weight distributions per layer - watch for distribution collapse."""
for name, param in model.named_parameters():
self.writer.add_histogram(f"weights/{name}", param.detach().cpu(), step)
if param.grad is not None:
self.writer.add_histogram(f"grads/{name}", param.grad.detach().cpu(), step)
def log_lr(self, optimizer: torch.optim.Optimizer, step: int):
"""Log learning rate for each parameter group."""
for i, group in enumerate(optimizer.param_groups):
self.writer.add_scalar(f"lr/group_{i}", group["lr"], step)
def close(self):
self.writer.close()
if self.use_wandb:
import wandb
wandb.finish()
# --- Putting it all together: production training loop ---
def production_train_loop(
model: nn.Module,
train_loader,
val_loader,
optimizer: torch.optim.Optimizer,
scheduler,
criterion: nn.Module,
n_epochs: int,
log_dir: str = "runs/experiment",
device: str = "cuda",
):
logger = TrainingLogger(log_dir=log_dir)
debugger = TrainingDebugger(model, log_every=100, clip_grad_norm=1.0)
global_step = 0
for epoch in range(1, n_epochs + 1):
# ---- Training ----
model.train()
epoch_losses = []
for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
# Validate inputs in first epoch
if epoch == 1:
validate_batch(batch_x, batch_y, global_step)
optimizer.zero_grad(set_to_none=True)
logits = model(batch_x)
loss = criterion(logits, batch_y)
loss.backward()
# Debugger handles clipping + logging
should_step = debugger.step(global_step, loss.item(), optimizer)
if should_step:
optimizer.step()
scheduler.step()
# Log to TensorBoard
if global_step % 50 == 0:
logger.log_scalars({"train/loss": loss.item()}, global_step)
logger.log_gradient_norms(model, global_step)
logger.log_lr(optimizer, global_step)
epoch_losses.append(loss.item())
global_step += 1
# ---- Validation ----
model.eval()
val_losses, val_correct, val_total = [], 0, 0
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
logits = model(batch_x)
val_losses.append(criterion(logits, batch_y).item())
preds = logits.argmax(dim=-1)
val_correct += (preds == batch_y).sum().item()
val_total += batch_y.size(0)
avg_train = sum(epoch_losses) / len(epoch_losses)
avg_val = sum(val_losses) / len(val_losses)
val_acc = val_correct / max(val_total, 1)
logger.log_scalars({
"train/epoch_loss": avg_train,
"val/loss": avg_val,
"val/accuracy": val_acc,
}, global_step)
if epoch == 1 or epoch % 5 == 0:
print(f"\nEpoch {epoch}: train={avg_train:.4f}, val={avg_val:.4f}, "
f"val_acc={val_acc:.4f}")
print("\nTraining complete.")
print("Debugger summary:", debugger.summary())
logger.close()
Common Mistakes
:::danger Using model.eval() without model.train() in the training loop
The most common production bug. After validation with model.eval(), if you forget model.train() at the top of the next training epoch, dropout is disabled (no regularization) and BatchNorm uses frozen running statistics (wrong normalization). The model trains - but incorrectly. The symptom: validation metrics are erratic and do not improve even though training loss appears healthy. Always call model.train() explicitly at the start of each training epoch. Do not rely on any implicit mode switching.
:::
:::danger Not calling torch.no_grad() during validation
Without torch.no_grad(), PyTorch builds a computation graph for every validation forward pass. This wastes GPU memory proportional to the batch size and validation set size, can cause OOM errors mid-training, and makes the validation loop 20–50% slower. Additionally, any BatchNorm in train mode will update its running statistics using validation data - poisoning the normalization for the next training epoch.
:::
:::danger Learning rate too high for the final phase of training A LR that works well in the first half of training may cause oscillation in the second half as the model approaches a good minimum. The loss curve shows initially good convergence followed by increasing oscillation - loss improvements stop, and the curve bounces. This is especially common with cosine schedules that reach their floor too slowly. The fix: use a schedule with aggressive late-phase decay (cosine annealing) or add learning rate warmup + decay explicitly. :::
:::warning Gradient clipping too aggressive (max_norm too small)
If max_norm is set far below typical gradient norms, every single step is clipped. The optimizer is always moving in a scaled-down direction. Learning is slower and the model may not converge to the best solution. Check: if the pre-clip gradient norm is almost always larger than max_norm, you are over-clipping. Monitor clip_frequency - if it is above 80%, increase max_norm or reduce LR.
:::
:::warning Learning rate finder run on wrong data distribution The LR finder assumes the batches used during the search are representative of the full training distribution. If you run the LR finder on a biased subset (e.g., all from one class), the suggested LR may not transfer. Always use a DataLoader that shuffles data and is drawn from the same distribution as full training. :::
10-Point Pre-Training Checklist
| # | Check | How |
|---|---|---|
| 1 | Model outputs correct shape | print(model(sample).shape) |
| 2 | Loss decreases on step 1 | Compare loss before and after one grad step |
| 3 | Overfit-one-batch succeeds | Run overfit_one_batch function |
| 4 | No NaN in inputs | assert not torch.isnan(inputs).any() |
| 5 | Normalization applied | print(inputs.mean(), inputs.std()) ≈ (0, 1) |
| 6 | Train/val splits are disjoint | Log sample count per split |
| 7 | Augmentation not on val | Visually inspect val loader output |
| 8 | model.train() called before training | Explicit call, every epoch |
| 9 | Scheduler state saved in checkpoint | Check save/load logic |
| 10 | LR is logged every step | Confirm in experiment tracker |
YouTube Resources
| Title | Channel | Why Watch |
|---|---|---|
| A Recipe for Training Neural Networks | Andrej Karpathy (blog + talks) | The definitive practical guide from one of the best practitioners - covers every pathology systematically |
| Training Deep Neural Networks | Stanford CS231n | Academic treatment of loss landscapes, gradient flow, and optimizer dynamics |
| Weights & Biases - Debugging ML Models | Weights & Biases | End-to-end walkthrough of using W&B for gradient monitoring, hyperparameter search, and experiment tracking |
| How to Find a Good Learning Rate | Fast.ai | Leslie Smith's LR finder explained visually with live code |
| Loss Landscape Visualization | NeurIPS Tutorial | The Li et al. (2018) "Visualizing the Loss Landscape of Neural Nets" paper explained - flat vs sharp minima |
Interview Q&A
Q1: How do you diagnose whether a stuck training loss is due to a learning rate problem versus dead neurons?
Two distinct tests with distinct signatures. First, check gradient norms per layer. If all gradient norms are uniformly near zero across all layers, the LR has likely decayed too aggressively - the optimizer has stopped moving. If early-layer gradient norms are near zero but late-layer norms are healthy, the issue is vanishing gradients or dead neurons - the signal is not propagating backward.
Second, run the overfit-one-batch test with a fresh learning rate (e.g., 1e-3 for Adam). If the model memorizes the batch, the architecture is healthy and the scheduler is the problem. If it fails to memorize, there is a structural issue - wrong activation function, bad initialization, or disconnected computation graph. This test costs two minutes and resolves the ambiguity definitively.
Q2: Explain the loss landscape and why flat minima generalize better than sharp minima.
The loss landscape is the function over the full parameter space . During training, we are navigating this landscape using gradient descent. Not all minima have the same generalization properties.
A sharp minimum is a narrow valley - small perturbations to cause large increases in training loss. This corresponds to a solution that is highly sensitive to exact weight values. At test time, the distribution shift between train and test data is equivalent to a small perturbation in the effective loss landscape. If the minimum is sharp, the test loss will be much higher than the training loss - the model overfitted to the exact training distribution.
A flat minimum is a wide basin - large perturbations to cause only small increases in training loss. The model has learned features that are robust to small parameter changes. This translates to robustness to distribution shift at test time. SGD with small batch sizes naturally finds flatter minima because the gradient noise prevents convergence to narrow sharp valleys.
Q3: What is the learning rate finder and how do you interpret its output?
Leslie Smith's LR range test starts with a very small LR and exponentially increases it over 100–200 training steps, recording the loss at each step. The resulting plot of loss vs. LR (on a log scale) shows three regions: initially the loss barely changes (LR too small to matter), then the loss decreases rapidly (the good LR range), then the loss spikes or becomes NaN (LR too large, training diverges).
The recommended LR is at the steepest downward slope of the curve - the point of maximum negative slope in the loss-vs-LR plot. This is typically 10–100x below the divergence point. Using the LR at the very bottom of the loss curve often leads to instability because you are just below the divergence threshold. The steepest slope point is more conservative and more robust. For cyclical or one-cycle schedules, the max LR in the cycle is often set to this found value.
Q4: Walk through how you find the source of a NaN loss that appears after 500 training steps.
Step 1: Add torch.autograd.set_detect_anomaly(True) before the training loop and run for 510 steps. PyTorch will print a stack trace showing exactly which operation produced the first NaN. Slow (2–3x overhead) but definitive. Disable after finding the source.
Step 2: Register forward hooks on all modules that check output tensors for NaN. The hook that fires first identifies the culprit module. Use register_nan_detection_hooks(model) from the code in this lesson.
Step 3: Add validate_batch(inputs, targets, step) at the start of each training step. If the inputs themselves contain NaN (from a corrupt data file or preprocessing bug), the source is in the data pipeline, not the model.
Step 4: Check the loss function for numerically unstable operations: log of values near zero, division in attention by small denominators (without adding epsilon), or softmax with very large logits that overflow to inf before the exp normalization.
Step 5: If using mixed precision (torch.cuda.amp), check the GradScaler. If the scale drops to 1.0, gradients are overflowing in float16 (max representable is ~65504). Add explicit casts to float32 for numerically sensitive operations.
Q5: What does it mean when the pre-clip gradient norm is consistently much larger than your clip threshold? What does it mean when it is consistently much smaller?
Consistently larger (e.g., norm = 15 with max_norm = 1.0): gradient clipping is always active. Every update is in the clipped direction, not the true gradient direction. The optimizer is never using the full gradient signal. This can slow convergence and may indicate the LR is too high (large gradients → large steps → gradient clipping → effective LR is reduced but inconsistently). Possible fixes: reduce LR, improve initialization (better scaled initial weights produce smaller initial gradients), or increase max_norm if the task genuinely has large gradients (some RL and meta-learning settings).
Consistently smaller (e.g., norm = 0.001 with max_norm = 1.0): clipping is never active - it is providing no benefit. The gradients are naturally small. This is fine if the model is still converging. But if the model is stuck, vanishingly small gradients mean the LR is too low (the optimizer is moving with tiny steps that do not overcome numerical noise) or the model has reached a plateau or dead zone. Check whether loss is still decreasing; if not, increase LR or check for dead neurons.
Q6: Describe the batch normalization eval-mode bug and how to catch it.
BatchNorm has two distinct operating modes. In training mode, it normalizes each batch using that batch's mean and variance, and updates its running statistics. In eval mode, it normalizes using the stored running mean and variance from training - the current batch statistics are ignored.
The bug: after the validation loop (model.eval()), if model.train() is not called before the next training epoch, the model trains with eval-mode BatchNorm. The normalization uses stale running statistics rather than the current batch, producing incorrect normalized activations. Gradients computed from these activations are wrong.
The symptom is subtle: training loss continues to decrease (slowly and noisily), but validation loss behaves erratically - sometimes improving, sometimes degrading. The gap between train and val is inconsistent across epochs without clear trend. The model is being trained differently from how it is being evaluated.
Detection: explicitly assert that model.training == True at the start of each training epoch. Or use a context manager that enforces train mode for training and eval mode for validation. Production code should be written so that mode-switching is explicit and mandatory, not assumed.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Training Dynamics demo on the EngineersOfAI Playground - no code required.
:::
