Optimization Algorithms Deep Dive - SGD, Adam, AdamW, and Beyond
Reading time: ~28 minutes | Level: Mathematical Foundations → ML Engineering
The PyTorch optimizer line in your training script is usually one of three choices: optim.SGD, optim.Adam, or optim.AdamW. Most engineers pick one, set a learning rate from a paper, and train. Few understand why these algorithms work, what they are actually computing, or when one outperforms another.
This lesson derives each optimizer from first principles and gives you the mathematical intuition to make principled choices - not cargo-cult ones.
What You Will Learn
- SGD with momentum: the math and when it excels
- AdaGrad: per-parameter adaptive learning rates
- RMSProp: fixing AdaGrad's diminishing rates
- Adam: combining momentum and RMSProp - the math step by step
- AdamW: decoupling weight decay from the adaptive gradient update
- Learning rate schedules: cosine annealing, linear warmup
- Gradient clipping: preventing exploding gradients
- Comprehensive comparison table and decision guide
- From-scratch Python implementations of all optimizers
Prerequisites
- Lesson 01: Derivatives and Gradients (required)
- Lesson 03: Gradient Descent Mechanics (recommended)
- NumPy and basic PyTorch
Part 1 - SGD with Momentum (Revisited and Extended)
The algorithm
At each step t, given mini-batch gradient g_t = ∇L̃(θ_t):
With β₁ = 0 (no momentum): reduces to vanilla mini-batch SGD.
The "effective step size" interpretation
With β₁ = 0.9 and learning rate α:
If the gradient is constant g for many steps, the EMA converges to:
So the effective step size for a constant gradient is α/(1-β₁)·g... wait, let us be careful. Under a constant gradient g:
The EMA of a constant converges to the constant. Effective step size: α.
In directions where the gradient oscillates (changes sign), the EMA cancels out oscillations. In directions where the gradient is consistent, the EMA accumulates.
PyTorch SGD: the subtleties
import torch
import torch.nn as nn
# PyTorch SGD with momentum - note the implementation convention
# PyTorch uses: v_{t+1} = momentum * v_t + g_t (without (1-momentum) scaling)
# Then: θ_{t+1} = θ_t - lr * v_{t+1}
# This differs from the standard definition above - the learning rate is effectively
# lr * (1/(1-momentum)) for steady-state gradients
class SGDMomentum:
"""
SGD with momentum - PyTorch convention:
v = momentum * v + grad
θ = θ - lr * v
"""
def __init__(self, params, lr: float = 0.01, momentum: float = 0.0):
self.params = list(params)
self.lr = lr
self.momentum = momentum
self.velocities = [torch.zeros_like(p) for p in self.params]
def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad.zero_()
def step(self):
for p, v in zip(self.params, self.velocities):
if p.grad is None:
continue
g = p.grad.data
# PyTorch convention: v = momentum * v + g
v.mul_(self.momentum).add_(g)
# Update: θ = θ - lr * v
p.data.sub_(v, alpha=self.lr)
# Demo
model = nn.Linear(5, 1)
optimizer = SGDMomentum(model.parameters(), lr=0.01, momentum=0.9)
X = torch.randn(32, 5)
y = torch.randn(32, 1)
for step in range(5):
optimizer.zero_grad()
loss = ((model(X) - y)**2).mean()
loss.backward()
optimizer.step()
print(f"Step {step}: loss = {loss.item():.4f}")
When SGD with momentum is the right choice
SGD with momentum is the default for computer vision (ResNets, EfficientNets, ConvNets). It often generalizes better than Adam for vision tasks, finding flatter minima.
The key insight: SGD's noise and the lack of per-parameter adaptation mean it explores more broadly. For image classification with well-tuned learning rate and schedule, SGD often slightly outperforms Adam on final accuracy.
Part 2 - AdaGrad: Per-Parameter Adaptive Learning Rates
The problem AdaGrad solves
Vanilla gradient descent uses the same learning rate α for all parameters. But parameters that appear frequently in training (features seen often) accumulate large gradients; parameters that appear rarely (rare features) accumulate small gradients. A uniform learning rate treats them identically.
AdaGrad adapts the learning rate per parameter based on the history of gradient magnitudes.
The algorithm
Maintain the sum of squared gradients for each parameter:
Update:
Parameters with large historical gradients → large G_t → small effective learning rate. Parameters with small historical gradients → small G_t → large effective learning rate.
ML Connection: sparse gradients
AdaGrad was designed for NLP with sparse features (word counts, bag-of-words). Rare words have gradient only when they appear. AdaGrad accumulates small G for rare features → large effective lr → makes learning more democratic across vocabulary.
import numpy as np
class AdaGrad:
"""
AdaGrad optimizer.
G_t = G_{t-1} + g_t^2
θ_{t+1} = θ_t - (α / sqrt(G_t + ε)) * g_t
"""
def __init__(self, params: list, lr: float = 0.01, eps: float = 1e-8):
self.params = params
self.lr = lr
self.eps = eps
# Accumulated squared gradients for each parameter
self.G = [np.zeros_like(p) for p in params]
def step(self, grads: list):
for i, (p, g) in enumerate(zip(self.params, grads)):
# Accumulate squared gradient
self.G[i] += g ** 2
# Adaptive update
p -= self.lr * g / (np.sqrt(self.G[i]) + self.eps)
# Demonstrate: AdaGrad on sparse gradient problem
# Parameter 0: seen often (large gradients), Parameter 1: seen rarely (small gradients)
np.random.seed(42)
theta = np.array([0.0, 0.0]) # starting point
true_theta = np.array([3.0, 3.0])
optimizer = AdaGrad([theta], lr=0.5)
print("AdaGrad on sparse problem:")
print(f"{'Step':>5} | {'θ[0]':>8} | {'θ[1]':>8} | {'G[0]':>10} | {'G[1]':>10}")
print("-" * 55)
for step in range(20):
# Parameter 0 gets large gradient every step
# Parameter 1 gets gradient only every 5 steps (sparse)
g = np.zeros(2)
g[0] = 2 * (theta[0] - true_theta[0]) # frequent
if step % 5 == 0:
g[1] = 2 * (theta[1] - true_theta[1]) # rare
optimizer.G[0] += g**2 # manual for display
theta -= 0.5 * g / (np.sqrt(optimizer.G[0]) + 1e-8)
if step % 5 == 0:
print(f"{step:>5} | {theta[0]:>8.4f} | {theta[1]:>8.4f} | "
f"{optimizer.G[0][0]:>10.2f} | {optimizer.G[0][1]:>10.2f}")
AdaGrad's problem: learning rate death
Because G_t accumulates all historical squared gradients, it grows monotonically. The effective learning rate α/√G_t → 0 over time. For long training, AdaGrad's learning rate shrinks to near-zero, stopping learning.
This is AdaGrad's fundamental flaw. RMSProp fixes it.
Part 3 - RMSProp: Decaying Gradient History
The fix
Instead of summing all historical squared gradients, use an exponential moving average (forgets old history):
With β₂ = 0.9 or 0.999: old squared gradients are gradually forgotten. The effective learning rate adapts to recent gradient magnitudes, not all-time history.
import numpy as np
class RMSProp:
"""
RMSProp optimizer.
v_t = β₂ * v_{t-1} + (1-β₂) * g_t^2
θ_{t+1} = θ_t - (α / sqrt(v_t + ε)) * g_t
"""
def __init__(self, params: list, lr: float = 0.001, beta2: float = 0.9, eps: float = 1e-8):
self.params = params
self.lr = lr
self.beta2 = beta2
self.eps = eps
self.v = [np.zeros_like(p) for p in params]
def step(self, grads: list):
for i, (p, g) in enumerate(zip(self.params, grads)):
# Exponential moving average of squared gradients
self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * g**2
# Adaptive update
p -= self.lr * g / (np.sqrt(self.v[i]) + self.eps)
# RMSProp was invented by Geoff Hinton in an online Coursera lecture (not a paper)
# It is the direct predecessor of Adam
Why RMSProp works
The denominator √v_t estimates the RMS (root mean square) of recent gradients. Dividing by it normalizes the update step to have roughly unit scale per parameter, regardless of the gradient magnitude. This makes the effective learning rate independent of gradient scale - crucial for deep networks where different layers can have very different gradient magnitudes.
Part 4 - Adam: Combining Momentum and RMSProp
Adam (Adaptive Moment Estimation) combines:
- Momentum (1st moment): exponential moving average of gradients
- RMSProp (2nd moment): exponential moving average of squared gradients
- Bias correction: corrects for zero-initialized moments in early training
The Adam algorithm
Given mini-batch gradient g_t at step t:
First moment (momentum):
Second moment (squared gradient EMA):
Bias correction (because m₀ = v₀ = 0, estimates are biased toward zero early in training):
Parameter update:
Why bias correction matters
At t=1, with m₀ = 0:
With β₁ = 0.9: m₁ = 0.1 · g₁, which underestimates the gradient by 10x!
The bias-corrected estimate:
Correct. As t grows, β₁^t → 0 and 1 - β₁^t → 1, so bias correction has negligible effect after a few hundred steps.
Complete Adam implementation
import numpy as np
from typing import List
class Adam:
"""
Adam optimizer - Kingma & Ba (2015).
m_t = β₁ * m_{t-1} + (1-β₁) * g_t
v_t = β₂ * v_{t-1} + (1-β₂) * g_t^2
m̂_t = m_t / (1 - β₁^t)
v̂_t = v_t / (1 - β₂^t)
θ_{t+1} = θ_t - α * m̂_t / (sqrt(v̂_t) + ε)
"""
def __init__(
self,
params: List[np.ndarray],
lr: float = 1e-3,
beta1: float = 0.9,
beta2: float = 0.999,
eps: float = 1e-8,
):
self.params = params
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self.t = 0 # step counter
# First and second moment estimates, initialized to zero
self.m = [np.zeros_like(p) for p in params]
self.v = [np.zeros_like(p) for p in params]
def step(self, grads: List[np.ndarray]):
self.t += 1
for i, (p, g) in enumerate(zip(self.params, grads)):
# Update biased first moment estimate
self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * g
# Update biased second moment estimate
self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * g**2
# Bias correction
m_hat = self.m[i] / (1 - self.beta1**self.t)
v_hat = self.v[i] / (1 - self.beta2**self.t)
# Update parameters
p -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)
# ── Demonstration: Adam vs SGD on a problem with varying curvature ──────────
def run_optimizer_comparison():
np.random.seed(42)
n, d = 500, 10
X = np.random.randn(n, d)
true_w = np.random.randn(d)
y = X @ true_w + 0.1 * np.random.randn(n)
def mse_loss(w): return np.mean((X @ w - y)**2)
def mse_grad(w): return (2/n) * X.T @ (X @ w - y)
# SGD
w_sgd = np.zeros(d)
losses_sgd = []
for step in range(200):
g = mse_grad(w_sgd)
w_sgd -= 0.01 * g
losses_sgd.append(mse_loss(w_sgd))
# Adam
w_adam = np.zeros(d)
adam = Adam([w_adam], lr=0.01)
losses_adam = []
for step in range(200):
g = mse_grad(w_adam)
adam.step([g])
losses_adam.append(mse_loss(w_adam))
print(f"SGD (200 steps): loss = {losses_sgd[-1]:.6f}")
print(f"Adam (200 steps): loss = {losses_adam[-1]:.6f}")
# Steps to reach within 0.01 of true minimum (≈0.01 for this setup)
sgd_steps = next((i for i, l in enumerate(losses_sgd) if l < 0.02), 200)
adam_steps = next((i for i, l in enumerate(losses_adam) if l < 0.02), 200)
print(f"Steps to loss<0.02: SGD={sgd_steps}, Adam={adam_steps}")
run_optimizer_comparison()
Adam defaults
| Hyperparameter | Default | Effect |
|---|---|---|
| α (learning rate) | 1e-3 | Per-parameter effective step size |
| β₁ (momentum) | 0.9 | Gradient EMA decay; higher = more history |
| β₂ (squared grad EMA) | 0.999 | Curvature estimate; higher = more stable |
| ε (numerical stability) | 1e-8 | Prevents division by zero |
:::tip β₂ for transformer training For transformers, β₂ = 0.98 or 0.95 (instead of 0.999) often works better. Very high β₂ means the curvature estimate changes slowly - good for stable problems but slow to adapt for transformers with rapidly changing gradient scales. The LLaMA paper used β₂ = 0.95. :::
Part 5 - AdamW: Fixing Weight Decay in Adam
The problem with Adam + L2 regularization
L2 regularization adds λ‖θ‖² to the loss. The gradient of L + λ‖θ‖² is g_t + 2λθ.
In vanilla gradient descent, this results in:
The weight decays by a factor (1 - 2αλ) each step - well-defined weight decay.
In Adam, the gradient g_t + 2λθ_t is passed to the adaptive update. The second moment v_t adapts to the combined gradient including the regularization term. The effective weight decay is not (1 - 2αλ) - it is modulated by 1/√v_t, which varies per parameter and over time. L2 regularization in Adam does not behave like true weight decay.
AdamW: the fix
AdamW (Loshchilov & Hutter, 2019) decouples weight decay from the adaptive gradient update:
The weight decay term -αλθ_t is applied directly to the parameter, not through the adaptive scaling. This produces true weight decay behavior regardless of the gradient scale.
AdamW implementation
import numpy as np
from typing import List
class AdamW:
"""
AdamW optimizer - Loshchilov & Hutter (2019).
Same as Adam but with decoupled weight decay:
θ_{t+1} = θ_t - α * m̂_t / (sqrt(v̂_t) + ε) - α * λ * θ_t
The weight decay term is NOT passed through the adaptive scaling.
This is the default optimizer for transformer training (BERT, GPT, LLaMA).
"""
def __init__(
self,
params: List[np.ndarray],
lr: float = 1e-3,
beta1: float = 0.9,
beta2: float = 0.999,
eps: float = 1e-8,
weight_decay: float = 0.01, # λ - typical value for transformers
):
self.params = params
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self.weight_decay = weight_decay
self.t = 0
self.m = [np.zeros_like(p) for p in params]
self.v = [np.zeros_like(p) for p in params]
def step(self, grads: List[np.ndarray]):
self.t += 1
alpha_t = self.lr # could also apply schedule here
for i, (p, g) in enumerate(zip(self.params, grads)):
self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * g
self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * g**2
m_hat = self.m[i] / (1 - self.beta1**self.t)
v_hat = self.v[i] / (1 - self.beta2**self.t)
# Adaptive gradient step
adaptive_step = alpha_t * m_hat / (np.sqrt(v_hat) + self.eps)
# Decoupled weight decay (applied separately)
decay_step = alpha_t * self.weight_decay * p
p -= adaptive_step + decay_step
# PyTorch AdamW usage
import torch
import torch.nn as nn
model = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=256, nhead=8, batch_first=True),
num_layers=4
)
# AdamW with typical transformer hyperparameters
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
betas=(0.9, 0.95), # β₂=0.95 instead of 0.999 for transformers
eps=1e-8,
weight_decay=0.1, # typical for large language models
)
print("AdamW configured for transformer training.")
print(f"Parameters: lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1")
Part 6 - Optimizer Comparison
Mathematical summary
| Optimizer | Update rule | Key property |
|---|---|---|
| SGD | θ - α·g | Simplest; noisy for large learning rates |
| SGD + Momentum | θ - α·m (m = β₁m + (1-β₁)g) | Smooths oscillations; builds speed in consistent directions |
| AdaGrad | θ - α·g/√(G+ε) (G = Σg²) | Per-param lr; dies for dense gradients |
| RMSProp | θ - α·g/√(v+ε) (v = β₂v+(1-β₂)g²) | AdaGrad with forgetting; no dying |
| Adam | θ - α·m̂/√(v̂+ε) | Combines momentum + RMSProp; fast convergence |
| AdamW | θ - α·m̂/√(v̂+ε) - α·λ·θ | Adam + true weight decay; best for transformers |
When to use each optimizer
| Use case | Recommended optimizer | Why |
|---|---|---|
| Image classification (ResNet, etc.) | SGD + momentum | Better final accuracy; weight decay works correctly |
| Transformer pretraining (BERT, GPT) | AdamW | Handles sparse attention gradients; correct weight decay |
| Fine-tuning transformers | AdamW | Same as pretraining |
| RNNs / LSTMs | Adam or RMSProp | Handles sparse gradient dynamics |
| GANs | Adam (β₁=0.5) | Lower momentum avoids mode collapse |
| Reinforcement learning | Adam or RMSProp | Fast adaptation to changing distributions |
| Linear/logistic regression | Adam or L-BFGS | Fast convergence for convex problems |
| Very large models (>10B params) | AdamW + 8-bit Adam | Memory efficiency critical |
import torch
import torch.nn as nn
# Side-by-side comparison on a simple problem
def train_with_optimizer(optimizer_name: str, n_steps: int = 200) -> list:
torch.manual_seed(42)
model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 1))
X = torch.randn(100, 10)
y = torch.randn(100, 1)
optimizers = {
'SGD': torch.optim.SGD(model.parameters(), lr=0.1),
'SGD+Mom': torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9),
'AdaGrad': torch.optim.Adagrad(model.parameters(), lr=0.1),
'RMSProp': torch.optim.RMSprop(model.parameters(), lr=0.01),
'Adam': torch.optim.Adam(model.parameters(), lr=0.01),
'AdamW': torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01),
}
opt = optimizers[optimizer_name]
losses = []
for step in range(n_steps):
opt.zero_grad()
loss = ((model(X) - y)**2).mean()
loss.backward()
opt.step()
losses.append(loss.item())
return losses
print(f"{'Optimizer':>12} | {'Initial loss':>13} | {'Final loss':>12} | {'Steps to <0.5':>14}")
print("-" * 60)
for name in ['SGD', 'SGD+Mom', 'AdaGrad', 'RMSProp', 'Adam', 'AdamW']:
losses = train_with_optimizer(name)
steps_to_threshold = next((i for i, l in enumerate(losses) if l < 0.5), 200)
print(f"{name:>12} | {losses[0]:>13.4f} | {losses[-1]:>12.4f} | {steps_to_threshold:>14}")
Part 7 - Learning Rate Schedules in Practice
Warmup + cosine decay (the transformer standard)
import torch
import math
class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
"""
Linear warmup followed by cosine decay.
Standard for transformer training (BERT, GPT, LLaMA).
"""
def __init__(self, optimizer, warmup_steps: int, total_steps: int, min_lr_ratio: float = 0.1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr_ratio = min_lr_ratio
super().__init__(optimizer)
def get_lr(self):
step = self.last_epoch
if step < self.warmup_steps:
# Linear warmup: 0 → 1
scale = step / self.warmup_steps
else:
# Cosine decay: 1 → min_lr_ratio
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
scale = self.min_lr_ratio + (1 - self.min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
return [base_lr * scale for base_lr in self.base_lrs]
# Usage
model = torch.nn.Linear(100, 10)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
total_steps = 10000
warmup_steps = 500
scheduler = WarmupCosineScheduler(optimizer, warmup_steps, total_steps, min_lr_ratio=0.1)
# Print schedule at key points
print("Warmup + Cosine schedule:")
for step in [0, 100, 500, 1000, 5000, 9999]:
# Fast-forward scheduler
for _ in range(step - scheduler.last_epoch):
scheduler.step()
print(f" Step {step:5d}: lr = {scheduler.get_lr()[0]:.2e}")
Cosine annealing with warm restarts (SGDR)
# PyTorch has this built-in
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=1000, # initial restart period
T_mult=2, # period doubles after each restart
eta_min=1e-6 # minimum learning rate
)
# Or the simpler version without restarts:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=10000, # total steps
eta_min=1e-6
)
One-cycle learning rate policy
Popularized by fast.ai. Cycles learning rate from a low value up to max, then back down:
# PyTorch's OneCycleLR
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=3e-4,
total_steps=10000,
pct_start=0.3, # 30% of steps for warmup
anneal_strategy='cos',
div_factor=25, # initial_lr = max_lr / 25
final_div_factor=1e4, # final_lr = initial_lr / 1e4
)
Part 8 - Gradient Clipping: The Math and Practice
Global norm clipping
The standard gradient clipping approach: scale all gradients proportionally so the total gradient norm equals max_norm:
where g is the concatenated gradient vector across all parameters.
This preserves the relative direction of all parameter gradients while bounding the total update magnitude.
import torch
import torch.nn as nn
def train_with_gradient_clipping(model, optimizer, X, y, max_norm=1.0):
"""Training step with gradient clipping."""
optimizer.zero_grad()
output = model(X)
loss = nn.MSELoss()(output, y)
loss.backward()
# Compute gradient norm BEFORE clipping (for monitoring)
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
# Compute norm AFTER clipping
clipped_norm = 0.0
for p in model.parameters():
if p.grad is not None:
clipped_norm += p.grad.data.norm(2).item() ** 2
clipped_norm = clipped_norm ** 0.5
optimizer.step()
return loss.item(), total_norm, clipped_norm
# Example
model = nn.Sequential(nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 1))
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
torch.manual_seed(0)
X = torch.randn(32, 10)
y = torch.randn(32, 1)
loss, before_clip, after_clip = train_with_gradient_clipping(model, optimizer, X, y, max_norm=1.0)
print(f"Loss: {loss:.4f}")
print(f"Gradient norm before clipping: {before_clip:.4f}")
print(f"Gradient norm after clipping: {after_clip:.4f} (capped at max_norm=1.0)")
When to clip and what max_norm to use
# Typical max_norm values by task
clipping_guidelines = {
'Language model (GPT, BERT)': 1.0,
'LSTM/RNN': 5.0,
'Vision transformer': 1.0,
'Diffusion model': 1.0,
'ResNet (usually not needed)': None,
}
# How to choose max_norm:
# 1. Train for a few steps, monitor gradient norms
# 2. max_norm should be >= median gradient norm but < maximum spike
# Example: monitoring gradient norms during training
def train_and_monitor(model, n_steps=100):
optimizer = torch.optim.AdamW(model.parameters())
norms = []
for step in range(n_steps):
X = torch.randn(32, 10)
y = torch.randn(32, 1)
optimizer.zero_grad()
loss = nn.MSELoss()(model(X), y)
loss.backward()
# Measure norm before clipping
total_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=float('inf') # infinity = no clipping, just measure
)
norms.append(total_norm.item())
optimizer.step()
import numpy as np
norms = np.array(norms)
print(f"Gradient norm stats: mean={norms.mean():.4f}, "
f"p90={np.percentile(norms, 90):.4f}, "
f"max={norms.max():.4f}")
print(f"Recommended max_norm ≈ {np.percentile(norms, 90):.2f} (90th percentile)")
Part 9 - 8-Bit Adam and Memory-Efficient Optimizers
For models with billions of parameters, optimizer states (m and v) require 2 × model size in memory (8 bytes per parameter, vs 4 bytes for float32 model weights). This doubles memory requirements.
8-bit Adam (Dettmers et al.) quantizes optimizer states to 8-bit integers, reducing memory by 4x while maintaining training stability:
# Using bitsandbytes for 8-bit Adam
# pip install bitsandbytes
try:
import bitsandbytes as bnb
optimizer_8bit = bnb.optim.Adam8bit(model.parameters(), lr=3e-4)
print("8-bit Adam available")
except ImportError:
print("Install bitsandbytes for 8-bit Adam")
# Memory savings: for a 7B parameter model
# Standard AdamW: 7B params × 3 (model+m+v) × 4 bytes = 84 GB
# 8-bit AdamW: 7B params × (4 + 1 + 1) bytes = 42 GB (saves 42 GB)
Part 10 - Production Training Recipes
Complete training loop with all best practices
import torch
import torch.nn as nn
import math
def create_optimizer_and_scheduler(
model: nn.Module,
learning_rate: float = 3e-4,
weight_decay: float = 0.01,
warmup_steps: int = 1000,
total_steps: int = 100000,
betas: tuple = (0.9, 0.95),
) -> tuple:
"""
Production-grade optimizer setup for transformer training.
Key choices:
- AdamW with decoupled weight decay
- No weight decay for biases and LayerNorm parameters
- Warmup + cosine schedule
- Gradient clipping in the training loop
"""
# Separate parameters: apply weight decay to weights but NOT biases/norms
decay_params = []
no_decay_params = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
# Don't decay bias terms or normalization parameters
if 'bias' in name or 'norm' in name or param.ndim == 1:
no_decay_params.append(param)
else:
decay_params.append(param)
param_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0},
]
optimizer = torch.optim.AdamW(
param_groups,
lr=learning_rate,
betas=betas,
eps=1e-8,
)
# Warmup + cosine decay
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps # linear warmup
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress)) # cosine decay to 10%
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return optimizer, scheduler
def training_step(
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
batch: tuple,
max_grad_norm: float = 1.0,
accumulation_steps: int = 1,
current_accumulation: int = 0,
) -> dict:
"""
Single training step with gradient clipping and accumulation.
Returns metrics dict.
"""
X, y = batch
# Forward pass
output = model(X)
loss = nn.CrossEntropyLoss()(output, y)
# Scale loss for gradient accumulation
loss = loss / accumulation_steps
loss.backward()
metrics = {'loss': loss.item() * accumulation_steps}
# Only update on accumulation boundary
if (current_accumulation + 1) % accumulation_steps == 0:
# Gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
metrics['grad_norm'] = grad_norm.item()
# Optimizer step
optimizer.step()
scheduler.step()
optimizer.zero_grad()
metrics['lr'] = scheduler.get_last_lr()[0]
return metrics
Part 11 - Common Mistakes
:::danger Using Adam with L2 regularization instead of AdamW
torch.optim.Adam(params, weight_decay=0.01) does NOT implement true weight decay - it adds the L2 gradient to the adaptive update, which modulates the decay by the second moment estimate. This is incorrect. Use torch.optim.AdamW for true decoupled weight decay.
# WRONG for weight decay
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.01)
# RIGHT for weight decay
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
:::
:::warning Applying weight decay to all parameters Biases and normalization layer parameters (LayerNorm, BatchNorm) should NOT have weight decay applied. Applying weight decay to these causes training instability in transformers.
# Correct: no weight decay for biases and norm params
no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
'weight_decay': 0.0},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=3e-4)
:::
:::warning Not resetting optimizer state when resuming training with a new learning rate If you resume training with a different learning rate, the Adam second moment estimates (v_t) from the old run reflect the old learning rate's gradient scale. The bias correction will give wrong results for the first several thousand steps.
Solution: when changing the learning rate significantly, consider resetting optimizer state (state_dict replacement) or starting with a brief warmup.
:::
:::tip Monitor gradient norm during training If you do not monitor gradient norms, you cannot know if clipping is being triggered too frequently (threshold too low) or never (threshold might be too high for the instability you are experiencing).
# Add to training loop:
if step % 100 == 0:
total_norm = sum(p.grad.data.norm(2).item()**2 for p in model.parameters()
if p.grad is not None) ** 0.5
print(f"Step {step}: grad_norm={total_norm:.4f}")
:::
Interview Questions
Q1: Explain Adam's update rule step by step. What does each term do?
Adam maintains estimates of two statistics about the gradient history:
m_t = β₁·m_{t-1} + (1-β₁)·g_t - First moment (exponential moving average of gradients). Functions as momentum: accumulates gradient direction, smoothing out oscillations. β₁ = 0.9 is standard.
v_t = β₂·v_{t-1} + (1-β₂)·g_t² - Second moment (EMA of squared gradients). Estimates the gradient variance per parameter. Large v_t → parameter has large gradient variance → smaller effective learning rate. β₂ = 0.999 is standard.
m̂_t = m_t/(1-β₁^t), v̂_t = v_t/(1-β₂^t) - Bias correction. Since m₀ = v₀ = 0, early estimates are biased downward. Dividing by (1-β^t) corrects this. Effect is negligible after a few hundred steps.
θ_{t+1} = θ_t - α·m̂_t/(√v̂_t + ε) - Update. The effective learning rate per parameter is α/√v̂_t: parameters with large gradient history get smaller updates (already well-estimated), parameters with small gradient history get larger updates (need more correction).
Why it works: Adam adapts the learning rate per parameter based on the relative magnitudes of first and second gradient moments. Parameters with large gradient magnitudes receive smaller effective learning rates (high curvature → small steps), and vice versa. This is a diagonal approximation to Newton's method.
Q2: What is the difference between Adam and AdamW?
Adam with weight_decay: Adds λ·θ to the gradient before the adaptive update. The update becomes:
θ_{t+1} = θ_t - α·m̂(g + λθ) / (√v̂(g + λθ) + ε)
The weight decay term λθ is modulated by 1/√v̂ - different per parameter and changing over time. The actual regularization effect depends on the gradient scale of each parameter. This is not true weight decay.
AdamW applies weight decay directly to the parameter, after the adaptive gradient step:
θ_{t+1} = θ_t - α·m̂_t/(√v̂_t + ε) - α·λ·θ_t
The decay term -α·λ·θ_t is a simple proportional shrinkage, independent of the adaptive gradient scaling. This IS true weight decay - identical in effect to L2 regularization in SGD.
Practical impact: AdamW generalizes significantly better than Adam + L2 for transformers. The original BERT paper used a custom weight decay implementation equivalent to AdamW. The Hugging Face transformers library defaults to AdamW for all transformer fine-tuning for this reason.
Q3: Why does AdaGrad "die" and how does RMSProp fix it?
AdaGrad accumulates all historical squared gradients: G_t = Σ_{k=1}^{t} g_k². This sum only grows - it never decreases. As t → ∞, G_t → ∞, so the effective learning rate α/√G_t → 0.
For sparse NLP tasks (short training), this is fine - G_t does not get too large. For dense tasks with long training (deep learning), G_t grows without bound, and learning effectively stops after thousands of steps.
RMSProp fixes this with exponential forgetting: v_t = β₂·v_{t-1} + (1-β₂)·g_t². With β₂ = 0.9:
- Recent squared gradients contribute (1-β₂) = 0.1
- Gradients from 10 steps ago contribute (1-β₂)·β₂^10 ≈ 0.034
- Gradients from 100 steps ago contribute ~1.5e-5 (essentially zero)
The effective memory window is ~1/(1-β₂) = 10 steps. v_t reaches a steady state (RMS of recent gradients) instead of growing without bound. The effective learning rate stabilizes rather than dying.
Q4: When should you use SGD with momentum instead of Adam, and why?
Use SGD + momentum for computer vision (image classification, object detection, segmentation):
-
Better final accuracy: SGD often achieves slightly higher accuracy on ImageNet-scale tasks. Adam converges faster initially but SGD with a well-tuned learning rate schedule often surpasses it at convergence.
-
Better generalization: SGD's noise and lack of adaptive per-parameter scaling means it explores more broadly, finding flatter minima that generalize better. Adam's adaptive learning rate can cause it to converge to sharper minima.
-
L2 regularization works correctly: SGD + weight_decay applies true L2 regularization. For vision tasks with strong augmentation, this regularization scheme works reliably.
Use Adam/AdamW for:
- Transformer training (NLP, vision transformers, multimodal)
- Any architecture with sparse gradients (rare features, sparse attention)
- Faster initial convergence when training time is constrained
- Problems with widely varying gradient scales across layers
Empirical rule: For ResNets on ImageNet, SGD + momentum + cosine schedule ≈ 76.5% top-1. Adam with the same lr/schedule often ≈ 74-75%. For BERT fine-tuning, AdamW is standard; SGD does not converge reliably.
Q5: Explain gradient clipping mathematically and when to set max_norm.
Global norm clipping: Given the gradient vector g (concatenation of all parameter gradients), if ‖g‖₂ > max_norm:
This scales all gradient components down proportionally, preserving direction but bounding magnitude.
Why global norm, not per-parameter: Per-parameter clipping would distort the relative direction of different parameter gradients, changing the optimization trajectory. Global norm clipping preserves direction.
When to use it:
- RNNs/LSTMs: Backprop through time multiplies many Jacobians, causing exponential gradient explosion. Max_norm = 5.0 is standard.
- Transformers: Max_norm = 1.0 is the de facto standard (used in GPT-2, GPT-3, BERT, LLaMA).
- Early training: Any model can have large initial gradients due to random initialization.
How to set max_norm:
- Train for 100-200 steps with max_norm = ∞ (no clipping), record gradient norms
- Set max_norm ≈ 90th percentile of these norms
- If training is stable, gradient clipping is rarely active - which is fine (it is a safety net, not a constant intervention)
- If clipping triggers > 50% of steps, reduce max_norm or reduce learning rate
Quick Reference
| Optimizer | Key hyperparameters | Best for |
|---|---|---|
| SGD | lr, momentum=0.9 | Vision, when final accuracy matters most |
| Adam | lr=1e-3, β₁=0.9, β₂=0.999 | General purpose, NLP, sparse gradients |
| AdamW | lr=3e-4, β₁=0.9, β₂=0.95, wd=0.01-0.1 | Transformers, language models |
| RMSProp | lr=1e-3, β₂=0.9 | RNNs, RL |
| AdaGrad | lr=0.01 | Very sparse features, short training |
| Schedule | Use case |
|---|---|
| Cosine annealing | Standard for deep learning |
| Warmup + cosine | Transformer training (required) |
| OneCycleLR | Fast training cycles, fast.ai |
| Step decay | Classical ML, simple baselines |
Key Takeaways
- SGD with momentum smooths gradient oscillations; adaptive methods (Adam, AdamW) additionally adapt learning rates per parameter
- AdaGrad accumulates all squared gradients (learning rate dies); RMSProp fixes this with exponential forgetting
- Adam combines momentum (1st moment) and RMSProp (2nd moment) with bias correction for the initialization
- AdamW decouples weight decay from the adaptive gradient update - essential for correct regularization in transformer training; use it instead of Adam + L2
- SGD + momentum often generalizes better on vision tasks; AdamW is the standard for transformers
- Always apply warmup for transformer training - gradients are unreliable at initialization
- Gradient clipping (max_norm = 1.0 for transformers, 5.0 for RNNs) prevents catastrophic updates; monitor gradient norms during training
- Do not apply weight decay to biases and normalization parameters
Module Complete: Calculus and Optimization for Machine Learning
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Optimizer Race demo on the EngineersOfAI Playground - no code required.
:::
