Skip to main content

Training Dynamics and Debugging

The Production Scenario

It is 11 PM on the night before a quarterly review. Your model has been training for six hours on four A100s. The validation loss is 2.341. It was 2.340 two hours ago. It was 2.342 an hour ago. The training loss is still decreasing - slowly - but validation has not moved.

You open the experiment tracker. The loss curve is flat as a frozen lake for the last two hours. The gradient norm plot shows healthy numbers. No NaNs. No obvious spike. The learning rate is 3e-5, on a cosine schedule that has been running for 40% of the total budget.

Is this normal convergence? Is the model stuck in a saddle point with zero gradient curvature in most directions? Did the cosine schedule decay the LR below the effective threshold two hours ago, turning the optimizer into a statue? Is there a subtle bug in your validation data pipeline that is computing the same metric for every batch? Is the model silently overfitting to the training set while validation quietly degrades - and you are measuring validation loss on a stale checkpoint?

You have a meeting in eight hours and cannot afford another six-hour experiment.

Every one of these hypotheses has a specific diagnostic test. Every pathology has a signature in the loss curve, the gradient norms, the activation statistics, and the learning rate schedule. Senior ML engineers do not guess - they diagnose. They open the right plot, run the right check, and know within twenty minutes whether to restart or wait.

This lesson is that diagnostic toolkit. Every symptom. Every test. Every fix.

Why Training Dynamics Matter

Neural network training is not a simple convex optimization. The loss function is a high-dimensional, non-convex surface with:

  • Flat saddle points where gradients are near zero in all directions
  • Sharp narrow valleys where large LR causes oscillation but small LR causes slow progress
  • Flat wide basins where many equally good solutions live
  • Cliff-like loss spikes in transformers and RNNs when gradients suddenly explode

Understanding the geometry of this surface - and how your optimizer, learning rate, and model interact with it - is what separates engineers who debug in hours from engineers who debug in days.

The tools in this lesson are organized by the diagnostic question they answer:

  1. Loss landscape - what shape is the surface I am optimizing?
  2. Gradient flow - are gradients reaching all layers?
  3. Learning rate - is my LR in the right regime?
  4. NaN detection - where did the first NaN appear?
  5. Activation health - are neurons alive and well-distributed?
  6. Data pipeline - is what the model is seeing correct?

The Loss Landscape: Geometry and Generalization

Why Loss Landscape Shape Matters

The loss landscape is the function L(θ)L(\theta) over the high-dimensional parameter space. Finding a minimum is the goal of training, but not all minima are equal:

Sharp minima: narrow, deep basins. A small perturbation to the weights causes a large increase in loss. The model has memorized very specific features. When you deploy and real data differs slightly from training data, the model is already on the steep walls of the basin - validation loss is high.

Flat minima: wide, shallow basins. A large perturbation to the weights causes only a small increase in loss. The model has learned robust features. Slight distribution shift at inference keeps you in the bottom of the basin - validation loss stays low.

The Sharpness-Aware Minimization (SAM) paper formalized this: flat minima correlate with better generalization. In practice, this means:

  • Large batch training tends toward sharp minima (the gradient estimator is too accurate, converging to the nearest minimum rather than exploring)
  • Small batch training adds gradient noise that helps escape sharp minima and find flatter regions
  • High learning rates and warmup schedules allow the optimizer to traverse the landscape more broadly before settling

The loss flatness metric: measure how much the loss changes when you perturb all weights by a fixed-norm random vector. If the loss barely changes (flat minimum), generalization will be good. If it spikes (sharp minimum), consider reducing batch size or adding noise.

import torch
import torch.nn as nn
import numpy as np
import copy


def measure_sharpness(
model: nn.Module,
loss_fn,
dataloader,
epsilon: float = 0.05,
n_samples: int = 20,
device: str = "cuda",
) -> dict:
"""
Estimate loss landscape sharpness by perturbing weights randomly.
Lower sharpness_ratio → flatter minimum → better expected generalization.

Returns: dict with base_loss, mean_perturbed_loss, sharpness_ratio.
"""
model.eval()
model = model.to(device)

# Compute base loss on a fixed sample
base_losses = []
sample_batches = []
with torch.no_grad():
for i, (x, y) in enumerate(dataloader):
if i >= 4:
break
x, y = x.to(device), y.to(device)
loss = loss_fn(model(x), y)
base_losses.append(loss.item())
sample_batches.append((x, y))

base_loss = np.mean(base_losses)

# Save original parameters
original_state = copy.deepcopy(model.state_dict())

# Perturb and measure
perturbed_losses = []
for _ in range(n_samples):
# Add random perturbation scaled to epsilon * param_norm
with torch.no_grad():
for param in model.parameters():
noise = torch.randn_like(param)
# Normalize to unit norm, then scale by epsilon * param_norm
param_norm = param.norm().item()
noise = noise / (noise.norm() + 1e-8)
param.add_(noise * epsilon * param_norm)

# Measure loss with perturbation
sample_losses = []
with torch.no_grad():
for x, y in sample_batches:
loss = loss_fn(model(x), y)
sample_losses.append(loss.item())
perturbed_losses.append(np.mean(sample_losses))

# Restore original parameters
model.load_state_dict(original_state)

mean_perturbed = np.mean(perturbed_losses)
sharpness_ratio = (mean_perturbed - base_loss) / (abs(base_loss) + 1e-8)

return {
"base_loss": base_loss,
"mean_perturbed_loss": mean_perturbed,
"sharpness_ratio": sharpness_ratio,
"interpretation": "sharp" if sharpness_ratio > 0.1 else "flat",
}

Loss vs Accuracy Divergence

In classification, training loss and training accuracy should decrease and increase together. When they diverge, something is wrong:

