Skip to main content

Learning Rate Scheduling

The Real Interview Moment

An ML engineer at a computer vision company trains a ResNet-50 on an internal product defect dataset. She sets a constant learning rate of 1e-3, trains for 20 epochs, and watches validation accuracy climb steadily to 87%. Solid, she thinks. She ships the model.

A colleague on another team runs the same experiment - same architecture, same dataset, same batch size, same optimizer. The only change: a cosine annealing schedule that starts at 1e-3 and smoothly decays to 1e-6. His final validation accuracy: 91.3%.

Four percentage points from a scheduler. The model weights, loss function, and data pipeline were identical. No architecture change. No data augmentation. No hyperparameter search. Just scheduling.

The difference comes down to how learning rate interacts with the geometry of the loss landscape. A constant rate that is large enough for early exploration is too large for fine-grained convergence later. Near a good minimum, the loss surface resembles a narrow valley - large steps keep bouncing off the walls instead of settling to the bottom. This lesson explains the geometry behind that failure, every major scheduling strategy, how to pick one, and how to implement each correctly in PyTorch.

Why Constant Learning Rate Fails

Think of the loss landscape as mountainous terrain and the optimizer as a hiker trying to reach the lowest valley. A large step size lets you cover ground quickly and explore broadly, but once near a good valley, large steps cause you to leap over the minimum or bounce between the walls. A small step is precise but agonizingly slow when far from any minimum.

The loss landscape has a hierarchical structure:

  • Early training: parameters are far from any minimum, gradients are large and informative, a relatively large LR makes rapid progress. Loss decreases steeply.
  • Mid training: gradients shrink as you approach a basin, a moderate LR continues useful descent. Loss decreases more slowly.
  • Late training: you are near a minimum, gradients are tiny and noisy, a large LR causes oscillation rather than settling.

A constant LR must be a compromise - small enough to not blow up late training, which makes it unnecessarily slow early on. Scheduling provides the right rate at each training phase.

Step Decay: Simple and Interpretable

The simplest non-constant schedule: reduce LR by a fixed factor every kk epochs.

ηt=η0γt/k\eta_t = \eta_0 \cdot \gamma^{\lfloor t / k \rfloor}

Where η0\eta_0 is the initial LR, γ(0,1)\gamma \in (0, 1) is the decay factor (typically 0.1 or 0.5), and kk is the step size in epochs.

Typical ImageNet training: γ=0.1\gamma = 0.1, k=30k = 30 epochs in a 90-epoch run - multiply by 0.1 at epochs 30 and 60. This is how the original AlexNet and ResNet papers were trained.

The problem: abrupt LR drops create discontinuities. The optimizer has learned a trajectory at a given rate and suddenly must adjust to a 10x smaller step. Training loss curves show visible spikes at milestone epochs - brief instability as the optimizer re-adapts to the new scale.

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, MultiStepLR

model = ...
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

# Drop LR by 10x every 30 epochs
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

for epoch in range(90):
train_one_epoch(model, optimizer)
scheduler.step() # call AFTER the epoch
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch+1}: LR = {current_lr:.6f}")
# epochs 1–30: LR = 0.1
# epochs 31–60: LR = 0.01
# epochs 61–90: LR = 0.001

# MultiStepLR: specify exact drop epochs (more common)
scheduler_multi = MultiStepLR(optimizer, milestones=[30, 60, 80], gamma=0.1)

Exponential Decay: Smooth but Aggressive

ηt=η0γt\eta_t = \eta_0 \cdot \gamma^t

Each epoch, the LR is multiplied by γ\gamma (typically 0.9 to 0.99). Smooth decay, no discontinuities. The problem: exponential decay is aggressive - after enough steps, LR drives so low that learning stops. For a 200-epoch run with γ=0.95\gamma = 0.95: 0.952001.7×1050.95^{200} \approx 1.7 \times 10^{-5}. If you started at 0.1, you end at 1.7×1061.7 \times 10^{-6} - potentially too small for meaningful learning in the final epochs.

Exponential decay is appropriate for runs up to ~100 epochs. For longer training, cosine annealing is preferred because it decays more slowly in the beginning and reaches a user-specified floor.

from torch.optim.lr_scheduler import ExponentialLR

scheduler = ExponentialLR(optimizer, gamma=0.95) # 5% decay each epoch
# epoch 10: LR = 0.1 * 0.95^10 = 0.0599
# epoch 50: LR = 0.1 * 0.95^50 = 0.0077
# epoch 100: LR = 0.1 * 0.95^100 = 0.00059

Cosine Annealing: The Modern Standard

The most widely used schedule for vision models. LR follows a cosine curve from ηmax\eta_{\max} to ηmin\eta_{\min} over TT steps:

