Stochastic and Mini-Batch Gradient Descent
The Real Interview Moment
You are interviewing for a senior ML engineer role at a large tech company. The interviewer, who has trained models with billions of parameters, shows you a training loss curve that plateaued too early. "Our model seems to be getting stuck. We tried reducing the learning rate. We tried adding L2 regularization. What else would you try?"
Most candidates suggest increasing the model capacity. A few suggest learning rate schedules. The candidate who gets the role understands that the answer might be to increase gradient noise - specifically, to reduce the batch size, or to use a cyclic learning rate schedule that periodically increases the learning rate to escape the stuck region. They know that gradient noise is not a bug to be eliminated but a feature that provides implicit regularization.
This lesson gives you the framework to answer that question, and to understand why the stochastic optimization algorithm that drives essentially all of modern deep learning works the way it does.
Why Batch Gradient Descent Fails at Scale
Recall the batch gradient descent update from Lesson 02:
This requires processing all samples to compute one gradient. At samples, 500 features, 32-bit floats: the dataset is 1 TB. It does not fit in memory. Even with distributed storage, one gradient step reads the entire dataset - taking 40 minutes on a high-throughput cluster. With 10,000 steps needed for convergence, training takes 27 days. The model is already stale before it finishes.
The solution is to use an approximation. Instead of the exact gradient over all samples, compute a gradient estimate from a small subset. This estimate is noisy, but it still points downhill in expectation - and computing it is times cheaper per step.
Stochastic Gradient Descent (SGD)
In SGD, each update uses a single randomly chosen sample :
This estimate is unbiased - its expectation over a randomly drawn is the true full gradient:
The cost per update drops from to - a factor of speedup. For , that is eight orders of magnitude faster per step.
Gradient Variance in SGD
The variance of the SGD gradient estimate is the key quantity governing both convergence speed and generalization:
This variance is always (equals zero only when all per-sample gradients are identical - trivial problems). It represents the "noise" in the gradient estimate. For mini-batch gradient descent with batch size , the variance decreases proportionally:
This is the fundamental trade-off: larger reduces variance (smoother gradient, more stable updates) but requires proportionally more compute per step.
Mini-Batch Gradient Descent
Mini-batch GD uses a batch of samples per update:
where is the submatrix of randomly selected rows.
This balances three competing objectives:
- Gradient accuracy: Larger → closer to true gradient → smoother, more stable convergence
- Computation speed: Smaller → more updates per epoch → faster convergence in wall-clock time
- Hardware efficiency: GPUs are optimized for matrix operations at batch sizes of 32–512 (fills CUDA cores and saturates memory bandwidth)
Epoch vs. Iteration
- Iteration: one weight update, processing samples
- Epoch: one full pass through the training data, consisting of iterations
n = 100,000 samples, B = 256 batch size
→ 100,000 / 256 ≈ 391 iterations per epoch
→ 50 epochs = 19,531 weight updates
→ vs. batch GD: 50 epochs = 50 weight updates
Mini-batch GD with and 50 epochs makes more weight updates than batch GD with the same epoch budget.
The Key Insight: Noise as Implicit Regularization
The most important insight in this lesson - and one that took the field years to fully appreciate - is that the gradient noise of SGD is a feature, not a bug.
Keskar et al. 2016: Sharp vs Flat Minima
Nitish Keskar and colleagues published a landmark paper in 2016 demonstrating empirically that large-batch training finds sharp minima while small-batch SGD finds flat minima. The sharpness of a minimum is measured by the largest eigenvalue of the Hessian at the minimum - a large eigenvalue means the loss increases rapidly in at least one direction as you move away from the minimum.
Why does this matter?
Sharp minimum (large-batch): Flat minimum (small-batch SGD):
Loss Loss
| |
| /\ | ___
| / \ | ___/ \___
| / \ | / \
| / \ |/ \
|/ \ | \
+-----------> w +--------------> w
w* w*
If test data shifts the loss slightly:
Sharp: huge increase in test loss Flat: small increase in test loss
A flat minimum generalizes better. Small perturbations to the data (or small distribution shifts between training and production) do not dramatically change the loss at a flat minimum. They do at a sharp minimum.
SGD's gradient noise acts as a random perturbation during training. This perturbation tends to push the optimization trajectory out of sharp minima (high loss in the neighborhood) and into flat minima (low loss in the neighborhood). This is why SGD often generalizes better than batch GD or large-batch GD, even when both reach similar training losses.
Convergence Sketch with Decreasing Learning Rate
For convex objectives, SGD with a decreasing learning rate schedule converges. The schedule must satisfy the Robbins-Monro conditions:
The first condition ensures the algorithm can reach any point. The second ensures the noise from gradient variance eventually becomes negligible (the updates shrink fast enough).
The schedule satisfies both:
- : diverges (harmonic series generalization) - can reach any point
- : converges - noise eventually negligible
With this schedule, for convex objectives:
This rate is slower than batch GD's rate per gradient step - but SGD takes steps for each batch GD step, so in wall-clock time it is typically faster.
Full NumPy Implementation: All Three Variants
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
np.random.seed(42)
X, y = make_regression(n_samples=5000, n_features=20, noise=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
n_train, d = X_train.shape
class LinearModel:
"""Minimal linear model with shared predict/loss/r2 logic."""
def __init__(self, d):
self.w = np.zeros(d)
self.b = 0.0
def predict(self, X):
return X @ self.w + self.b
def mse(self, X, y):
return float(np.mean((self.predict(X) - y) ** 2))
def r2(self, X, y):
pred = self.predict(X)
ss_res = np.sum((y - pred) ** 2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
return float(1 - ss_res / ss_tot)
def batch_gd(X, y, X_val, y_val, lr=0.05, epochs=100):
"""Batch gradient descent - one weight update per epoch using all n samples."""
model = LinearModel(X.shape[1])
n = len(y)
train_losses, val_losses = [], []
for epoch in range(epochs):
residuals = model.predict(X) - y # shape (n,)
grad_w = (2 / n) * X.T @ residuals
grad_b = (2 / n) * residuals.sum()
model.w -= lr * grad_w
model.b -= lr * grad_b
train_losses.append(model.mse(X, y))
val_losses.append(model.mse(X_val, y_val))
return model, train_losses, val_losses
def mini_batch_gd(X, y, X_val, y_val, lr=0.05, epochs=100, batch_size=64):
"""
Mini-batch gradient descent.
Processes B samples per update, shuffles before each epoch.
"""
model = LinearModel(X.shape[1])
n = len(y)
train_losses, val_losses = [], []
for epoch in range(epochs):
# Essential: shuffle each epoch to ensure unbiased mini-batches
idx = np.random.permutation(n)
X_shuf, y_shuf = X[idx], y[idx]
for start in range(0, n, batch_size):
end = min(start + batch_size, n)
Xb = X_shuf[start:end]
yb = y_shuf[start:end]
B = end - start # actual batch size (last batch may be smaller)
residuals = Xb @ model.w + model.b - yb
model.w -= lr * (2 / B) * Xb.T @ residuals
model.b -= lr * (2 / B) * residuals.sum()
train_losses.append(model.mse(X, y))
val_losses.append(model.mse(X_val, y_val))
return model, train_losses, val_losses
def sgd(X, y, X_val, y_val, lr=0.01, epochs=50):
"""
Pure SGD - one sample per update.
Shuffles each epoch. Note: very slow in Python for large n.
Use mini-batch GD with B=1 in PyTorch for production.
"""
model = LinearModel(X.shape[1])
n = len(y)
train_losses, val_losses = [], []
for epoch in range(epochs):
idx = np.random.permutation(n)
X_shuf, y_shuf = X[idx], y[idx]
for i in range(n):
xi = X_shuf[i] # shape (d,)
yi = float(y_shuf[i])
pred_i = xi @ model.w + model.b
residual_i = pred_i - yi
# Single-sample gradient: (2/1) * residual * x
model.w -= lr * 2 * residual_i * xi
model.b -= lr * 2 * residual_i
train_losses.append(model.mse(X, y))
val_losses.append(model.mse(X_val, y_val))
print(f"SGD Epoch {epoch+1}/{epochs}: train_loss={train_losses[-1]:.2f}", end="\r")
print()
return model, train_losses, val_losses
# Run all variants
print("Running Batch GD...")
m_bgd, tl_bgd, vl_bgd = batch_gd(
X_train, y_train, X_test, y_test, lr=0.05, epochs=100)
print("Running Mini-Batch GD (B=32)...")
m_mb32, tl_mb32, vl_mb32 = mini_batch_gd(
X_train, y_train, X_test, y_test, lr=0.05, epochs=100, batch_size=32)
print("Running Mini-Batch GD (B=256)...")
m_mb256, tl_mb256, vl_mb256 = mini_batch_gd(
X_train, y_train, X_test, y_test, lr=0.05, epochs=100, batch_size=256)
print("Running Mini-Batch GD (B=1024)...")
m_mb1024, tl_mb1024, vl_mb1024 = mini_batch_gd(
X_train, y_train, X_test, y_test, lr=0.05, epochs=100, batch_size=1024)
print(f"\nFinal Test R²:")
print(f" Batch GD: {m_bgd.r2(X_test, y_test):.4f}")
print(f" Mini-Batch (B=32): {m_mb32.r2(X_test, y_test):.4f}")
print(f" Mini-Batch (B=256): {m_mb256.r2(X_test, y_test):.4f}")
print(f" Mini-Batch (B=1024): {m_mb1024.r2(X_test, y_test):.4f}")
Convergence Comparison Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
def smooth(data, k=5):
"""Rolling mean smoothing for noisy loss curves."""
return np.convolve(data, np.ones(k)/k, mode='valid')
# Training loss
ax = axes[0]
ax.semilogy(tl_bgd, label='Batch GD', linewidth=2.5, color='#2563eb')
ax.semilogy(smooth(tl_mb32, k=5), label='Mini-Batch B=32 (smoothed)',
linewidth=1.5, color='#16a34a', alpha=0.9)
ax.semilogy(smooth(tl_mb256, k=5), label='Mini-Batch B=256 (smoothed)',
linewidth=1.5, color='#ea580c', alpha=0.9)
ax.semilogy(smooth(tl_mb1024, k=5), label='Mini-Batch B=1024 (smoothed)',
linewidth=1.5, color='#7c3aed', alpha=0.9)
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE Loss (log scale)")
ax.set_title("Training Loss vs Epoch\n(smoothed for mini-batch variants)")
ax.legend()
ax.grid(True, alpha=0.3)
# Validation loss
ax = axes[1]
ax.semilogy(vl_bgd, label='Batch GD', linewidth=2.5, color='#2563eb')
ax.semilogy(smooth(vl_mb32, k=5), label='Mini-Batch B=32 (smoothed)',
linewidth=1.5, color='#16a34a', alpha=0.9)
ax.semilogy(smooth(vl_mb256, k=5), label='Mini-Batch B=256 (smoothed)',
linewidth=1.5, color='#ea580c', alpha=0.9)
ax.semilogy(smooth(vl_mb1024, k=5), label='Mini-Batch B=1024 (smoothed)',
linewidth=1.5, color='#7c3aed', alpha=0.9)
ax.set_xlabel("Epoch")
ax.set_ylabel("Validation MSE")
ax.set_title("Validation Loss vs Epoch")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("gd_comparison.png", dpi=150)
plt.show()
Full PyTorch DataLoader Mini-Batch Training
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
# Convert numpy arrays to PyTorch tensors
X_t = torch.FloatTensor(X_train)
y_t = torch.FloatTensor(y_train)
# TensorDataset + DataLoader handle shuffling and batching automatically
dataset = TensorDataset(X_t, y_t)
loader = DataLoader(dataset, batch_size=64, shuffle=True,
num_workers=0, drop_last=False)
# Linear model - equivalent to our NumPy version
model_pt = nn.Linear(d, 1)
optimizer = torch.optim.SGD(model_pt.parameters(), lr=0.05, momentum=0.9)
loss_fn = nn.MSELoss()
train_losses_pt = []
val_losses_pt = []
X_val_t = torch.FloatTensor(X_test)
y_val_t = torch.FloatTensor(y_test)
print("Training with PyTorch DataLoader:")
for epoch in range(100):
model_pt.train()
epoch_loss = 0.0
n_batches = 0
for Xb, yb in loader:
optimizer.zero_grad() # clear accumulated gradients
preds = model_pt(Xb).squeeze() # shape (B,)
loss = loss_fn(preds, yb)
loss.backward() # compute gradients via autograd
optimizer.step() # update weights
epoch_loss += loss.item() * len(yb)
n_batches += 1
avg_train_loss = epoch_loss / len(y_train)
train_losses_pt.append(avg_train_loss)
# Validation loss (no gradient computation needed)
model_pt.eval()
with torch.no_grad():
val_preds = model_pt(X_val_t).squeeze()
val_loss = loss_fn(val_preds, y_val_t).item()
val_losses_pt.append(val_loss)
if (epoch + 1) % 20 == 0:
print(f" Epoch {epoch+1:3d}/100: train_loss={avg_train_loss:.4f}, "
f"val_loss={val_loss:.4f}")
# Compute final R^2
with torch.no_grad():
preds_final = model_pt(X_val_t).squeeze().numpy()
ss_res = np.sum((y_test - preds_final)**2)
ss_tot = np.sum((y_test - np.mean(y_test))**2)
print(f"\nPyTorch model test R²: {1 - ss_res/ss_tot:.4f}")
Gradient Accumulation for Large Effective Batch Sizes
When the model or batch does not fit in GPU memory, gradient accumulation simulates a larger batch by accumulating gradients over multiple forward passes before updating weights:
def train_with_gradient_accumulation(model, loader, optimizer, loss_fn,
accumulation_steps=4, epochs=50):
"""
Gradient accumulation - effectively trains with batch_size * accumulation_steps
batch without the memory cost of a large batch.
Use case: large models or high-resolution inputs that do not fit in GPU memory
at a large batch size. Accumulate gradients from k small batches before stepping.
"""
train_losses = []
for epoch in range(epochs):
model.train()
epoch_loss = 0.0
optimizer.zero_grad() # clear at the start of each epoch
for step, (Xb, yb) in enumerate(loader):
preds = model(Xb).squeeze()
# Divide loss by accumulation_steps to maintain correct gradient scale
loss = loss_fn(preds, yb) / accumulation_steps
loss.backward() # accumulate gradients in model.parameters().grad
if (step + 1) % accumulation_steps == 0 or (step + 1) == len(loader):
# After accumulation_steps, apply the accumulated gradient
optimizer.step()
optimizer.zero_grad() # reset for next accumulation window
epoch_loss += loss.item() * accumulation_steps * len(yb)
train_losses.append(epoch_loss / (len(loader.dataset)))
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}: loss={train_losses[-1]:.4f}")
return train_losses
# Simulate: effective batch = 64 * 4 = 256 with only 64 in memory
model_acc = nn.Linear(d, 1)
opt_acc = torch.optim.SGD(model_acc.parameters(), lr=0.05, momentum=0.9)
print("\nTraining with gradient accumulation (effective batch = 64*4 = 256):")
losses_acc = train_with_gradient_accumulation(
model_acc, loader, opt_acc, loss_fn,
accumulation_steps=4, epochs=50
)
Learning Rate Schedules
A fixed learning rate is rarely optimal. The right schedule reduces the learning rate as training progresses, allowing large steps early (fast progress) and small steps later (fine-tuning near the minimum):
class LRScheduleComparison:
"""Run mini-batch GD with different learning rate schedules."""
@staticmethod
def constant(lr0, t, T):
return lr0
@staticmethod
def step_decay(lr0, t, T, drop=0.5, period=20):
"""Halve the learning rate every `period` epochs."""
return lr0 * (drop ** (t // period))
@staticmethod
def exponential(lr0, t, T, decay=0.97):
"""Multiply by decay each epoch."""
return lr0 * (decay ** t)
@staticmethod
def cosine(lr0, t, T, eta_min=1e-5):
"""Cosine annealing: smooth decay from lr0 to eta_min."""
return eta_min + 0.5 * (lr0 - eta_min) * (1 + np.cos(np.pi * t / T))
@staticmethod
def warmup_cosine(lr0, t, T, warmup_epochs=10, eta_min=1e-5):
"""
Linear warmup for `warmup_epochs` followed by cosine annealing.
Warmup prevents instability in early training with large batch sizes.
"""
if t < warmup_epochs:
return lr0 * (t + 1) / warmup_epochs
t_adj = t - warmup_epochs
T_adj = T - warmup_epochs
return eta_min + 0.5 * (lr0 - eta_min) * (1 + np.cos(np.pi * t_adj / T_adj))
def run_with_schedule(X, y, X_val, y_val, schedule_fn, lr0=0.1,
epochs=150, batch_size=64):
"""Run mini-batch GD with a given lr schedule function."""
model = LinearModel(X.shape[1])
n = len(y)
train_losses, val_losses, lrs = [], [], []
for epoch in range(epochs):
lr = schedule_fn(lr0, epoch, epochs)
lrs.append(lr)
idx = np.random.permutation(n)
X_shuf, y_shuf = X[idx], y[idx]
for start in range(0, n, batch_size):
end = min(start + batch_size, n)
Xb, yb = X_shuf[start:end], y_shuf[start:end]
B = end - start
residuals = Xb @ model.w + model.b - yb
model.w -= lr * (2 / B) * Xb.T @ residuals
model.b -= lr * (2 / B) * residuals.sum()
train_losses.append(model.mse(X, y))
val_losses.append(model.mse(X_val, y_val))
return model, train_losses, val_losses, lrs
sched = LRScheduleComparison()
T = 150
schedule_configs = {
'Constant': lambda lr0, t, T: sched.constant(lr0, t, T),
'Step Decay': lambda lr0, t, T: sched.step_decay(lr0, t, T),
'Exponential': lambda lr0, t, T: sched.exponential(lr0, t, T),
'Cosine': lambda lr0, t, T: sched.cosine(lr0, t, T),
'Warmup+Cosine': lambda lr0, t, T: sched.warmup_cosine(lr0, t, T),
}
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
colors = ['#2563eb', '#16a34a', '#ea580c', '#7c3aed', '#dc2626']
for (name, fn), color in zip(schedule_configs.items(), colors):
m, tl, vl, lrs = run_with_schedule(
X_train, y_train, X_test, y_test, fn, lr0=0.1, epochs=T)
axes[0].semilogy(tl, label=name, linewidth=1.5, color=color)
axes[1].semilogy(vl, label=name, linewidth=1.5, color=color)
axes[2].plot(lrs, label=name, linewidth=1.5, color=color)
for ax, title in zip(axes, ["Training Loss", "Validation Loss", "Learning Rate"]):
ax.set_xlabel("Epoch")
ax.set_ylabel(title)
ax.set_title(title + " vs Epoch")
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
plt.suptitle("Learning Rate Schedule Comparison", y=1.02, fontsize=13)
plt.tight_layout()
plt.savefig("lr_schedules.png", dpi=150)
plt.show()
Cyclic Learning Rates (Smith 2017)
Leslie Smith's 2017 paper introduced cyclic learning rates (CLR) - an alternative to monotonically decreasing schedules. The learning rate oscillates between a minimum and maximum value on a triangular or cosine cycle:
def cyclic_lr(lr_min, lr_max, step, cycle_length=20, mode='triangular'):
"""
Cyclic learning rate (Smith 2017).
Oscillates between lr_min and lr_max over `cycle_length` steps.
mode='triangular': linear ramp up and down
mode='cosine': smooth cosine curve (cosine annealing with restarts)
Why it works: periodically increasing the lr helps escape sharp minima
(Loshchilov & Hutter 2016, SGDR). The high lr phases explore the landscape;
the low lr phases settle into local minima.
"""
cycle = step % cycle_length
half_cycle = cycle_length // 2
if mode == 'triangular':
if cycle < half_cycle:
return lr_min + (lr_max - lr_min) * cycle / half_cycle
else:
return lr_max - (lr_max - lr_min) * (cycle - half_cycle) / half_cycle
elif mode == 'cosine':
return lr_min + 0.5 * (lr_max - lr_min) * (1 - np.cos(np.pi * cycle / cycle_length))
else:
raise ValueError(f"Unknown mode: {mode}")
# Demonstrate cyclic LR vs cosine schedule
model_clr = LinearModel(d)
clr_losses = []
clr_lrs = []
n = len(y_train)
epochs_clr = 150
for epoch in range(epochs_clr):
lr = cyclic_lr(lr_min=0.001, lr_max=0.1, step=epoch, cycle_length=30,
mode='cosine')
clr_lrs.append(lr)
idx = np.random.permutation(n)
X_shuf, y_shuf = X_train[idx], y_train[idx]
for start in range(0, n, 64):
end = min(start + 64, n)
Xb, yb = X_shuf[start:end], y_shuf[start:end]
B = end - start
residuals = Xb @ model_clr.w + model_clr.b - yb
model_clr.w -= lr * (2 / B) * Xb.T @ residuals
model_clr.b -= lr * (2 / B) * residuals.sum()
clr_losses.append(float(np.mean((X_train @ model_clr.w + model_clr.b - y_train)**2)))
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].semilogy(clr_losses, label='CLR (cosine, cycle=30)', color='#16a34a')
axes[0].set_xlabel("Epoch"); axes[0].set_ylabel("Training MSE (log)")
axes[0].set_title("Cyclic LR Training Loss")
axes[0].legend(); axes[0].grid(True, alpha=0.3)
axes[1].plot(clr_lrs, color='#7c3aed')
axes[1].set_xlabel("Epoch"); axes[1].set_ylabel("Learning Rate")
axes[1].set_title("Cyclic Learning Rate Schedule (cosine, cycle=30)")
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("cyclic_lr.png", dpi=150)
plt.show()
The Linear Scaling Rule (Goyal et al. 2017)
Facebook AI Research's 2017 paper training ResNet-50 on ImageNet in 1 hour introduced the linear scaling rule: when you multiply the batch size by , multiply the learning rate by .
The intuition: with batch size , one epoch makes gradient steps. With batch size , one epoch makes steps. To make the same expected progress per epoch, each step must make times as much progress - which requires a times larger learning rate.
More precisely: with batch size and learning rate , in expectation one step moves:
With batch size and learning rate :
Same expected update magnitude per epoch (since there are fewer steps per epoch). But the variance of each step is times smaller (more samples per batch), so the signal-to-noise ratio is preserved.
Practical rule: double batch size → double learning rate. But this only works up to a point - at very large batch sizes (> 8192 for ImageNet-scale problems), the generalization gap from sharp minima outweighs the computational gains.
def linear_scaling_experiment(X_train, y_train, X_test, y_test,
base_bs=64, base_lr=0.05, epochs=100):
"""
Demonstrate the linear scaling rule.
Doubling batch size + doubling learning rate should give similar final loss.
"""
configs = [
(base_bs, base_lr, 'B=64, lr=0.05 (baseline)'),
(base_bs * 2, base_lr * 2, 'B=128, lr=0.10 (2x - linear scaling)'),
(base_bs * 4, base_lr * 4, 'B=256, lr=0.20 (4x - linear scaling)'),
(base_bs * 4, base_lr, 'B=256, lr=0.05 (4x batch, no lr scale)'),
]
print(f"\nLinear Scaling Rule Experiment:")
print(f"{'Config':45s} | {'Final Val Loss':>15} | {'Test R²':>8}")
print("-" * 75)
for bs, lr, label in configs:
m, tl, vl = mini_batch_gd(
X_train, y_train, X_test, y_test,
lr=lr, epochs=epochs, batch_size=bs
)
print(f"{label:45s} | {vl[-1]:>15.4f} | {m.r2(X_test, y_test):>8.4f}")
linear_scaling_experiment(X_train, y_train, X_test, y_test)
Distributed SGD: AllReduce and Ring-AllReduce
At scale (multiple GPUs, multiple machines), each worker computes the gradient on its local data shard and the gradients must be aggregated:
AllReduce operation: Each of workers computes a local gradient . The global gradient is . After AllReduce, every worker has the same average gradient.
Ring-AllReduce (Baidu 2017, used in Horovod): Instead of routing all gradients through a central parameter server, workers are arranged in a ring. Each step, each worker sends of its gradient to the next worker and receives from the previous worker. After send-receive cycles, every worker has accumulated the full sum. This uses bytes of network communication - nearly optimal (vs. for parameter server), and scales linearly with the number of workers.
def simulate_distributed_sgd(X, y, n_workers=4, local_batch=64,
lr=0.05, epochs=50):
"""
Simulate distributed SGD with AllReduce on a single machine.
In production: use torch.distributed, Horovod, or DeepSpeed.
Each worker processes local_batch samples; gradients are averaged.
"""
n, d_feat = X.shape
global_model_w = np.zeros(d_feat)
global_model_b = 0.0
dist_losses = []
for epoch in range(epochs):
# Simulate K workers each computing a gradient on their own mini-batch
grad_w_accum = np.zeros(d_feat)
grad_b_accum = 0.0
for worker_id in range(n_workers):
# Each worker independently samples a mini-batch
idx = np.random.choice(n, local_batch, replace=False)
Xb, yb = X[idx], y[idx]
residuals = Xb @ global_model_w + global_model_b - yb
local_grad_w = (2 / local_batch) * Xb.T @ residuals
local_grad_b = (2 / local_batch) * residuals.sum()
# AllReduce: accumulate local gradients
grad_w_accum += local_grad_w
grad_b_accum += local_grad_b
# Average the accumulated gradients across all workers
# In ring-AllReduce this happens via peer-to-peer communication
avg_grad_w = grad_w_accum / n_workers
avg_grad_b = grad_b_accum / n_workers
# Single update with the averaged gradient (equivalent to:
# effective batch size = n_workers * local_batch)
global_model_w -= lr * avg_grad_w
global_model_b -= lr * avg_grad_b
# Track loss with the current global model
preds = X @ global_model_w + global_model_b
dist_losses.append(float(np.mean((y - preds)**2)))
final_preds = X @ global_model_w + global_model_b
ss_res = np.sum((y - final_preds)**2)
ss_tot = np.sum((y - np.mean(y))**2)
r2 = 1 - ss_res / ss_tot
print(f"\nDistributed SGD ({n_workers} workers, local_B={local_batch}, "
f"effective_B={n_workers * local_batch}):")
print(f" Final train R²: {r2:.4f}")
print(f" Final train MSE: {dist_losses[-1]:.4f}")
return global_model_w, dist_losses
w_dist, dist_losses = simulate_distributed_sgd(X_train, y_train, n_workers=4)
Practical Batch Size Selection
| Batch Size | Gradient Variance | Regularization | GPU Efficiency | Use When |
|---|---|---|---|---|
| 1 (pure SGD) | Maximum | Strongest | Very low | Online learning, tiny datasets |
| 16–64 | High | Strong | Moderate | Standard settings, limited GPU memory |
| 128–512 | Medium | Moderate | High | Production deep learning |
| 1024–4096 | Low | Weak | Very high | Distributed training with linear scaling rule |
| Full batch | None | None | N/A | Small datasets, convex problems |
:::warning Large batch generalization gap Training with very large batches (> 4096 for most architectures) tends to find sharp minima with poor generalization (Keskar et al. 2016). If you must use large batches (for hardware efficiency), compensate with: (1) the linear scaling rule for learning rate; (2) longer warmup period; (3) explicit data augmentation; (4) label smoothing; (5) Sharpness-Aware Minimization (SAM). The generalization gap with batch sizes > 8192 is well-documented and requires active mitigation. :::
Common Mistakes
:::danger Not shuffling data before each epoch
# WRONG - biased gradients if data is sorted by class or time
for start in range(0, n, batch_size):
Xb = X[start:start + batch_size]
# CORRECT - shuffle before each epoch
idx = np.random.permutation(n)
X_shuf = X[idx]
for start in range(0, n, batch_size):
Xb = X_shuf[start:start + batch_size]
Without shuffling, consecutive batches contain correlated or class-homogeneous samples. Early batches may only see one class, causing dramatic gradient bias. This is especially dangerous for imbalanced datasets sorted by label. :::
:::danger Applying fit_transform to test data (data leakage)
# WRONG - leaks test distribution statistics into training
scaler = StandardScaler()
X_train_sc = scaler.fit_transform(X_train)
X_test_sc = scaler.fit_transform(X_test) # BUG: should use .transform()
# CORRECT
scaler = StandardScaler()
X_train_sc = scaler.fit_transform(X_train) # fit on train only
X_test_sc = scaler.transform(X_test) # transform using train statistics
The scaler must be fitted on training data only. Using test statistics during test preprocessing is a subtle but serious form of data leakage. :::
:::warning Comparing models trained with different batch sizes using the same epoch count
When you halve the batch size, each epoch makes twice as many gradient updates. A fair comparison requires the same number of updates, not the same number of epochs. Use n_updates = n_epochs * (n / batch_size) to normalize when comparing different batch sizes.
:::
YouTube Resources
| Resource | Channel | Why Watch |
|---|---|---|
| Stochastic Gradient Descent Clearly Explained | StatQuest with Josh Starmer | Best clear intro to SGD vs batch GD with visual examples |
| Large Scale Distributed Deep Learning | Google Brain Talk | Jeff Dean on AllReduce, ring-AllReduce, and distributed training at scale |
| SGDR: Cosine Annealing with Warm Restarts | Fast.ai | Loshchilov's cosine restart schedule and why it helps |
| Training ImageNet in 1 Hour (Goyal et al.) | Facebook AI | The linear scaling rule and batch size vs learning rate in practice |
| Why Large Batch Training Hurts Generalization | ICLR talks | Keskar et al. sharp vs flat minima - the foundational result |
Interview Q&A
Q1: Why does SGD generalize better than batch gradient descent in many settings?
SGD introduces gradient noise proportional to . This noise acts as implicit regularization: it steers optimization away from sharp minima (high Hessian eigenvalues - sensitive to distribution shift) toward flat minima (low Hessian eigenvalues - robust to distribution shift). Keskar et al. (2016) demonstrated this empirically on neural networks: large-batch training converges to sharper minima with worse test accuracy, even when training accuracy is similar. The mechanism: noise perturbs the model out of sharp narrow basins before it converges there, favoring wide flat basins which have lower test loss under distribution shift.
Q2: How do you choose batch size in practice?
Start with 64–256 - this saturates GPU memory bandwidth and compute efficiently on most hardware. Key considerations: (1) If GPU memory is the bottleneck, use the largest batch that fits, with gradient accumulation for an effective larger batch. (2) If generalization is suffering, reduce batch size or add explicit regularization. (3) For distributed training, use the linear scaling rule: batch larger → learning rate larger, with linear warmup over the first ~5 epochs. (4) Monitor both train and validation loss - if they diverge with large batches, you are in the generalization gap regime.
Q3: What is the relationship between batch size, learning rate, and generalization?
The key identity: . Larger batch → lower variance → more signal, less noise per step. The linear scaling rule () maintains the same signal-to-noise ratio per epoch by compensating the lower variance with a larger step. But beyond a critical batch size, the noise level falls below the threshold needed to escape sharp minima, causing a generalization gap. This threshold is problem-dependent but is typically 8192–65536 for large deep learning tasks.
Q4: Why must you shuffle data before each epoch in mini-batch SGD?
Without shuffling, the gradient estimates have a fixed correlation structure. For sorted data (by label, time, or any feature), early batches systematically over-represent one class/regime and gradients are heavily biased. The model oscillates between fitting different biased subsets rather than converging. Shuffling ensures each mini-batch is an approximately iid sample from the training distribution, making each gradient an unbiased estimate of the full gradient in expectation - the theoretical requirement for SGD convergence. PyTorch's DataLoader with shuffle=True handles this automatically. For datasets too large to shuffle globally, shuffle within a large buffer.
Q5: What is gradient accumulation and when do you use it?
Gradient accumulation simulates a larger effective batch by accumulating gradients from small forward passes before applying the weight update. The total gradient applied is , equivalent to a single gradient step with batch . Use it when: (1) the ideal batch size (for stability or linear scaling rule) does not fit in GPU memory; (2) training large models (LLMs, diffusion models) where even a batch of 1 barely fits in memory; (3) maintaining reproducibility when changing hardware with different memory capacities. Key implementation detail: divide the loss by the number of accumulation steps before calling .backward() to correctly scale the gradient.
Q6: Explain distributed SGD and the AllReduce operation.
In data-parallel distributed training, each worker (GPU) holds a copy of the model and processes a different shard of the data. Each worker computes a local gradient on its shard. AllReduce aggregates these local gradients by summing them across all workers and broadcasting the average back to each worker. Every worker then applies the same averaged gradient, keeping all model copies synchronized.
Ring-AllReduce (Baidu 2017, used in Horovod, PyTorch DistributedDataParallel) is the standard implementation. Workers are arranged in a logical ring. In communication steps (where = number of workers), each worker accumulates the full gradient sum using only point-to-point communication. This uses times the gradient size in total network traffic - nearly optimal. It avoids the bandwidth bottleneck of a central parameter server.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Optimizer Race demo on the EngineersOfAI Playground - no code required.
:::