Loss decreases, accuracy stalls: the model is adjusting confidence levels without changing its decisions. Often caused by label smoothing or class imbalance - the model learns to be less confident about easy predictions rather than correcting hard ones.

Accuracy increases, loss stalls: the model is making correct decisions but not with high enough confidence. Usually fine - it will continue to improve slowly.

Validation loss increases while validation accuracy stays flat: the model is becoming less calibrated (more overconfident on wrong predictions) even though the argmax is still correct. This is a warning sign of future degradation.

Training loss spikes then recovers: a bad batch with atypical examples. If this happens repeatedly, check for data corruption or extreme outliers. Add input validation or robust loss (Huber loss instead of MSE for regression).

Training Failure Diagnosis Tree

Loss Curve Pathologies

The shape of your loss curve is diagnostic. Every pathology has a specific cause.

1. HEALTHY CONVERGENCE
Loss │*
│ **
│ ***
│ ****
│ *******
└──────────────────── Steps

2. EXPLODING LOSS (LR too high or bad init)
Loss │ /
│ /
│ ____ /
│ / \/
│****
└──────────────────── Steps
Fix: reduce LR 10x, add grad clipping

3. STUCK LOSS (LR too small, dead neurons, bad init)
Loss │****
│ ****
│ *****
│ ******** <- barely moves
└──────────────────── Steps
Fix: raise LR, check dead neuron fraction

4. OSCILLATING (LR too high for late training)
Loss │ * * * * * * * *
│ * * * * * * * * *
│* * * * * * * * * *
└──────────────────── Steps
Fix: reduce LR or use LR warmup + decay

5. OVERFIT (good train, diverging val)
Loss │* (train) <- still dropping
│ **
│ ***___ <- val starts rising
│ * * (val)
└──────────────────── Steps
Fix: regularization, more data, early stopping

6. UNDERFIT (both losses high, similar)
Loss │*****
│ ****
│ ****
│ **** <- flattens early and high
└──────────────────── Steps
Fix: larger model, reduce regularization, raise LR
SymptomMost Likely CausePrimary Fix
Loss → NaN immediatelyLR too high, log(0)Reduce LR 10x, clamp inputs
Loss → NaN after N stepsExploding gradientsClip grad norm, reduce LR
Loss stuck from step 1Bad init, LR too smallHe init, increase LR
Loss stuck after initial dropLR decayed too fastLonger warmup, check scheduler
Loss oscillates throughoutLR too highReduce by 3–10x
Train OK, val divergesOverfittingRegularization, more data
Both losses high and equalUnderfitting / data bugLarger model, audit data
Val loss good but metric badMetric implementation bugVerify metric from scratch

Gradient Flow Analysis

Gradient flow is the circulatory system of learning. When gradients cannot reach early layers, those layers do not learn. When gradients are too large, parameters explode. Monitoring gradient norms per layer is the first diagnostic tool to reach for.

import torch
import torch.nn as nn
from collections import defaultdict
from typing import Dict, List
import matplotlib.pyplot as plt


def compute_gradient_norms(model: nn.Module) -> Dict[str, float]:
"""
Compute L2 gradient norm for every named parameter.
Call after loss.backward() and before optimizer.step().
"""
norms = {}
for name, param in model.named_parameters():
if param.grad is not None:
norms[name] = param.grad.detach().norm(2).item()
else:
norms[name] = 0.0 # no gradient - may indicate a bug
return norms


def print_gradient_summary(
model: nn.Module,
threshold_vanish: float = 1e-6,
threshold_explode: float = 10.0,
) -> Dict[str, float]:
"""
Summarize gradient norms with pass/fail status per layer.
Call after backward(), before optimizer.step().
"""
norms = compute_gradient_norms(model)

print(f"\n{'Layer':<50} {'Grad Norm':>12} {'Status':>12}")
print("-" * 76)
for name, norm in norms.items():
if norm == 0.0:
status = "NO GRAD"
elif norm < threshold_vanish:
status = "VANISHING"
elif norm > threshold_explode:
status = "EXPLODING"
else:
status = "OK"
print(f"{name:<50} {norm:>12.4e} {status:>12}")

all_norms = list(norms.values())
if all_norms:
print(f"\nSummary: min={min(all_norms):.2e}, "
f"max={max(all_norms):.2e}, "
f"mean={sum(all_norms)/len(all_norms):.2e}")

return norms


class GradientFlowLogger:
"""
Logs gradient norms per layer across training steps.
Enables plotting to visualize gradient health over time.

Usage:
logger = GradientFlowLogger(model)
# ... in training loop, after backward():
logger.record(step)
# ... after training:
logger.plot()
"""

def __init__(self, model: nn.Module):
self.model = model
self.history: Dict[str, List[tuple[int, float]]] = defaultdict(list)

def record(self, step: int):
for name, param in self.model.named_parameters():
if param.grad is not None:
norm = param.grad.detach().norm(2).item()
self.history[name].append((step, norm))

def plot(self, save_path: str = "gradient_flow.png", top_k: int = 6):
"""Plot gradient norms over time for the top_k most active layers."""
import matplotlib.pyplot as plt

# Select top_k layers by mean gradient norm
means = {
name: sum(n for _, n in vals) / max(len(vals), 1)
for name, vals in self.history.items()
}
top_layers = sorted(means, key=means.get, reverse=True)[:top_k]

fig, axes = plt.subplots(len(top_layers), 1,
figsize=(12, 3 * len(top_layers)), sharex=True)
if len(top_layers) == 1:
axes = [axes]