ηt=ηmin+12(ηmaxηmin)(1+cos(πtT))\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\frac{\pi t}{T}\right)\right)

Why cosine outperforms step decay:

  1. No discontinuities: the cosine curve is smooth - the optimizer never experiences a sudden step change
  2. Slow start of decay: cosine stays near ηmax\eta_{\max} for the first quarter of training - aggressive early exploration
  3. Rapid mid-decay: the steep middle portion efficiently reduces LR during convergence
  4. Slow finish: the curve flattens near ηmin\eta_{\min} - the optimizer can settle precisely without overshooting
LR
|
| ******
| * *
|* *
| **
| **
| ****
+-----------------------------------> Steps
0 T/2 T

The derivative of the cosine schedule is π2T(ηmaxηmin)sin(πt/T)-\frac{\pi}{2T}(\eta_{\max} - \eta_{\min})\sin(\pi t/T), which is zero at t=0t=0 and t=Tt=T (slow changes at boundaries) and maximum at t=T/2t=T/2 (fastest decay in the middle). This perfectly matches the desired behavior.

from torch.optim.lr_scheduler import CosineAnnealingLR

# Per-epoch cosine annealing
scheduler = CosineAnnealingLR(
optimizer,
T_max=90, # complete one cosine cycle over 90 epochs
eta_min=1e-6 # minimum LR at end of cycle
)

# Per-step cosine (preferred for smooth decay)
total_steps = len(train_loader) * num_epochs
scheduler_step = CosineAnnealingLR(
optimizer,
T_max=total_steps, # one cycle over all training steps
eta_min=1e-7
)
# Call scheduler_step.step() after each BATCH (not each epoch)

:::tip Per-Step vs Per-Epoch Scheduling Set T_max to the total number of training steps (not epochs) and call scheduler.step() after each batch for smoother, finer-grained decay. Per-epoch scheduling produces a step-function approximation of the smooth cosine curve. With 100 batches per epoch and 90 epochs, per-batch scheduling gives 9000 update points vs 90 for per-epoch. The difference matters for short training runs. :::

SGDR: Cosine with Warm Restarts (Loshchilov and Hutter, 2016)

Stochastic Gradient Descent with Warm Restarts applies cosine annealing in cycles. At the end of each cycle, LR is reset to ηmax\eta_{\max} and the next cycle length is multiplied by TmultT_{\text{mult}}:

ηt=ηmin+12(ηmaxηmin)(1+cos(πtcurrTcurr))\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\frac{\pi t_{\text{curr}}}{T_{\text{curr}}}\right)\right)

Where tcurrt_{\text{curr}} is the step within the current cycle and TcurrT_{\text{curr}} is the current cycle length (T0,T0Tmult,T0Tmult2,T_0, T_0 T_\text{mult}, T_0 T_\text{mult}^2, \ldots).

Why restarts help: each LR reset kicks the optimizer out of a sharp local minimum into a neighboring basin. The loss landscape around a sharp minimum is narrow - a large LR step easily escapes it. Flat minima are wider and harder to escape - the optimizer tends to settle back near them after a restart. Over multiple cycles, SGDR explores multiple basins and naturally gravitates toward flatter, more generalizable minima.

Snapshot ensembling: save the model at the end of each cosine cycle (when LR has decayed to ηmin\eta_{\min} and the model has converged within its current basin). Each saved model has converged to a different local minimum. Ensemble predictions from these checkpoints - each corresponds to a differently-shaped basin. The ensemble is nearly free (no additional training) and consistently improves performance by 1–2%.

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=10, # first restart after 10 epochs
T_mult=2, # subsequent cycle doubles: 10 → 20 → 40 → ...
eta_min=1e-6
)

# For snapshot ensembling: save model when scheduler.last_epoch == cycle_end
cycle_ends = [10, 30, 70, 150] # cumulative sums for T_mult=2: 10, 10+20, 10+20+40...
saved_models = []

for epoch in range(max_epochs):
train_one_epoch(model, optimizer)
scheduler.step() # call after each epoch

if epoch + 1 in cycle_ends:
# Optimizer has converged within current basin - save this checkpoint
import copy
saved_models.append(copy.deepcopy(model.state_dict()))
print(f"Snapshot saved at epoch {epoch + 1}")

Linear Warmup: Why Adam Needs It

Transformers and large models trained with Adam require warmup: starting from a very small LR and linearly increasing to the target LR over the first ww steps:

ηt=ηmaxtwfor tw\eta_t = \eta_{\text{max}} \cdot \frac{t}{w} \quad \text{for } t \leq w

The mechanical reason warmup is needed for Adam: Adam maintains an exponential moving average of squared gradients vtv_t to scale updates. At step tt:

vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2

