Adam Optimizer - The Default That Deserves Understanding
Reading time: ~30 min | Interview relevance: High | Roles: MLE, Research Engineer, Data Scientist
The Real Interview Moment
You're 40 minutes into a research engineer interview at a major AI lab. The interviewer has been probing your understanding of training dynamics, and then asks: "You're training a large transformer model and notice that loss plateaus after 10K steps. You're currently using Adam with default hyperparameters. Walk me through exactly what Adam is doing under the hood, why it might plateau, and what modifications you'd try. Also - can you explain why people sometimes switch to SGD with momentum for the final phase of training?"
You've used torch.optim.Adam a thousand times. You know it "just works." But can you derive the update rule from scratch? Can you explain why bias correction matters in the first few steps? Do you know when AdamW is actually critical versus a nice-to-have? The interviewer isn't testing whether you can call optimizer.step() - they're testing whether you understand the optimizer well enough to debug training when it breaks.
This is the optimizer interview. It separates practitioners who blindly use defaults from engineers who can reason about why training succeeds or fails.
What You Will Master
After reading this page, you will be able to:
- Derive Adam from first principles, starting from vanilla SGD
- Explain momentum and RMSprop and why Adam combines both
- Write out the full Adam update rule with bias correction and explain each term
- Distinguish AdamW from Adam with L2 regularization and explain when it matters
- Implement Adam from scratch in Python
- Diagnose common optimizer-related training failures
- Argue convincingly for Adam vs SGD in different scenarios
- Answer every common interview question about optimizers
Part 1 - The Problem Adam Solves
Why Vanilla SGD Is Not Enough
Stochastic gradient descent with a fixed learning rate has fundamental limitations:
where is the gradient at step .
Never say "Adam is just better than SGD." In many settings (e.g., image classification with ResNets), well-tuned SGD with momentum outperforms Adam on final test accuracy. The advantage of Adam is faster convergence and robustness to hyperparameter choices, not universally better results.
Three problems with vanilla SGD:
| Problem | Description | Example |
|---|---|---|
| Noisy gradients | Mini-batch gradients are noisy estimates of the true gradient | Loss oscillates wildly, slow convergence |
| Uniform learning rate | Same step size for all parameters, regardless of gradient magnitude | Rare features get tiny effective updates; frequent features dominate |
| Saddle points & ravines | Gets stuck oscillating in steep directions while making slow progress in flat directions | Loss surface has very different curvatures along different axes |
Part 2 - Building Blocks: Momentum and RMSprop
Momentum (Polyak, 1964)
Momentum smooths gradients by maintaining an exponential moving average (EMA) of past gradients:
Intuition: Think of a ball rolling downhill. Momentum accumulates velocity in consistent directions and dampens oscillations in inconsistent directions.
- Typical value: (averages roughly over last gradients)
- Effect: Smoother trajectory, faster convergence through flat regions, reduced oscillation
"Momentum maintains a running average of past gradients. When gradients consistently point in the same direction, momentum accelerates updates. When gradients oscillate, momentum cancels out the noise. It's like a heavy ball rolling downhill - it builds speed on consistent slopes and isn't disturbed by small bumps."
RMSprop (Hinton, unpublished lecture notes, 2012)
RMSprop adapts the learning rate per parameter based on the magnitude of recent gradients:
Intuition: Parameters with large gradients get smaller effective learning rates. Parameters with small gradients get larger effective learning rates.
- Typical value:
- prevents division by zero
- Effect: Handles parameters at very different scales, useful for sparse gradients
Comparison of Building Blocks
| Property | SGD | SGD + Momentum | RMSprop | Adam |
|---|---|---|---|---|
| Gradient smoothing | None | Yes (1st moment) | None | Yes (1st moment) |
| Per-parameter LR | No | No | Yes (2nd moment) | Yes (2nd moment) |
| Bias correction | N/A | No | No | Yes |
| Memory overhead | 0 | 1x params | 1x params | 2x params |
| Typical use | Research baselines | Image classification | RNNs (historical) | Transformers, default |
Part 3 - The Adam Algorithm
Full Derivation
Adam (Adaptive Moment Estimation) combines momentum and RMSprop with a critical addition: bias correction.
Algorithm (Kingma & Ba, 2015):
Initialize: , ,
For each step:
- - compute gradient
- - update first moment (mean)
- - update second moment (uncentered variance)
- - bias-corrected first moment
- - bias-corrected second moment
- - update parameters
Why Bias Correction Matters
Since and , the moment estimates are biased toward zero in early steps.
At step 1:
With , - the moment estimate is 10x smaller than the actual gradient!
The bias correction factor: compensates for this. At : , so . As , , and the correction becomes negligible.
Many candidates can state the Adam update rule but cannot explain why bias correction exists. The interviewer will specifically ask about those division terms. If you just say "it corrects for initialization," follow up with the math showing that without correction, early updates are severely underestimated, especially for where means the bias persists for hundreds of steps.
For the second moment with , bias correction is even more critical. At step :
The correction factor doesn't approach 1 until , meaning without correction, the variance estimate is severely wrong for the first thousand steps.
Default Hyperparameters
| Hyperparameter | Default | Meaning | Sensitivity |
|---|---|---|---|
| (learning rate) | Base step size | Very high - most important to tune | |
| 0.9 | Momentum decay | Low - rarely changed | |
| 0.999 | Variance decay | Medium - sometimes 0.98 or 0.95 for unstable training | |
| Numerical stability | Very low - almost never changed |
Part 4 - Adam from Scratch
import numpy as np
class Adam:
"""Adam optimizer implemented from scratch."""
def __init__(self, params, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
self.params = list(params)
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self.t = 0
# Initialize moment estimates
self.m = [np.zeros_like(p) for p in self.params] # First moment
self.v = [np.zeros_like(p) for p in self.params] # Second moment
def step(self, grads):
"""Perform one optimization step."""
self.t += 1
for i, (param, grad) in enumerate(zip(self.params, grads)):
# Update biased first moment estimate
self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad
# Update biased second moment estimate
self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * grad**2
# Bias correction
m_hat = self.m[i] / (1 - self.beta1**self.t)
v_hat = self.v[i] / (1 - self.beta2**self.t)
# Update parameters
param -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)
return self.params
PyTorch equivalent:
import torch
# Standard Adam
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))
# AdamW (decoupled weight decay - preferred for transformers)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
# Training loop
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Part 5 - AdamW: Decoupled Weight Decay
The Weight Decay vs L2 Regularization Confusion
This is one of the most commonly confused topics in deep learning.
L2 regularization adds a penalty to the loss:
This modifies the gradient:
Weight decay directly shrinks weights at each step:
"For vanilla SGD, L2 regularization and weight decay are mathematically equivalent. But for Adam, they're different! With L2 regularization, the regularization gradient gets scaled by Adam's adaptive learning rate, which means parameters with large gradient variance get less regularization. AdamW fixes this by applying weight decay directly to the parameters, bypassing Adam's scaling. This is called 'decoupled weight decay' and is why AdamW is the standard for training transformers."
Why They Diverge in Adam
With L2 regularization in Adam, the regularization term gets divided by . This means:
- Parameters with large gradients (large ) get less regularization
- Parameters with small gradients (small ) get more regularization
This is not the intended behavior. Weight decay should apply uniformly.
AdamW update rule (Loshchilov & Hutter, 2019):
The weight decay term is outside the adaptive scaling.
When AdamW Actually Matters
| Scenario | Impact of AdamW vs Adam+L2 |
|---|---|
| Small models, low regularization | Minimal difference |
| Large transformers with weight decay | Significant - AdamW gives better generalization |
| Models with very heterogeneous gradient scales | Critical - L2 regularization becomes unpredictable |
| Fine-tuning pre-trained models | Important - uniform decay prevents selective forgetting |
Part 6 - Learning Rate Warmup
Why Warmup Helps Adam
Despite bias correction, Adam's variance estimate is still imprecise during early training. A large learning rate combined with imprecise can cause dangerously large updates.
Linear warmup gradually increases the learning rate:
from torch.optim.lr_scheduler import LambdaLR
def get_warmup_cosine_scheduler(optimizer, warmup_steps, total_steps):
"""Warmup + cosine decay - the standard for transformer training."""
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps # Linear warmup
# Cosine decay
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + np.cos(np.pi * progress))
return LambdaLR(optimizer, lr_lambda)
# Typical usage for a 100K step training run
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = get_warmup_cosine_scheduler(optimizer, warmup_steps=2000, total_steps=100000)
- Google/DeepMind: Often use warmup + inverse sqrt decay: . This is the original Transformer schedule.
- OpenAI: Cosine decay with warmup is standard for GPT-style models.
- Meta: Sometimes use linear decay with warmup for LLaMA-style training.
- Startups: Often just use cosine with warmup - it's robust and works well.
Part 7 - Adam vs SGD: The Great Debate
When to Use Each
| Factor | Favor Adam | Favor SGD + Momentum |
|---|---|---|
| Architecture | Transformers, LSTMs, attention models | CNNs (ResNet, EfficientNet) |
| Dataset size | Small to medium datasets | Large datasets with long training |
| Tuning budget | Limited hyperparameter search budget | Can afford extensive grid search |
| Training stability | Need robust convergence | Can tolerate more tuning |
| Final accuracy | Good but sometimes slightly worse | Often better generalization with tuning |
| Speed to good solution | Faster - fewer steps to converge | Slower - but can reach better final loss |
| Sparse gradients | Excellent (per-parameter LR) | Poor (uniform LR for all params) |
| Memory | 2x parameter memory for moments | 1x parameter memory for momentum |
The Generalization Gap
A well-documented phenomenon: Adam often converges faster but to a wider minimum, while SGD with momentum finds sharper minima that sometimes generalize better.
Don't simply say "Adam generalizes worse than SGD." The generalization gap has been significantly reduced by AdamW, proper weight decay, and warmup schedules. For transformers, Adam/AdamW typically generalizes better than SGD because transformers have loss landscapes that SGD struggles with. The generalization gap is primarily observed in CNN training on image classification tasks.
Key research findings:
- Wilson et al. (2017): SGD generalizes better than Adam on image tasks - but with extensive tuning
- Loshchilov & Hutter (2019): AdamW closes much of the gap by fixing weight decay
- Liu et al. (2020): RAdam (rectified Adam) addresses early training instability
- In practice (2024-2025): AdamW with warmup is the de facto standard for transformers
The Practical Answer
# For transformers / LLMs / NLP - use AdamW
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4, # Lower than SGD typically
betas=(0.9, 0.999),
weight_decay=0.01,
eps=1e-8
)
# For CNNs / image classification - SGD with momentum is competitive
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.1, # Higher than Adam typically
momentum=0.9,
weight_decay=1e-4,
nesterov=True # Nesterov momentum slightly better
)
Part 8 - Adam Variants
AMSGrad
Problem: Adam's second moment estimate can decrease over time, causing the effective learning rate to increase, potentially leading to divergence.
AMSGrad fix: Use the maximum of all past values:
In practice: AMSGrad rarely makes a significant difference. Included in PyTorch (amsgrad=True) but seldom used.
RAdam (Rectified Adam)
Problem: Adam's variance estimate has high variance in early training (not enough samples), causing unstable updates even with bias correction.
RAdam fix: Compute the variance of the variance estimate. When it's too high (early training), fall back to SGD without the adaptive term. As training progresses and the estimate stabilizes, smoothly transition to full Adam.
LAMB (Layer-wise Adaptive Moments)
Designed for large-batch training. Normalizes the Adam update by the ratio of parameter norm to update norm, per layer:
where is the Adam update. Used to train BERT with batch size 65K.
8-bit Adam (bitsandbytes)
Quantizes the optimizer states and to 8-bit, reducing memory by 75% with negligible accuracy loss. Critical for training large models on limited GPU memory.
import bitsandbytes as bnb
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.999)
)
Variant Summary
Part 9 - Diagnosing Optimizer Issues
Common Training Failures and Optimizer Fixes
| Symptom | Likely Cause | Fix |
|---|---|---|
| Loss explodes early in training | LR too high + imprecise variance estimate | Add warmup, reduce initial LR |
| Loss plateaus after fast initial drop | LR too high for fine-grained optimization | Add LR decay (cosine) |
| Loss oscillates wildly | LR too high or too low | Reduce LR, increase |
| Training loss drops but val loss doesn't | Overfitting, not optimizer issue | Increase weight decay, add dropout |
| NaN loss after many steps | Numerical instability, gradient explosion | Gradient clipping, increase |
| Different layers converge at very different rates | Expected - Adam handles this | This is why you use Adam |
| Fine-tuning destroys pretrained knowledge | LR too high for pretrained layers | Use layer-wise LR decay |
Debugging Checklist
# 1. Monitor gradient norms per layer
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
print(f"{name}: grad_norm={grad_norm:.6f}")
# 2. Check for NaN/Inf in gradients
for name, param in model.named_parameters():
if param.grad is not None:
if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
print(f"WARNING: NaN/Inf gradient in {name}")
# 3. Monitor effective learning rate
for group in optimizer.param_groups:
for p in group['params']:
state = optimizer.state[p]
if 'exp_avg_sq' in state:
v = state['exp_avg_sq']
effective_lr = group['lr'] / (v.sqrt() + group['eps'])
print(f"Effective LR range: [{effective_lr.min():.8f}, {effective_lr.max():.8f}]")
Part 10 - Practice Problems
Problem 1: The Bias Correction Derivation
Derive the bias correction factor for the first moment estimate. Show that when .
Hint 1 - Direction
Expand by recursively substituting the update rule. You'll get a geometric series involving .
Hint 2 - Key Insight
. Now take the expectation, assuming for all .
Full Answer
Expand the recursion:
Taking expectation (assuming stationary gradients where ):
Therefore:
The bias correction factor exactly compensates for the initialization at zero.
Problem 2: Adam vs SGD Recommendation
You're training a ResNet-50 on ImageNet and a GPT-2 model on a text corpus. Recommend an optimizer for each and justify your choices with specific technical reasons.
Hint 1 - Direction
Consider the loss landscape properties of CNNs vs transformers. Think about the role of batch normalization in CNNs.
Full Answer + Rubric
ResNet-50 on ImageNet: SGD with momentum (0.9), Nesterov, LR=0.1, cosine decay, weight decay=1e-4.
Reasons: (1) BatchNorm normalizes activations, reducing the need for per-parameter adaptive LRs. (2) Extensive literature shows SGD achieves better final accuracy on ImageNet with ResNets. (3) Lower memory cost (1x vs 2x params). (4) Well-established training recipes exist.
GPT-2 on text: AdamW, LR=3e-4, warmup 2000 steps, cosine decay, weight decay=0.01.
Reasons: (1) Transformers have highly heterogeneous gradient scales across attention heads and layers. (2) No batch normalization - LayerNorm doesn't provide the same normalizing effect on gradients. (3) Sparse attention patterns create sparse gradients. (4) Adam's adaptive LR is critical for stable training.
Scoring:
- Strong Hire: Gives both recommendations correctly with architecture-specific justifications
- Lean Hire: Correct recommendations but generic justifications
- No Hire: Says "always use Adam" or "always use SGD"
Problem 3: Debugging a Training Run
You're training a transformer with AdamW (lr=1e-3, warmup=500 steps). Training loss drops rapidly for 1000 steps, then starts increasing. What's happening and how do you fix it?
Hint 1 - Direction
The LR of 1e-3 is quite high for Adam on transformers. Consider what happens after warmup when the full learning rate kicks in.
Full Answer + Rubric
Diagnosis: The learning rate 1e-3 is likely too high. During warmup (0-500 steps), the effective LR ramps from 0 to 1e-3, and training is stable. After warmup, the full LR causes overshooting - the optimizer steps past minima, and the adaptive estimates from the low-LR warmup phase don't reflect the dynamics at the full LR.
Fixes (in order of priority):
- Reduce LR to 3e-4 (standard for transformers)
- Extend warmup to 2000+ steps
- Add gradient clipping (max norm = 1.0)
- Reduce from 0.999 to 0.98 (faster adaptation to current gradients)
- Add cosine decay to gradually reduce LR
Scoring:
- Strong Hire: Identifies LR as root cause, suggests multiple targeted fixes, explains the warmup-to-full-LR transition dynamics
- Lean Hire: Says "reduce learning rate" without deeper analysis
- No Hire: Suggests switching optimizers or increasing batch size without diagnosing
Part 11 - The Paper in Context
Key Contributions of the Adam Paper
- Combined momentum + adaptive LR - not novel individually, but the combination with bias correction was
- Bias correction - the key technical contribution. Without it, Adam performs poorly in early training
- Default hyperparameters - , , work surprisingly well across many domains
- Convergence proof - proved convergence for convex objectives (the non-convex case is harder)
What the Paper Got Wrong
- Convergence proof was flawed: Reddi et al. (2018) showed Adam can diverge on simple convex problems where SGD converges
- L2 regularization: The paper didn't distinguish between L2 reg and weight decay (fixed by AdamW in 2019)
- Generalization: The paper didn't address the generalization gap, which became apparent later
Historical Timeline
Interview Cheat Sheet
| Question Pattern | Framework | Key Phrases |
|---|---|---|
| "Explain Adam" | SGD → Momentum (1st moment) → RMSprop (2nd moment) → Adam (both + bias correction) | "Adam combines momentum's gradient smoothing with RMSprop's per-parameter scaling, plus bias correction for the zero initialization" |
| "What is bias correction?" | Zero init → biased estimates → correction factor → proof | "Since moments are initialized at zero, early estimates underestimate the true values. The factor 1/(1-beta^t) exactly compensates." |
| "Adam vs AdamW?" | L2 reg vs weight decay → equivalent for SGD → different for Adam → decoupled decay | "In Adam, L2 regularization gets scaled by the adaptive learning rate. AdamW applies weight decay directly, giving uniform regularization." |
| "When would you use SGD over Adam?" | CNNs with BatchNorm, large-scale image classification, when you can tune extensively | "For CNNs, SGD with momentum often achieves better generalization with proper tuning. For transformers, AdamW is almost always better." |
| "Why use warmup?" | Imprecise early variance estimates → large steps → instability | "The second moment estimate needs hundreds of steps to stabilize. Warmup gives it time while keeping updates small." |
| "Adam is not converging. Debug it." | Check LR → check gradients → check warmup → check weight decay → check data | "I'd first verify the learning rate is appropriate, then check for gradient explosion, then examine the warmup schedule." |
Spaced Repetition Checkpoints
- Day 0: Read this page. Write out the Adam update rule from memory. Explain bias correction to a rubber duck.
- Day 3: Without looking, explain why AdamW differs from Adam + L2. Implement the Adam update step in 10 lines of Python.
- Day 7: Compare Adam vs SGD for transformers vs CNNs. What three reasons make Adam better for transformers?
- Day 14: Solve all three practice problems from memory. Time yourself - you should answer each in 5-8 minutes.
- Day 21: Explain the full optimizer landscape (SGD → Momentum → RMSprop → Adam → AdamW → 8-bit Adam) in 3 minutes.
Next Steps
- Continue to LoRA and PEFT to learn how parameter-efficient fine-tuning changes the optimizer equation
- Review Attention Is All You Need to see the original transformer training setup
- For more on training dynamics, see Batch Normalization