for ax, name in zip(axes, top_layers):
steps, norms = zip(*self.history[name])
ax.semilogy(steps, norms, linewidth=1.5)
ax.set_ylabel("Grad Norm (log)")
ax.set_title(name, fontsize=9)
ax.axhline(1e-6, color="red", linestyle="--", alpha=0.5, label="vanish threshold")
ax.axhline(10.0, color="orange", linestyle="--", alpha=0.5, label="explode threshold")
ax.legend(fontsize=7)

axes[-1].set_xlabel("Training Step")
fig.suptitle("Gradient Norms Per Layer Over Training")
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"Gradient flow plot saved to {save_path}")

What to look for in gradient norm plots:

  • Healthy: all layers show similar-order-magnitude gradient norms. Early layers have norms within 10–100x of late layers.
  • Vanishing: early layer norms are 1000x smaller than late layers. Catastrophic in deep networks without residual connections.
  • Exploding: any layer shows norms above 10. Will produce NaN weights within steps if not clipped.
  • Dead layers: gradient norm exactly 0.0 or near zero and constant. The layer is not participating in learning. Check for disconnected computational graph or all-dead ReLU units.
  • Oscillating norms: gradient norm bounces between large and small. Often caused by LR too high - the optimizer is overshooting and bouncing around.

Learning Rate: The Most Important Hyperparameter

Why LR Matters More Than Architecture

In practice, a suboptimal learning rate destroys training faster than a suboptimal architecture. Too high: the optimizer overshoots optima and oscillates or diverges. Too low: convergence is so slow that the model does not reach a good solution within the compute budget. The right LR is in a narrow range that varies per model, per dataset, per optimizer, and per batch size.

The linear scaling rule (Goyal et al., 2017): when multiplying batch size by kk, multiply LR by kk. Intuition: a larger batch gives a less noisy gradient estimate, so you can take larger steps. This is empirically validated for SGD on vision tasks. For Adam, the relationship is weaker but still directional.

Learning Rate Finder (Leslie Smith's Algorithm)

Rather than guessing LR, run a systematic search:

  1. Start with a very small LR (1e-7)
  2. Train for one mini-batch, record the loss
  3. Multiply LR by a constant factor (e.g., 1.1–1.3)
  4. Repeat for 100–200 steps until loss diverges
  5. Plot loss vs LR (on log scale)
  6. Choose the LR at the steepest downward slope - just before the loss starts rising

The LR at the bottom of the curve is often too high (the model is about to diverge). The steepest descent point (roughly 10x smaller than the divergence LR) is the sweet spot.

import torch
import torch.nn as nn
import copy
import math
import matplotlib.pyplot as plt
from typing import Callable


class LRFinder:
"""
Leslie Smith's Learning Rate Finder.
Trains for a range of LRs and records loss.
Use the LR at the steepest decline in the plot.

Reference: Smith (2015), "Cyclical Learning Rates for Training Neural Networks"
"""

def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
device: str = "cuda",
):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device

# Save initial state - restore after the search
self._model_state = copy.deepcopy(model.state_dict())
self._optim_state = copy.deepcopy(optimizer.state_dict())

self.lr_history: list[float] = []
self.loss_history: list[float] = []

def run(
self,
dataloader,
start_lr: float = 1e-7,
end_lr: float = 10.0,
n_steps: int = 100,
smooth_f: float = 0.05, # exponential smoothing factor
diverge_threshold: float = 5.0, # stop if loss > best * this
):
"""Run the LR range test."""
# Set starting LR
for group in self.optimizer.param_groups:
group["lr"] = start_lr

lr_factor = (end_lr / start_lr) ** (1.0 / n_steps)

best_loss = float("inf")
avg_loss = 0.0

data_iter = iter(dataloader)
self.model.train()

for step in range(n_steps):
try:
inputs, targets = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
inputs, targets = next(data_iter)

inputs = inputs.to(self.device)
targets = targets.to(self.device)

self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
loss.backward()
self.optimizer.step()

# Exponential smoothing of loss
current_loss = loss.item()
avg_loss = smooth_f * current_loss + (1 - smooth_f) * avg_loss
smoothed = avg_loss / (1 - (1 - smooth_f) ** (step + 1)) # bias correction

current_lr = self.optimizer.param_groups[0]["lr"]
self.lr_history.append(current_lr)
self.loss_history.append(smoothed)

if smoothed < best_loss:
best_loss = smoothed

# Divergence check
if smoothed > diverge_threshold * best_loss:
print(f"Loss diverged at step {step}, LR={current_lr:.2e}")
break

# Update LR for next step
for group in self.optimizer.param_groups:
group["lr"] *= lr_factor

# Restore original model and optimizer state
self.model.load_state_dict(self._model_state)
self.optimizer.load_state_dict(self._optim_state)

return self

def suggest_lr(self) -> float:
"""
Suggest the learning rate at the steepest loss decline.
This is the point of maximum negative gradient of the loss-lr curve.
"""
if len(self.loss_history) < 3:
return self.lr_history[0] if self.lr_history else 1e-3

# Compute numerical gradient of loss w.r.t. log(lr)
log_lrs = [math.log(lr) for lr in self.lr_history]
slopes = []
for i in range(1, len(self.loss_history) - 1):
slope = (self.loss_history[i + 1] - self.loss_history[i - 1]) / (
log_lrs[i + 1] - log_lrs[i - 1]
)
slopes.append((slope, self.lr_history[i]))

# Steepest descent = most negative slope
best_slope, best_lr = min(slopes, key=lambda x: x[0])
return best_lr