With v0=0v_0 = 0 and β2=0.999\beta_2 = 0.999, at step 1: v1=0.001g12v_1 = 0.001 g_1^2. After bias correction: v^1=v1/(10.9991)=v1/0.001=g12\hat{v}_1 = v_1 / (1 - 0.999^1) = v_1 / 0.001 = g_1^2. The bias correction makes this correct in expectation.

However, the issue is that the second moment estimate is based on only one gradient observation. The estimate has high variance - for some parameter dimensions, g12g_1^2 may be very small (producing a very small v^1\hat{v}_1 for that dimension), which makes the effective learning rate η/v^1+ϵ\eta / \sqrt{\hat{v}_1 + \epsilon} very large. In the first few steps, Adam can take catastrophically large steps in dimensions where the single gradient observation happened to be small. Warmup keeps η\eta small during this unreliable period.

Why warmup particularly matters for pre-trained models: the pretrained representations are fragile - they encode information accumulated over billions of tokens. A single large gradient step from a fine-tuning batch can overwrite the pretrained weights, destroying the representations. Warmup prevents this: the first few hundred steps have a tiny LR, allowing the optimizer's statistics to stabilize before full-magnitude updates occur.

Typical warmup duration: 4–10% of total training steps. For a 100K step language model pre-training run, warm up for 4,000–10,000 steps. For fine-tuning (typically 1K–10K steps), warm up for 50–500 steps.

import math
from torch.optim.lr_scheduler import LambdaLR, LinearLR, CosineAnnealingLR, SequentialLR


def make_linear_warmup_cosine_scheduler(
optimizer,
warmup_steps: int,
total_steps: int,
min_lr_ratio: float = 0.0,
) -> LambdaLR:
"""
Linear warmup for warmup_steps, then cosine decay to min_lr_ratio * base_lr.
This is the BERT/GPT recipe and the standard for transformer training.

Args:
optimizer: PyTorch optimizer
warmup_steps: Number of warmup steps
total_steps: Total number of training steps
min_lr_ratio: Final LR = base_lr * min_lr_ratio (default: 0 = full decay)
"""
def lr_lambda(current_step: int) -> float:
if current_step < warmup_steps:
# Linear warmup: 0 → 1 over warmup_steps
return float(current_step) / float(max(1, warmup_steps))

# Cosine decay from warmup_steps to total_steps
progress = float(current_step - warmup_steps) / float(
max(1, total_steps - warmup_steps)
)
cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine

return LambdaLR(optimizer, lr_lambda)


# Usage: BERT-style fine-tuning
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
total_steps = len(train_loader) * num_epochs
warmup_steps = int(0.06 * total_steps) # 6% warmup

scheduler = make_linear_warmup_cosine_scheduler(optimizer, warmup_steps, total_steps)

for step in range(total_steps):
batch = next(train_iter)
loss = forward(model, batch, optimizer)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step() # called after each STEP, not each epoch
optimizer.zero_grad()


# Alternative: compose with PyTorch built-in schedulers (cleaner for production)
def make_warmup_cosine_pytorch(optimizer, warmup_steps: int, total_steps: int,
eta_min: float = 0.0) -> SequentialLR:
"""Using PyTorch built-in components - more readable."""
warmup = LinearLR(
optimizer,
start_factor=1e-8, # start at almost zero
end_factor=1.0, # ramp up to full LR
total_iters=warmup_steps,
)
cosine = CosineAnnealingLR(
optimizer,
T_max=total_steps - warmup_steps,
eta_min=eta_min,
)
return SequentialLR(
optimizer,
schedulers=[warmup, cosine],
milestones=[warmup_steps],
)

1-Cycle Policy and Super-Convergence (Smith, 2018)

Leslie Smith's 1-cycle policy enables super-convergence - reaching competitive accuracy in 5–10x fewer epochs than standard training. The schedule has two phases:

Phase 1 (first 30% of total steps): LR increases from max_lr / div_factor to max_lr. Momentum decreases from max_momentum to base_momentum. (Momentum and LR are anti-correlated - high LR works with low momentum, and vice versa.)

Phase 2 (remaining 70% of steps): LR decreases from max_lr all the way to max_lr / (div_factor × final_div_factor). Momentum returns to max_momentum.

LR Momentum
| /\ | \ /
| / \ | \ /
| / \ | \/
| / \ |
|/ \__________ +---------->
+------------> Steps
Steps

Why the high peak LR works as regularization: the very high LR during the middle of phase 1 acts as a strong regularizer - it prevents settling into any single sharp basin and forces exploration of flatter regions. Networks trained with 1cycle generalize better because they are pushed toward flat minima during training, not just during the final convergence.

Super-convergence mechanism: with a sufficiently high peak LR, the network undergoes a qualitative phase transition - it escapes sharp local minima entirely and finds a globally flatter region. Once in this flat region, the aggressive annealing in phase 2 allows precise convergence. The result: the network finds a better optimum in fewer iterations.

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR


def train_with_one_cycle(model, train_loader, num_epochs: int = 20,
max_lr: float = 0.1):
"""
1cycle training: find max_lr with LR range test first!
"""
optimizer = optim.SGD(
model.parameters(),
lr=max_lr / 25, # initial LR = max_lr / div_factor
momentum=0.85, # start with lower momentum
weight_decay=5e-4,
)

scheduler = OneCycleLR(
optimizer,
max_lr=max_lr,
steps_per_epoch=len(train_loader),
epochs=num_epochs,
pct_start=0.3, # 30% of steps for warmup phase
div_factor=25.0, # initial_lr = max_lr / 25
final_div_factor=1e4, # final_lr = initial_lr / 10000 ≈ 4e-7
anneal_strategy='cos', # smooth cosine annealing in both phases
cycle_momentum=True, # momentum varies anti-correlated with LR
base_momentum=0.85, # momentum at peak LR
max_momentum=0.95, # momentum at minimum LR
)

criterion = torch.nn.CrossEntropyLoss()

for epoch in range(num_epochs):
model.train()
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
scheduler.step() # PER BATCH - not per epoch!

print(f"Epoch {epoch+1}/{num_epochs}, LR: {optimizer.param_groups[0]['lr']:.6f}")

:::warning OneCycleLR: Per-Batch, Not Per-Epoch OneCycleLR must be called after each batch, not each epoch. This differs from most other PyTorch schedulers. Calling per-epoch with OneCycleLR will produce a schedule that completes the entire cycle in num_epochs calls instead of steps_per_epoch × num_epochs calls - the entire schedule happens in the first epoch and LR stays at its final value for all subsequent epochs. PyTorch will silently allow this and produce wrong results. :::

LR Range Test: Finding the Right Peak LR

The LR range test (Smith, 2015) is the fastest way to find a good learning rate empirically, especially for OneCycleLR's max_lr.

Procedure:

  1. Save the model state (the test will modify weights)
  2. Start with LR at 1e-7
  3. For each mini-batch: run forward+backward pass, update parameters, then multiply LR by a fixed factor to increase exponentially
  4. Record the smoothed loss at each LR
  5. Plot smoothed loss vs LR on a log scale
  6. Choose max_lr where the loss descends most steeply
  7. Restore the saved model state
from copy import deepcopy
import torch
import torch.nn as nn


class LRFinder:
"""
Learning rate range test.
Exponentially increases LR over mini-batches and records loss.
Restore model after running to avoid corrupted weights.
"""

def __init__(self, model: nn.Module, optimizer, criterion, device='cuda'):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device

# Save initial state - MUST restore after test
self._model_state = deepcopy(model.state_dict())
self._opt_state = deepcopy(optimizer.state_dict())

def range_test(self, train_loader, start_lr: float = 1e-7,
end_lr: float = 10.0, num_iter: int = 200,
smooth_f: float = 0.05, diverge_th: float = 5.0):
"""
Returns (lrs, losses) for plotting.
smooth_f: exponential smoothing factor (smaller = smoother)
diverge_th: stop when loss exceeds best_loss * diverge_th
"""
lrs, losses = [], []
best_loss = float('inf')
avg_loss = 0.0

# LR increases by this factor each step: end_lr = start_lr * mult^num_iter
mult = (end_lr / start_lr) ** (1.0 / (num_iter - 1))
lr = start_lr

for pg in self.optimizer.param_groups:
pg['lr'] = lr

self.model.train()
data_iter = iter(train_loader)

for step in range(num_iter):
try:
inputs, targets = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
inputs, targets = next(data_iter)

inputs = inputs.to(self.device)
targets = targets.to(self.device)

self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
loss.backward()
self.optimizer.step()

# Exponential smoothing to reduce noise
avg_loss = smooth_f * loss.item() + (1 - smooth_f) * avg_loss
# Bias correction for smoothed average (same logic as Adam's)
smoothed = avg_loss / (1 - (1 - smooth_f) ** (step + 1))

if smoothed < best_loss:
best_loss = smoothed

lrs.append(lr)
losses.append(smoothed)

if smoothed > diverge_th * best_loss:
print(f"Loss diverged at LR={lr:.2e}. Stopping early.")
break

# Increase LR exponentially
lr *= mult
for pg in self.optimizer.param_groups:
pg['lr'] = lr

# CRITICAL: restore original model and optimizer state
self.model.load_state_dict(self._model_state)
self.optimizer.load_state_dict(self._opt_state)

return lrs, losses