def plot(self, save_path: str = "lr_finder.png"):
"""Plot the loss-vs-LR curve."""
fig, ax = plt.subplots(figsize=(10, 6))
ax.semilogx(self.lr_history, self.loss_history, linewidth=2)
suggested = self.suggest_lr()
ax.axvline(suggested, color="red", linestyle="--",
label=f"Suggested LR: {suggested:.2e}")
ax.set_xlabel("Learning Rate (log scale)")
ax.set_ylabel("Loss (smoothed)")
ax.set_title("Learning Rate Finder - Choose LR at Steepest Descent")
ax.legend()
ax.grid(True, alpha=0.3)
plt.savefig(save_path, dpi=150, bbox_inches="tight")
print(f"LR finder plot saved to {save_path}")
print(f"Suggested LR: {suggested:.2e}")


# Usage
# model = MyModel()
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-7)
# finder = LRFinder(model, optimizer, nn.CrossEntropyLoss(), device="cuda")
# finder.run(train_loader, start_lr=1e-7, end_lr=1.0, n_steps=100)
# finder.plot()
# print(f"Use LR: {finder.suggest_lr():.2e}")

Cyclical Learning Rates and Warmup

Linear Warmup: start with a very small LR and linearly increase to the target LR over the first 5–10% of training. This prevents early large-gradient steps from corrupting the initialization before the optimizer has calibrated its momentum estimates. Critical for transformers - without warmup, attention weights can diverge in the first 100 steps.

Cosine Annealing: after warmup, decay the LR following a cosine curve from the peak to a small floor value (e.g., 1/10th of peak). The cosine schedule is smoother than step decay - the gradual decrease allows the model to settle into a flat minimum rather than being abruptly stopped.

Cyclical LR (CLR): instead of monotonically decreasing, cycle the LR between a minimum and maximum. This can escape local minima (LR spikes kick the optimizer out) and find flatter minima (LR valleys allow settling). Leslie Smith introduced this in 2015, and it is still used in One-Cycle training.

import torch
import torch.optim as optim
import math


def get_cosine_schedule_with_warmup(
optimizer: optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
min_lr_ratio: float = 0.1,
) -> optim.lr_scheduler.LambdaLR:
"""
Linear warmup followed by cosine annealing.
Standard schedule for transformer training.

num_warmup_steps: steps for linear warmup phase
num_training_steps: total training steps
min_lr_ratio: LR at end of cosine decay as fraction of peak LR
"""
def lr_lambda(current_step: int) -> float:
if current_step < num_warmup_steps:
# Linear warmup: from near-zero to 1.0
return float(current_step) / float(max(1, num_warmup_steps))

# Cosine annealing from 1.0 to min_lr_ratio
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine_decay

return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def get_one_cycle_schedule(
optimizer: optim.Optimizer,
max_lr: float,
total_steps: int,
pct_start: float = 0.3,
) -> optim.lr_scheduler.OneCycleLR:
"""
One Cycle LR policy (Leslie Smith, 2019).
Increases LR linearly to max, then decreases with cosine annealing.
Combined with momentum scheduling: high momentum at low LR, low at high LR.
"""
return optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=max_lr,
total_steps=total_steps,
pct_start=pct_start, # fraction of steps for warmup phase
anneal_strategy="cos",
div_factor=25.0, # initial_lr = max_lr / 25
final_div_factor=1e4, # final_lr = initial_lr / 1e4
)


# Warmup schedule usage example
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# scheduler = get_cosine_schedule_with_warmup(
# optimizer,
# num_warmup_steps=500,
# num_training_steps=10000,
# )
# for step, batch in enumerate(dataloader):
# loss = train_step(batch)
# loss.backward()
# optimizer.step()
# scheduler.step()
# optimizer.zero_grad()

Gradient Clipping: Defending Against Explosions

Gradient clipping prevents exploding gradients from corrupting the parameter update. Two clipping strategies:

Clip by global norm (recommended): compute the L2 norm of all gradients combined, then scale down all gradients proportionally if the norm exceeds a threshold. This preserves the direction of the gradient while limiting its magnitude.

if g2>τ:gτg2g\text{if } \|\mathbf{g}\|_2 > \tau: \quad \mathbf{g} \leftarrow \frac{\tau}{\|\mathbf{g}\|_2} \mathbf{g}

Clip by value: independently clamp each gradient element to [c,c][-c, c]. Simpler but changes the direction of the gradient. Rarely used in modern practice.

import torch
import torch.nn as nn


def clip_and_log_gradients(
model: nn.Module,
max_norm: float = 1.0,
log_clipping: bool = True,
) -> float:
"""
Clip gradients by global norm and optionally log when clipping is active.
Returns the gradient norm BEFORE clipping.

The returned pre-clip norm is valuable for monitoring:
- Consistently large (>> max_norm): LR may be too high or init is poor
- Occasionally large: normal, clipping is working as intended
- Always small (<< max_norm): clipping is never active; could increase LR
"""
total_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=max_norm
)
norm_value = total_norm.item()

if log_clipping and norm_value > max_norm:
clip_ratio = max_norm / norm_value
print(f" [GradClip] Pre-clip norm: {norm_value:.4f}, "
f"clip ratio: {clip_ratio:.4f}")

return norm_value


# In training loop:
# loss.backward()
# pre_clip_norm = clip_and_log_gradients(model, max_norm=1.0)
# log_metric("grad/pre_clip_norm", pre_clip_norm, step=global_step)
# optimizer.step()

:::tip When to Set max_norm For most tasks with Adam: max_norm=1.0. For RNNs and LSTMs: max_norm=0.1–0.5 (they are more prone to exploding gradients). For transformers: max_norm=1.0 is standard. If the pre-clip norm is almost always less than max_norm, the clipping is not doing anything - this is fine, it just means the training is already stable. :::

NaN Detection and Hunting

NaN values propagate silently. One NaN in any forward computation corrupts all downstream values and all gradients in the backward pass. After one optimizer step, all parameters that received NaN gradients become NaN. Within two steps, the entire model can be corrupted.