def suggest_lr(self, lrs: list, losses: list) -> dict:
"""
Suggest LR based on steepest descent and minimum loss.
The steepest descent point is the most reliable recommendation.
"""
# Find steepest descent
gradients = [losses[i+1] - losses[i] for i in range(len(losses)-1)]
steepest_idx = min(range(len(gradients)), key=lambda i: gradients[i])

min_loss_idx = losses.index(min(losses))

return {
"steepest_descent_lr": lrs[steepest_idx],
"min_loss_lr": lrs[min_loss_idx],
"recommended_max_lr": lrs[steepest_idx],
"note": "Use steepest descent LR as max_lr for OneCycleLR. "
"Use min_loss_lr / 10 as static LR for other schedulers."
}


# Usage
optimizer = torch.optim.SGD(model.parameters(), lr=1e-7, momentum=0.9)
finder = LRFinder(model, optimizer, nn.CrossEntropyLoss(), device='cuda')
lrs, losses = finder.range_test(train_loader, num_iter=300)
suggestions = finder.suggest_lr(lrs, losses)
print(f"Recommended max_lr: {suggestions['recommended_max_lr']:.2e}")

Polynomial Decay: For Large-Scale Training

Many large-scale training runs (T5, PaLM, GPT-3) use polynomial decay:

ηt=(η0ηend)(1tT)p+ηend\eta_t = (\eta_0 - \eta_{\text{end}}) \cdot \left(1 - \frac{t}{T}\right)^p + \eta_{\text{end}}

With p=1p = 1 this is linear decay. With p=2p = 2 it decays faster near the beginning and slower near the end. The advantage over cosine: the decay rate is explicit (controlled by pp) and the schedule is fully determined by the training length TT.

from torch.optim.lr_scheduler import PolynomialLR

# Polynomial decay with power=1 (linear)
scheduler = PolynomialLR(
optimizer,
total_iters=total_training_steps,
power=1.0, # linear
)

# For warmup + polynomial decay (common in T5, PaLM):
warmup_sched = LinearLR(optimizer, start_factor=1e-6, end_factor=1.0,
total_iters=warmup_steps)
poly_sched = PolynomialLR(optimizer, total_iters=total_steps - warmup_steps, power=1.0)
scheduler = SequentialLR(optimizer, [warmup_sched, poly_sched],
milestones=[warmup_steps])

All Schedulers in PyTorch: Reference Implementation

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import (
StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR,
CosineAnnealingWarmRestarts, OneCycleLR, LinearLR, PolynomialLR,
SequentialLR, LambdaLR, ReduceLROnPlateau,
)
import torch.nn as nn


def all_scheduler_examples(model: nn.Module, train_loader) -> None:
"""Reference implementation of every major PyTorch scheduler."""

base_lr = 0.1
optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9)
total_steps = len(train_loader) * 90 # 90 epochs

examples = {
"StepLR": StepLR(optimizer, step_size=30, gamma=0.1),

"MultiStepLR": MultiStepLR(optimizer, milestones=[30, 60, 80], gamma=0.1),

"ExponentialLR": ExponentialLR(optimizer, gamma=0.95),

"CosineAnnealingLR": CosineAnnealingLR(optimizer, T_max=90, eta_min=1e-6),

"CosineWarmRestarts": CosineAnnealingWarmRestarts(
optimizer, T_0=10, T_mult=2, eta_min=1e-6
),

"OneCycleLR": OneCycleLR(
optimizer, max_lr=0.1,
steps_per_epoch=len(train_loader), epochs=90,
pct_start=0.3, div_factor=25, final_div_factor=1e4,
),

"LinearWarmup+Cosine": make_linear_warmup_cosine_scheduler(
optimizer, warmup_steps=int(0.04 * total_steps), total_steps=total_steps
),

"PolynomialDecay": PolynomialLR(optimizer, total_iters=90, power=1.0),

"ReduceLROnPlateau": ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5,
threshold=1e-4, min_lr=1e-6
),
}

# Note: ReduceLROnPlateau.step(val_loss) - takes metric, not step count
print("Scheduler summary:")
for name, sched in examples.items():
print(f" {name:<30}: {type(sched).__name__}")

return examples


# The correct training loop pattern for different scheduler types:

def train_with_scheduler(model, train_loader, val_loader, scheduler, scheduler_type: str):
"""
Different schedulers require different .step() call patterns.
"""
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
criterion = nn.CrossEntropyLoss()

for epoch in range(90):
model.train()
for batch_idx, (inputs, targets) in enumerate(train_loader):
optimizer.zero_grad()
loss = criterion(model(inputs), targets)
loss.backward()
optimizer.step()

# OneCycleLR: step after EVERY batch
if scheduler_type == "OneCycleLR":
scheduler.step()