Common NaN sources:

SourceCauseFix
log(0)Cross-entropy on zero probabilitytorch.clamp(x, min=1e-8) before log
0 / 0Division by zero in normalizationAdd epsilon to denominator
sqrt(0) gradientGradient of sqrt at exactly zerosqrt(x + 1e-8)
Large logitsexp(700) overflows to infScale inputs, clip logits
fp16 overflowValues exceed float16 range (~65504)Use GradScaler, reduce LR
Exploding gradientsLR too high or no clippingClip gradients
NaN in input dataCorrupt data pipelineValidate inputs before forward
In-place op on graphTensor modified after it was savedRemove in-place ops
import torch
import torch.nn as nn
from typing import List


def find_nan_in_model(model: nn.Module) -> List[str]:
"""Find all parameters and buffers containing NaN or Inf."""
problems = []
for name, param in model.named_parameters():
if torch.isnan(param).any():
problems.append(f"NaN in parameter: {name}")
if torch.isinf(param).any():
problems.append(f"Inf in parameter: {name}")
for name, buf in model.named_buffers():
if torch.isnan(buf).any():
problems.append(f"NaN in buffer: {name}")
return problems


def register_nan_detection_hooks(model: nn.Module) -> List:
"""
Register forward hooks that raise on the first NaN produced by any module.
Pinpoints exactly which operation produces the first NaN.

Usage:
hooks = register_nan_detection_hooks(model)
try:
output = model(x)
except RuntimeError as e:
print(e)
finally:
for h in hooks: h.remove()
"""
hooks = []

def make_hook(name: str):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
if torch.isnan(output).any():
# Inspect inputs for additional context
input_nan = [
torch.isnan(inp).any().item()
for inp in input
if isinstance(inp, torch.Tensor)
]
raise RuntimeError(
f"NaN detected in OUTPUT of module '{name}' "
f"(type: {type(module).__name__}).\n"
f"Input had NaN: {input_nan}\n"
f"Output shape: {output.shape}, "
f"NaN count: {torch.isnan(output).sum().item()}"
)
return hook

for name, module in model.named_modules():
hooks.append(module.register_forward_hook(make_hook(name)))

return hooks


def validate_batch(
inputs: torch.Tensor,
targets: torch.Tensor,
step: int,
input_clamp: float = 100.0,
) -> bool:
"""
Validate a training batch before the forward pass.
Returns True if batch is clean, False if problems found.
"""
issues = []

if torch.isnan(inputs).any():
issues.append(f"NaN in inputs (count: {torch.isnan(inputs).sum().item()})")
if torch.isinf(inputs).any():
issues.append(f"Inf in inputs")
if inputs.abs().max().item() > input_clamp:
issues.append(f"Extreme input values: max_abs={inputs.abs().max().item():.2f}")
if torch.isnan(targets).any():
issues.append(f"NaN in targets")

if issues:
print(f"Step {step} batch validation FAILED:")
for issue in issues:
print(f" - {issue}")
return False

return True


def loss_with_nan_check(
loss_value: torch.Tensor,
step: int,
model: nn.Module,
raise_on_nan: bool = False,
) -> float:
"""Safely extract loss value, checking for NaN and logging model state."""
value = loss_value.item()

if not (value == value): # NaN check: NaN != NaN is always True
print(f"\nStep {step}: NaN loss detected!")
problems = find_nan_in_model(model)
if problems:
print("NaN found in model parameters:")
for p in problems:
print(f" {p}")
else:
print("Model parameters are clean - NaN came from this batch")

if raise_on_nan:
raise ValueError(f"NaN loss at step {step}")

return value

Complete Debugging Checklist (Runnable Code)

import torch
import torch.nn as nn
from typing import Iterator, Dict, List, Optional


def overfit_one_batch(
model: nn.Module,
sample_batch: tuple,
criterion: nn.Module,
lr: float = 1e-3,
max_steps: int = 500,
target_loss: float = 0.01,
device: str = "cpu",
) -> Dict:
"""
Test whether the model can memorize a single batch.
A healthy model and loss function should drive loss near zero.
If it cannot: architectural bug, wrong loss, disconnected graph,
or initialization problem - no amount of training will help.

Returns dict with success, final_loss, loss_history.
"""
model = model.to(device)
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
inputs, targets = sample_batch
inputs = inputs.to(device)
targets = targets.to(device)

history = []
print(f"Overfit-one-batch test ({max_steps} max steps)...")

for step in range(max_steps):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

loss_val = loss.item()
history.append(loss_val)

if step % 50 == 0:
print(f" Step {step:>4}: loss = {loss_val:.6f}")

if loss_val != loss_val: # NaN
print(" FAILED: NaN loss - check loss function or architecture")
return {"success": False, "final_loss": float("nan"), "history": history}

if loss_val < target_loss:
print(f" SUCCESS: memorized in {step + 1} steps (loss={loss_val:.6f})")
return {"success": True, "final_loss": loss_val, "history": history}

print(f" FAILED: did not reach target {target_loss} after {max_steps} steps")
print(f" Final loss: {history[-1]:.6f}")
print(" Likely causes: wrong loss function, bad init, disconnected graph")
return {"success": False, "final_loss": history[-1], "history": history}


class TrainingDebugger:
"""
All-in-one debugging harness for the training loop.
Monitors: loss trends, gradient norms, clipping activity, NaN detection.

Usage:
debugger = TrainingDebugger(model, log_every=100)

for step, (x, y) in enumerate(loader):
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
debugger.step(step, loss.item(), optimizer)
optimizer.step()
"""