# Most schedulers: step after each EPOCH
if scheduler_type not in ("OneCycleLR",):
if scheduler_type == "ReduceLROnPlateau":
# ReduceLROnPlateau: needs validation metric
val_loss = evaluate_loss(model, val_loader, criterion)
scheduler.step(val_loss)
else:
scheduler.step()

print(f"Epoch {epoch+1}: LR = {optimizer.param_groups[0]['lr']:.6f}")

Schedule Selection Guide

Scheduler Comparison Table

SchedulerShapeRestartsBest ForLimitation
ConstantFlatNoDebugging onlySub-optimal convergence
StepLRSteppedNoQuick baselinesAbrupt drops cause spikes
MultiStepLRSteppedNoKnown milestone epochsSame as StepLR
ExponentialLRSmooth monotoneNoShort runs (<100 epochs)LR too small in long runs
CosineAnnealingLRSmooth cosineNoMost CNN/MLP tasksSingle cycle only
CosineWarmRestartsRepeating cosineYesLong runs, ensemblingRequires T_0, T_mult tuning
OneCycleLRTriangle-likeNoFast CNN convergenceNeeds LR range test
LinearWarmup+CosineRamp + smoothNoTransformers, LLMsNone - gold standard
PolynomialLRSmooth polynomialNoLarge-scale training (T5)Extra power hyperparameter
ReduceLROnPlateauAdaptive steppedNoUnknown convergenceRequires validation metric

Production Engineering: Checkpointing and Gradient Accumulation

import torch
import torch.nn as nn


# CRITICAL: Always save scheduler state in checkpoints
def save_checkpoint(model, optimizer, scheduler, epoch: int, path: str) -> None:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(), # DO NOT OMIT
}, path)


def load_checkpoint(model, optimizer, scheduler, path: str) -> int:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict']) # restores step count
return checkpoint['epoch']


# Gradient accumulation: only step scheduler when optimizer steps
def train_with_gradient_accumulation(model, train_loader, optimizer, scheduler,
criterion, accumulation_steps: int = 4) -> None:
"""
With gradient accumulation, call scheduler.step() only when
optimizer.step() is called - not every backward pass.
"""
model.train()
optimizer.zero_grad()

for batch_idx, (inputs, targets) in enumerate(train_loader):
# Scale loss by accumulation steps - prevents effective gradient magnitude
# from growing with accumulation_steps
loss = criterion(model(inputs), targets) / accumulation_steps
loss.backward()

if (batch_idx + 1) % accumulation_steps == 0:
# Gradient clipping before optimizer step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

optimizer.step()
scheduler.step() # only here - not every backward pass
optimizer.zero_grad()


# Logging LR every step - catch bugs immediately
import logging

def train_step_with_logging(model, batch, optimizer, scheduler,
criterion, global_step: int) -> float:
inputs, targets = batch
optimizer.zero_grad()
loss = criterion(model(inputs), targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()

# Log current LR - the fastest way to spot scheduling bugs
current_lr = optimizer.param_groups[0]['lr']
logging.info(f"step={global_step}, loss={loss.item():.4f}, lr={current_lr:.2e}")

return loss.item()

:::warning Save the Scheduler State Dict Forgetting to save scheduler.state_dict() in checkpoints is a common bug. If you resume training without it, the scheduler restarts from step 0. For OneCycleLR, this means a second warmup phase at an unexpected point in training. For CosineAnnealingLR, this resets to the maximum LR. Neither produces an error message - the model just trains with a wrong LR trajectory for the remainder of the run, potentially harming final performance. Always include scheduler.state_dict() alongside model and optimizer state dicts. :::

YouTube Resources

VideoChannelWhy Watch It
Learning Rate Scheduling - Fast.aiFast.ai1cycle policy, super-convergence, LR finder explained
Cosine Annealing ExplainedAndrej KarpathyCosine schedule intuition with loss landscape visualization
Super-Convergence (1cycle) PaperYannic KilcherFull walkthrough of Smith 2018 super-convergence paper
CS231n - Learning Rate TuningStanford CS231nLearning rate warmup, decay, and practical selection
LR Range Test - Leslie SmithPyImageSearchPractical implementation of the LR finder with code

Interview Q&A

Q1: Why does cosine annealing outperform step decay even though both reduce LR over time?

The shape of the decay matters as much as the total decay. Cosine annealing is smooth - no abrupt drops. It keeps LR near its maximum for roughly the first quarter of training (aggressive exploration of the loss landscape), decays rapidly through the middle (efficient convergence), and flattens near the minimum LR at the end (precise settlement without overshooting). Step decay's discrete drops cause the optimizer to experience sudden step changes. The optimizer has built up a trajectory (momentum buffers, Adam's moment estimates) calibrated for one learning rate, and must re-calibrate for a 10x smaller one. This creates visible instability spikes in the loss curve at each milestone. Cosine annealing avoids all discontinuities. Empirically, cosine consistently outperforms step decay by 0.3–1.5% on image classification benchmarks with identical total training time.