def __init__(
self,
model: nn.Module,
log_every: int = 100,
clip_grad_norm: float = 1.0,
window_size: int = 200,
):
self.model = model
self.log_every = log_every
self.clip_grad_norm = clip_grad_norm
self.window_size = window_size

self.loss_history: List[float] = []
self.grad_norm_history: List[float] = []
self.clip_events: List[int] = [] # steps where clipping was active
self._consecutive_nan: int = 0

def step(
self,
global_step: int,
loss: float,
optimizer: torch.optim.Optimizer,
extra_metrics: Optional[Dict[str, float]] = None,
):
"""
Call after loss.backward(), before optimizer.step().
Handles: NaN detection, gradient clipping, logging.
"""
# NaN guard - skip optimizer step if loss is NaN
if loss != loss:
self._consecutive_nan += 1
print(f"\n[Step {global_step}] NaN loss ({self._consecutive_nan} consecutive)")
if self._consecutive_nan >= 3:
problems = find_nan_in_model(self.model)
if problems:
print("NaN propagated to parameters:")
for p in problems:
print(f" {p}")
return False # signal to skip optimizer step

self._consecutive_nan = 0
self.loss_history.append(loss)

# Gradient clipping and norm tracking
pre_clip_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=self.clip_grad_norm
).item()
self.grad_norm_history.append(pre_clip_norm)

if pre_clip_norm > self.clip_grad_norm * 0.99:
self.clip_events.append(global_step)

# Periodic logging
if global_step % self.log_every == 0:
self._log(global_step, loss, pre_clip_norm, optimizer, extra_metrics)

return True # signal to proceed with optimizer step

def _log(
self,
step: int,
loss: float,
grad_norm: float,
optimizer: torch.optim.Optimizer,
extra: Optional[Dict] = None,
):
lr = optimizer.param_groups[0]["lr"]

# Trend over recent window
window = self.loss_history[-self.window_size:]
if len(window) > 10:
trend = window[-1] - window[0]
trend_str = f"{trend:+.4f}"
else:
trend_str = "N/A"

clip_pct = len(self.clip_events) / max(len(self.grad_norm_history), 1) * 100
clipped_this_step = (grad_norm >= self.clip_grad_norm * 0.99)

print(
f"\n[Step {step:>7}] "
f"loss={loss:.4f} "
f"trend({self.window_size})={trend_str} | "
f"grad_norm={grad_norm:.4f}{'[CLIPPED]' if clipped_this_step else ''} "
f"clip_freq={clip_pct:.1f}% | "
f"lr={lr:.2e}"
)

# Warnings
if len(window) > 50 and abs(trend) < 1e-4:
print(" WARNING: Loss has barely moved in last 200 steps.")
print(" Possible: LR decayed too low, dead neurons, plateau.")

if clip_pct > 50:
print(" WARNING: Clipping >50% of steps - LR may be too high.")

if extra:
for k, v in extra.items():
print(f" {k}: {v:.4f}")

def summary(self) -> Dict:
"""Return summary statistics for the training run."""
if not self.loss_history:
return {}
return {
"final_loss": self.loss_history[-1],
"min_loss": min(self.loss_history),
"mean_grad_norm": sum(self.grad_norm_history) / max(len(self.grad_norm_history), 1),
"clip_frequency": len(self.clip_events) / max(len(self.grad_norm_history), 1),
"nan_events": self._consecutive_nan,
}

TensorBoard and Weights & Biases Integration

Logging training metrics to a dashboard is not optional in production - it is how you catch problems before they waste hours of compute.

import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter


class TrainingLogger:
"""
Dual-backend logger: TensorBoard + optional W&B.
Logs: losses, gradient norms, learning rates, activation statistics.
"""

def __init__(
self,
log_dir: str = "runs/experiment",
use_wandb: bool = False,
wandb_project: str = "my-project",
wandb_config: dict = None,
):
self.writer = SummaryWriter(log_dir)
self.use_wandb = use_wandb

if use_wandb:
import wandb
wandb.init(project=wandb_project, config=wandb_config or {})

def log_scalars(self, metrics: dict, step: int):
"""Log a dict of scalar metrics."""
for key, value in metrics.items():
self.writer.add_scalar(key, value, step)

if self.use_wandb:
import wandb
wandb.log({**metrics, "step": step})

def log_gradient_norms(self, model: nn.Module, step: int):
"""Log per-layer gradient norms and total gradient norm."""
norms = {}
total_sq = 0.0

for name, param in model.named_parameters():
if param.grad is not None:
norm = param.grad.detach().norm(2).item()
norms[f"grad_norm/{name}"] = norm
total_sq += norm ** 2

total_norm = total_sq ** 0.5
norms["grad_norm/total"] = total_norm

self.log_scalars(norms, step)

def log_activation_histogram(
self, model: nn.Module, inputs: torch.Tensor, step: int
):
"""
Log activation histograms for each layer.
Useful for spotting saturation, dead neurons, or distribution collapse.
"""
activation_cache = {}

def make_hook(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
activation_cache[name] = output.detach().cpu()
return hook

handles = []
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
handles.append(module.register_forward_hook(make_hook(name)))

model.eval()
with torch.no_grad():
model(inputs)

for h in handles:
h.remove()

for name, acts in activation_cache.items():
self.writer.add_histogram(f"activations/{name}", acts, step)

def log_weight_histograms(self, model: nn.Module, step: int):
"""Log weight distributions per layer - watch for distribution collapse."""
for name, param in model.named_parameters():
self.writer.add_histogram(f"weights/{name}", param.detach().cpu(), step)
if param.grad is not None:
self.writer.add_histogram(f"grads/{name}", param.grad.detach().cpu(), step)

def log_lr(self, optimizer: torch.optim.Optimizer, step: int):
"""Log learning rate for each parameter group."""
for i, group in enumerate(optimizer.param_groups):
self.writer.add_scalar(f"lr/group_{i}", group["lr"], step)

def close(self):
self.writer.close()
if self.use_wandb:
import wandb
wandb.finish()


# --- Putting it all together: production training loop ---

def production_train_loop(
model: nn.Module,
train_loader,
val_loader,
optimizer: torch.optim.Optimizer,
scheduler,
criterion: nn.Module,
n_epochs: int,
log_dir: str = "runs/experiment",
device: str = "cuda",
):
logger = TrainingLogger(log_dir=log_dir)
debugger = TrainingDebugger(model, log_every=100, clip_grad_norm=1.0)
global_step = 0

for epoch in range(1, n_epochs + 1):
# ---- Training ----
model.train()
epoch_losses = []

for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)

# Validate inputs in first epoch
if epoch == 1:
validate_batch(batch_x, batch_y, global_step)

optimizer.zero_grad(set_to_none=True)
logits = model(batch_x)
loss = criterion(logits, batch_y)
loss.backward()

# Debugger handles clipping + logging
should_step = debugger.step(global_step, loss.item(), optimizer)
if should_step:
optimizer.step()

scheduler.step()

# Log to TensorBoard
if global_step % 50 == 0:
logger.log_scalars({"train/loss": loss.item()}, global_step)
logger.log_gradient_norms(model, global_step)
logger.log_lr(optimizer, global_step)

epoch_losses.append(loss.item())
global_step += 1

# ---- Validation ----
model.eval()
val_losses, val_correct, val_total = [], 0, 0

with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
logits = model(batch_x)
val_losses.append(criterion(logits, batch_y).item())
preds = logits.argmax(dim=-1)
val_correct += (preds == batch_y).sum().item()
val_total += batch_y.size(0)

avg_train = sum(epoch_losses) / len(epoch_losses)
avg_val = sum(val_losses) / len(val_losses)
val_acc = val_correct / max(val_total, 1)

logger.log_scalars({
"train/epoch_loss": avg_train,
"val/loss": avg_val,
"val/accuracy": val_acc,
}, global_step)

if epoch == 1 or epoch % 5 == 0:
print(f"\nEpoch {epoch}: train={avg_train:.4f}, val={avg_val:.4f}, "
f"val_acc={val_acc:.4f}")

print("\nTraining complete.")
print("Debugger summary:", debugger.summary())
logger.close()

Common Mistakes

:::danger Using model.eval() without model.train() in the training loop The most common production bug. After validation with model.eval(), if you forget model.train() at the top of the next training epoch, dropout is disabled (no regularization) and BatchNorm uses frozen running statistics (wrong normalization). The model trains - but incorrectly. The symptom: validation metrics are erratic and do not improve even though training loss appears healthy. Always call model.train() explicitly at the start of each training epoch. Do not rely on any implicit mode switching. :::

:::danger Not calling torch.no_grad() during validation Without torch.no_grad(), PyTorch builds a computation graph for every validation forward pass. This wastes GPU memory proportional to the batch size and validation set size, can cause OOM errors mid-training, and makes the validation loop 20–50% slower. Additionally, any BatchNorm in train mode will update its running statistics using validation data - poisoning the normalization for the next training epoch. :::

:::danger Learning rate too high for the final phase of training A LR that works well in the first half of training may cause oscillation in the second half as the model approaches a good minimum. The loss curve shows initially good convergence followed by increasing oscillation - loss improvements stop, and the curve bounces. This is especially common with cosine schedules that reach their floor too slowly. The fix: use a schedule with aggressive late-phase decay (cosine annealing) or add learning rate warmup + decay explicitly. :::

:::warning Gradient clipping too aggressive (max_norm too small) If max_norm is set far below typical gradient norms, every single step is clipped. The optimizer is always moving in a scaled-down direction. Learning is slower and the model may not converge to the best solution. Check: if the pre-clip gradient norm is almost always larger than max_norm, you are over-clipping. Monitor clip_frequency - if it is above 80%, increase max_norm or reduce LR. :::

:::warning Learning rate finder run on wrong data distribution The LR finder assumes the batches used during the search are representative of the full training distribution. If you run the LR finder on a biased subset (e.g., all from one class), the suggested LR may not transfer. Always use a DataLoader that shuffles data and is drawn from the same distribution as full training. :::

10-Point Pre-Training Checklist

#CheckHow
1Model outputs correct shapeprint(model(sample).shape)
2Loss decreases on step 1Compare loss before and after one grad step
3Overfit-one-batch succeedsRun overfit_one_batch function
4No NaN in inputsassert not torch.isnan(inputs).any()
5Normalization appliedprint(inputs.mean(), inputs.std()) ≈ (0, 1)
6Train/val splits are disjointLog sample count per split
7Augmentation not on valVisually inspect val loader output
8model.train() called before trainingExplicit call, every epoch
9Scheduler state saved in checkpointCheck save/load logic
10LR is logged every stepConfirm in experiment tracker

YouTube Resources

TitleChannelWhy Watch
A Recipe for Training Neural NetworksAndrej Karpathy (blog + talks)The definitive practical guide from one of the best practitioners - covers every pathology systematically
Training Deep Neural NetworksStanford CS231nAcademic treatment of loss landscapes, gradient flow, and optimizer dynamics
Weights & Biases - Debugging ML ModelsWeights & BiasesEnd-to-end walkthrough of using W&B for gradient monitoring, hyperparameter search, and experiment tracking
How to Find a Good Learning RateFast.aiLeslie Smith's LR finder explained visually with live code
Loss Landscape VisualizationNeurIPS TutorialThe Li et al. (2018) "Visualizing the Loss Landscape of Neural Nets" paper explained - flat vs sharp minima