Q2: Why do transformers specifically need learning rate warmup? Explain the mechanism.

Adam maintains an exponential moving average of squared gradients vtv_t to adaptively scale updates per parameter. At step t=1t=1, v1=(1β2)g120.001g12v_1 = (1-\beta_2) g_1^2 \approx 0.001 g_1^2 with β2=0.999\beta_2=0.999. Even after bias correction, this estimate is derived from a single gradient observation - extremely noisy. For dimensions where g1g_1 happens to be small, v^1\hat{v}_1 is tiny, making the effective LR η/v^1+ϵ\eta/\sqrt{\hat{v}_1 + \epsilon} very large. In the first few steps, Adam can take catastrophically large steps in these dimensions. For pre-trained transformers, a single large step can overwrite learned representations. Warmup keeps the LR near zero while the second moment estimate stabilizes over many batches. The warmup period is sufficient for vtv_t to see enough gradient samples to provide a reliable estimate. Practically: for a 100K step training run, 4000 warmup steps are enough for vtv_t to converge even with β2=0.999\beta_2=0.999.

Q3: Explain the 1-cycle policy. What is super-convergence and why does it happen?

The 1-cycle policy (Smith, 2018) uses a triangular LR schedule: increase from initial_lr to max_lr over 30% of total steps, then decrease to near zero. Momentum is varied anti-correlated (high LR with low momentum; low LR with high momentum). Super-convergence refers to the empirical observation that under this policy, networks can reach competitive accuracy in 5–10x fewer epochs than standard training. The mechanism: the high peak LR acts as aggressive regularization - it is too large for the network to settle into any single sharp minimum, so it continues exploring. Networks that would otherwise converge to sharp minima are kept "bouncing" until they find a flatter, more generalizable basin. The subsequent rapid annealing then allows precise convergence within that flat basin. The result is better generalization (flat minima) achieved much faster (fewer total epochs). The prerequisite: finding the right max_lr via the LR range test. The policy fails if max_lr is too small (no super-convergence effect) or too large (divergence).

Q4: How do you find the right learning rate for a new model and dataset?

Use the LR range test: exponentially increase LR from 1e-7 to 10 over 100–300 mini-batches, recording smoothed loss at each LR. Plot loss vs LR on a log scale. The loss first decreases as LR increases (too small LR → slow progress), then reaches a minimum, then increases as LR becomes too large and training destabilizes. Choose the LR at the point of steepest descent - where the loss is decreasing most rapidly. For static LR training (SGD), a good starting point is the LR at the minimum loss divided by 10. For OneCycleLR, use the LR at the steepest descent directly as max_lr. After the test: restore the model state (the test corrupts weights), then train normally. The LR range test takes 5 minutes and saves hours of manual hyperparameter search. Run it whenever starting on a new dataset or architecture.

Q5: When resuming training from a checkpoint, what must be restored and why?