Interview Q&A

Q1: How do you diagnose whether a stuck training loss is due to a learning rate problem versus dead neurons?

Two distinct tests with distinct signatures. First, check gradient norms per layer. If all gradient norms are uniformly near zero across all layers, the LR has likely decayed too aggressively - the optimizer has stopped moving. If early-layer gradient norms are near zero but late-layer norms are healthy, the issue is vanishing gradients or dead neurons - the signal is not propagating backward.

Second, run the overfit-one-batch test with a fresh learning rate (e.g., 1e-3 for Adam). If the model memorizes the batch, the architecture is healthy and the scheduler is the problem. If it fails to memorize, there is a structural issue - wrong activation function, bad initialization, or disconnected computation graph. This test costs two minutes and resolves the ambiguity definitively.

Q2: Explain the loss landscape and why flat minima generalize better than sharp minima.

The loss landscape is the function L(θ)L(\theta) over the full parameter space θRn\theta \in \mathbb{R}^n. During training, we are navigating this landscape using gradient descent. Not all minima have the same generalization properties.

A sharp minimum is a narrow valley - small perturbations to θ\theta cause large increases in training loss. This corresponds to a solution that is highly sensitive to exact weight values. At test time, the distribution shift between train and test data is equivalent to a small perturbation in the effective loss landscape. If the minimum is sharp, the test loss will be much higher than the training loss - the model overfitted to the exact training distribution.

A flat minimum is a wide basin - large perturbations to θ\theta cause only small increases in training loss. The model has learned features that are robust to small parameter changes. This translates to robustness to distribution shift at test time. SGD with small batch sizes naturally finds flatter minima because the gradient noise prevents convergence to narrow sharp valleys.

Q3: What is the learning rate finder and how do you interpret its output?

Leslie Smith's LR range test starts with a very small LR and exponentially increases it over 100–200 training steps, recording the loss at each step. The resulting plot of loss vs. LR (on a log scale) shows three regions: initially the loss barely changes (LR too small to matter), then the loss decreases rapidly (the good LR range), then the loss spikes or becomes NaN (LR too large, training diverges).

The recommended LR is at the steepest downward slope of the curve - the point of maximum negative slope in the loss-vs-LR plot. This is typically 10–100x below the divergence point. Using the LR at the very bottom of the loss curve often leads to instability because you are just below the divergence threshold. The steepest slope point is more conservative and more robust. For cyclical or one-cycle schedules, the max LR in the cycle is often set to this found value.

Q4: Walk through how you find the source of a NaN loss that appears after 500 training steps.

Step 1: Add torch.autograd.set_detect_anomaly(True) before the training loop and run for 510 steps. PyTorch will print a stack trace showing exactly which operation produced the first NaN. Slow (2–3x overhead) but definitive. Disable after finding the source.

Step 2: Register forward hooks on all modules that check output tensors for NaN. The hook that fires first identifies the culprit module. Use register_nan_detection_hooks(model) from the code in this lesson.

Step 3: Add validate_batch(inputs, targets, step) at the start of each training step. If the inputs themselves contain NaN (from a corrupt data file or preprocessing bug), the source is in the data pipeline, not the model.

Step 4: Check the loss function for numerically unstable operations: log of values near zero, division in attention by small denominators (without adding epsilon), or softmax with very large logits that overflow to inf before the exp normalization.

Step 5: If using mixed precision (torch.cuda.amp), check the GradScaler. If the scale drops to 1.0, gradients are overflowing in float16 (max representable is ~65504). Add explicit casts to float32 for numerically sensitive operations.

Q5: What does it mean when the pre-clip gradient norm is consistently much larger than your clip threshold? What does it mean when it is consistently much smaller?

Consistently larger (e.g., norm = 15 with max_norm = 1.0): gradient clipping is always active. Every update is in the clipped direction, not the true gradient direction. The optimizer is never using the full gradient signal. This can slow convergence and may indicate the LR is too high (large gradients → large steps → gradient clipping → effective LR is reduced but inconsistently). Possible fixes: reduce LR, improve initialization (better scaled initial weights produce smaller initial gradients), or increase max_norm if the task genuinely has large gradients (some RL and meta-learning settings).

Consistently smaller (e.g., norm = 0.001 with max_norm = 1.0): clipping is never active - it is providing no benefit. The gradients are naturally small. This is fine if the model is still converging. But if the model is stuck, vanishingly small gradients mean the LR is too low (the optimizer is moving with tiny steps that do not overcome numerical noise) or the model has reached a plateau or dead zone. Check whether loss is still decreasing; if not, increase LR or check for dead neurons.

Q6: Describe the batch normalization eval-mode bug and how to catch it.

BatchNorm has two distinct operating modes. In training mode, it normalizes each batch using that batch's mean and variance, and updates its running statistics. In eval mode, it normalizes using the stored running mean and variance from training - the current batch statistics are ignored.

The bug: after the validation loop (model.eval()), if model.train() is not called before the next training epoch, the model trains with eval-mode BatchNorm. The normalization uses stale running statistics rather than the current batch, producing incorrect normalized activations. Gradients computed from these activations are wrong.

The symptom is subtle: training loss continues to decrease (slowly and noisily), but validation loss behaves erratically - sometimes improving, sometimes degrading. The gap between train and val is inconsistent across epochs without clear trend. The model is being trained differently from how it is being evaluated.

Detection: explicitly assert that model.training == True at the start of each training epoch. Or use a context manager that enforces train mode for training and eval mode for validation. Production code should be written so that mode-switching is explicit and mandatory, not assumed.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Training Dynamics demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.