Three things must be restored: model state dict (the weights - obvious), optimizer state dict (momentum buffers and Adam's first/second moment estimates per parameter), and the scheduler state dict (the current step count and internal state). Missing the optimizer state means momentum buffers and Adam's moment estimates start from zero - the first steps after resuming will be as if training just started, with no useful gradient history. Missing the scheduler state means the scheduler resets to step 0. For a cosine schedule at epoch 50 out of 90, the scheduler would jump back to the maximum LR and restart the cosine cycle - training the final 40 epochs at the wrong LR. For OneCycleLR, this would trigger a second warmup phase. None of these failures produce error messages - the training continues silently with the wrong LR trajectory. This is particularly painful for large-scale runs where a missed scheduler save can waste GPU-hours. Always checkpoint all three.

Q6: Compare linear warmup + cosine decay versus SGDR for a transformer pre-training run. Which would you choose?

For transformer pre-training, linear warmup + cosine decay is almost universally preferred. The reasons: (1) Simplicity: a single continuous schedule with no restarts. Restarts in SGDR complicate training restarts from checkpoints - you need to track which part of which cycle you are in. (2) Established baseline: BERT, GPT, T5, LLaMA, and virtually all major language model papers use linear warmup + cosine decay. The recipe is well-understood. (3) No snapshot ensembling benefit: ensembling transformer checkpoints from different loss basins is not standard practice in NLP - the training runs are too expensive to produce multiple convergence points. (4) Warmup correctness: SGDR typically starts from the maximum LR immediately (or with a very short warmup). For transformers, a proper warmup of 4–10% of steps is important for the reasons described above. SGDR would be considered for very long training runs where snapshot ensembling is valuable (computer vision, 200+ epoch training), or for tasks where the optimizer is likely to get stuck in local minima and benefits from periodic kicks. For most production transformer training, linear warmup + cosine decay is the correct default.

ReduceLROnPlateau: Adaptive Scheduling

ReduceLROnPlateau monitors a metric (typically validation loss) and reduces the LR when no improvement is seen for patience epochs:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau


def train_with_plateau_scheduler(model: nn.Module, train_loader, val_loader,
n_epochs: int = 100) -> None:
"""
ReduceLROnPlateau: the most adaptive scheduler.
Best when you do not know the convergence profile in advance.
"""
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

# Reduce by factor 0.5 when val_loss hasn't improved for 10 epochs
# Minimum LR: 1e-6 (stop reducing below this)
# Threshold: consider a change significant only if > 0.1% improvement
scheduler = ReduceLROnPlateau(
optimizer,
mode='min', # watching for minimum (loss) - use 'max' for accuracy
factor=0.5, # multiply LR by 0.5 on plateau
patience=10, # wait 10 epochs before reducing
threshold=1e-4, # minimum significant improvement
threshold_mode='rel', # relative improvement (not absolute)
cooldown=5, # wait 5 epochs after reduction before allowing another
min_lr=1e-6, # never go below this
verbose=True, # print message when LR is reduced
)

for epoch in range(n_epochs):
model.train()
for x, y in train_loader:
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
optimizer.step()

# Evaluate on validation set
model.eval()
val_loss = 0.0
with torch.no_grad():
for x, y in val_loader:
val_loss += criterion(model(x), y).item()
val_loss /= len(val_loader)

# ReduceLROnPlateau takes the metric - NOT a step counter
scheduler.step(val_loss) # different signature from other schedulers!

current_lr = optimizer.param_groups[0]['lr']
print(f"Epoch {epoch+1}: val_loss={val_loss:.4f}, lr={current_lr:.2e}")

When to use ReduceLROnPlateau: tabular MLP tasks with unknown convergence behavior, when the total number of training epochs is not fixed, when a monitored metric is more reliable than a step count. The key limitation: it reacts to past performance, not future trajectory. On tasks where loss decreases monotonically with a smooth curve, ReduceLROnPlateau reduces LR too conservatively. Cosine annealing is more aggressive and usually better for vision and language tasks.

Learning Rate vs Batch Size Scaling

When changing batch size, the learning rate must scale proportionally to maintain the same effective gradient signal:

Linear scaling rule (Goyal et al., Facebook, 2017): if you increase batch size by kk, increase the learning rate by kk. This keeps the gradient signal magnitude the same per effective step.

ηnew=ηbase×BnewBbase\eta_{\text{new}} = \eta_{\text{base}} \times \frac{B_{\text{new}}}{B_{\text{base}}}

For example, if your baseline is LR=0.1 at batch size 256, and you scale to batch size 1024 (4× larger), use LR=0.4.

Square root scaling: some practitioners use ηnew=ηbase×Bnew/Bbase\eta_{\text{new}} = \eta_{\text{base}} \times \sqrt{B_{\text{new}} / B_{\text{base}}} for very large batch sizes where the linear rule overshoots.

Warmup is required at large batch sizes: without warmup, the initial linear-scaled LR is too large for the optimizer's moment estimates to handle. The Facebook paper used 5 epochs of linear warmup when scaling to batch size 8192 for ImageNet training.

def scale_lr_for_batch_size(base_lr: float, base_batch: int,
new_batch: int, rule: str = "linear") -> float:
"""
Compute the appropriate learning rate when changing batch size.

Args:
base_lr: Learning rate at base_batch
base_batch: Reference batch size (where base_lr was tuned)
new_batch: New batch size
rule: "linear" (Goyal 2017) or "sqrt" (gentler scaling)
"""
ratio = new_batch / base_batch
if rule == "linear":
return base_lr * ratio
elif rule == "sqrt":
return base_lr * (ratio ** 0.5)
else:
raise ValueError(f"Unknown rule: {rule}")


# Examples
print(scale_lr_for_batch_size(0.1, 256, 1024, "linear")) # 0.4
print(scale_lr_for_batch_size(0.1, 256, 1024, "sqrt")) # 0.2
print(scale_lr_for_batch_size(3e-4, 512, 4096, "linear")) # 2.4e-3

:::warning The Learning Rate Is the Most Important Hyperparameter All scheduling strategies in this lesson optimize around the learning rate. Before spending time on schedule choice, ensure the base LR is correct. The LR range test identifies this in 5 minutes. The performance difference between a good LR with a simple StepLR schedule vs a bad LR with a perfectly tuned cosine schedule is orders of magnitude larger than the performance difference between two good schedules. Tune LR first. Then schedule. :::

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Training Dynamics demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.