Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the CUDA Programming Model demo on the EngineersOfAI Playground - no code required. :::

GPU Memory Management

The OOM That Should Not Have Happened

The training run had been going for 6 hours when it died: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.50 GiB (GPU 0; 39.59 GiB total capacity; 36.83 GiB already allocated; 1.47 GiB free; 37.50 GiB reserved by PyTorch using the caching allocator).

The engineer was confused. They were training GPT-2 Large - 774 million parameters - on an A100 with 40 GB of VRAM. GPT-2 Large's weights in float16 take 1.55 GB. The model should fit comfortably. Why was 36.83 GB already allocated?

The answer required understanding what actually gets stored in GPU memory during training. Model weights are only a small fraction of the total. The gradients double the footprint. The optimizer states (Adam's first and second moment buffers) triple it again. And activations - the intermediate values saved during the forward pass for use in the backward pass - can dwarf all of these combined for large batch sizes and long sequences.

The breakdown for this 774M parameter model in mixed precision training with batch size 8, sequence length 1024:

  • Model weights (fp16): 1.55 GB
  • Gradients (fp16): 1.55 GB
  • Optimizer states (fp32): 9.3 GB (Adam: first moment + second moment + master weights)
  • Activations (fp16): ~28 GB (the true culprit)

Total: ~40 GB. They had hit the exact VRAM limit, and PyTorch's memory allocator overhead pushed them over.

Understanding why activations dominate, what tools exist to reduce them, and how to systematically attack the memory budget is the core skill this lesson teaches.


Why Memory Management Is Critical

VRAM is the hardest constraint in deep learning. You can scale compute by adding GPU time. You cannot trivially scale VRAM - when your model and its training state exceeds the VRAM of a single GPU, you either need fundamentally different algorithms (gradient checkpointing, model parallelism) or different hardware. Making the wrong decision here adds weeks of debugging and migration work.

The stakes are particularly high because:

  1. OOM errors appear at runtime, often deep into a training run
  2. Many teams plan batch sizes and model sizes in parameter counts without computing actual memory requirements
  3. Memory requirements scale non-linearly with sequence length for transformer attention (quadratic in sequence length before FlashAttention)

What Lives in GPU Memory During Training

For a model with PP parameters, mixed precision training requires:

ComponentSizeNotes
Model weights (fp16/bf16)P×2P \times 2 bytesActive computation
Gradients (fp16/bf16)P×2P \times 2 bytesAccumulated during backward
Master weights (fp32)P×4P \times 4 bytesFor stable optimizer updates
Adam first moment (fp32)P×4P \times 4 bytesmtm_t - gradient moving average
Adam second moment (fp32)P×4P \times 4 bytesvtv_t - squared gradient avg
ActivationsVariableSaved for backward pass
CUDA allocator overhead5–15%Fragmentation, buffer pools

For a 7B parameter model:

  • Weights: 14 GB (fp16)
  • Gradients: 14 GB
  • Optimizer states: 84 GB
  • Optimizer total alone: 112 GB - requires at least 3 × 40 GB GPUs just for optimizer states

This is why training 7B+ models on a single GPU is impossible even with 80 GB VRAM.

def estimate_training_memory_gb(
n_parameters: int,
sequence_length: int,
batch_size: int,
n_layers: int,
hidden_dim: int,
n_heads: int,
dtype: str = "fp16", # or "fp32", "bf16"
) -> dict:
"""
Estimate peak VRAM requirements for transformer model training.
"""
bytes_per_param = {"fp32": 4, "fp16": 2, "bf16": 2}[dtype]
fp32_bytes = 4

# Model weights
weights_gb = (n_parameters * bytes_per_param) / 1e9

# Gradients (same dtype as weights)
grads_gb = weights_gb

# Adam optimizer states (always fp32)
# master weights + first moment + second moment
optimizer_gb = (n_parameters * fp32_bytes * 3) / 1e9

# Activation memory per layer (rough estimate for transformer)
# Saved activations for attention: batch * heads * seq * seq
# + feedforward activations: batch * seq * 4 * hidden
attn_activations = batch_size * n_heads * sequence_length * sequence_length * 2 # fp16
ff_activations = batch_size * sequence_length * 4 * hidden_dim * 2 # fp16
activation_per_layer_gb = (attn_activations + ff_activations) / 1e9
total_activation_gb = activation_per_layer_gb * n_layers

# CUDA allocator overhead (~10%)
subtotal = weights_gb + grads_gb + optimizer_gb + total_activation_gb
overhead_gb = subtotal * 0.10

return {
"weights_gb": round(weights_gb, 2),
"gradients_gb": round(grads_gb, 2),
"optimizer_states_gb": round(optimizer_gb, 2),
"activations_gb": round(total_activation_gb, 2),
"overhead_gb": round(overhead_gb, 2),
"total_gb": round(subtotal + overhead_gb, 2),
"note": "Activations dominate for large batch sizes / long sequences",
}

# GPT-2 Large: 774M params, 36 layers, hidden=1280, heads=20
plan = estimate_training_memory_gb(
n_parameters=774_000_000,
sequence_length=1024,
batch_size=8,
n_layers=36,
hidden_dim=1280,
n_heads=20,
)
for key, value in plan.items():
print(f" {key}: {value}")

Mixed Precision Training (FP16/BF16)

Mixed precision training (Micikevicius et al., 2018) uses FP16 or BF16 for forward/backward computation while keeping FP32 master weights for optimizer updates. This halves the memory for weights and activations while maintaining training stability.

import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

def train_with_mixed_precision(
model: nn.Module,
optimizer: torch.optim.Optimizer,
data_loader,
n_epochs: int = 3,
use_bf16: bool = True, # prefer bf16 for stability
):
"""
Mixed precision training using torch.cuda.amp.
BF16 is recommended over FP16 for A100/H100 (same Tensor Core speed, more stable).
"""
device = torch.device("cuda")
model = model.to(device)

# GradScaler prevents underflow in FP16 (not needed for BF16 but harmless)
# For BF16: scaler is effectively disabled (scale stays at 1.0)
scaler = GradScaler(enabled=not use_bf16)

dtype = torch.bfloat16 if use_bf16 else torch.float16

for epoch in range(n_epochs):
model.train()
for batch_idx, (inputs, labels) in enumerate(data_loader):
inputs = inputs.to(device)
labels = labels.to(device)

optimizer.zero_grad()

# Forward pass in reduced precision
with autocast(device_type="cuda", dtype=dtype):
outputs = model(inputs)
loss = nn.functional.cross_entropy(outputs, labels)

# Backward pass (scaler handles fp16 gradient scaling)
scaler.scale(loss).backward()

# Unscale gradients and clip (clip in fp32 space)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Optimizer step
scaler.step(optimizer)
scaler.update()

if batch_idx % 100 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
print(f" Scale factor: {scaler.get_scale():.0f}")
print(f" GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

Activation Checkpointing (Gradient Checkpointing)

Activation checkpointing (Chen et al., 2016) is the most powerful memory reduction technique for training. It trades compute for memory: instead of saving all activations during the forward pass, it saves only a subset (checkpoints) and recomputes intermediate activations during the backward pass as needed.

Without checkpointing: Save all activations, use O(n_layers) memory. With checkpointing: Save only O(√n_layers) activations, recompute the rest. Memory reduction: 4–8× for deep networks. Compute overhead: 30–40% (one extra forward pass worth of compute).

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint

class TransformerLayer(nn.Module):
def __init__(self, hidden_dim: int, n_heads: int, ff_dim: int):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_dim, n_heads, batch_first=True)
self.norm1 = nn.LayerNorm(hidden_dim)
self.ff = nn.Sequential(
nn.Linear(hidden_dim, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, hidden_dim),
)
self.norm2 = nn.LayerNorm(hidden_dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
attn_out, _ = self.attention(x, x, x, need_weights=False)
x = self.norm1(x + attn_out)
ff_out = self.ff(x)
x = self.norm2(x + ff_out)
return x


class TransformerWithCheckpointing(nn.Module):
def __init__(
self,
n_layers: int,
hidden_dim: int,
n_heads: int,
use_checkpoint: bool = True,
):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayer(hidden_dim, n_heads, hidden_dim * 4)
for _ in range(n_layers)
])
self.use_checkpoint = use_checkpoint

def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
if self.use_checkpoint and self.training:
# Gradient checkpointing: recompute this layer's activations
# during backward pass instead of storing them
x = checkpoint.checkpoint(
layer,
x,
use_reentrant=False, # newer, preferred API
)
else:
x = layer(x)
return x


def compare_memory_usage():
"""Measure peak VRAM with and without gradient checkpointing."""
device = torch.device("cuda")
n_layers, hidden_dim, n_heads = 24, 1024, 16
batch_size, seq_len = 8, 512

for use_ckpt in [False, True]:
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()

model = TransformerWithCheckpointing(
n_layers=n_layers,
hidden_dim=hidden_dim,
n_heads=n_heads,
use_checkpoint=use_ckpt,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
x = torch.randn(batch_size, seq_len, hidden_dim, device=device)
labels = torch.randn(batch_size, seq_len, hidden_dim, device=device)

optimizer.zero_grad()
out = model(x)
loss = nn.functional.mse_loss(out, labels)
loss.backward()
optimizer.step()

peak_mb = torch.cuda.max_memory_allocated() / 1e6
print(f"Gradient checkpointing={use_ckpt}: Peak VRAM = {peak_mb:.0f} MB")

ZeRO Optimizer Stages

ZeRO (Zero Redundancy Optimizer, Rajbhandari et al., 2020) is a memory optimization technique for distributed training that partitions optimizer state across GPUs rather than replicating it.

Stage 1 (ZeRO-1): Partition optimizer states (Adam moments) across GPUs. Each GPU stores 1/N of the optimizer states. Memory saving: 4× for optimizer states.

Stage 2 (ZeRO-2): Partition gradients + optimizer states. Memory saving: 8× vs standard DDP.

Stage 3 (ZeRO-3): Partition parameters + gradients + optimizer states. Each GPU stores only 1/N of all model parameters. Memory saving: scales linearly with number of GPUs. Tradeoff: extra communication for parameter all-gathers during forward pass.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, StateDictType
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
import os
import torch.distributed as dist

def setup_fsdp_training(
model: nn.Module,
layer_class, # e.g., TransformerLayer - the class to wrap per-layer
use_bf16: bool = True,
):
"""
Configure FSDP (PyTorch's ZeRO-3 equivalent).
Partitions parameters, gradients, and optimizer states across GPUs.
"""
# Initialize distributed process group
dist.init_process_group("nccl")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)

# Mixed precision configuration
mixed_precision_policy = MixedPrecision(
param_dtype=torch.bfloat16 if use_bf16 else torch.float16,
reduce_dtype=torch.bfloat16 if use_bf16 else torch.float16,
buffer_dtype=torch.bfloat16 if use_bf16 else torch.float16,
)

# Auto wrap: each TransformerLayer gets its own FSDP unit
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={layer_class},
)

fsdp_model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
# sharding_strategy controls ZeRO stage:
# FULL_SHARD = ZeRO-3 (params + grads + optimizer)
# SHARD_GRAD_OP = ZeRO-2 (grads + optimizer only)
# NO_SHARD = DDP (replicates everything)
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device(),
)

return fsdp_model

ZeRO Stage Comparison

StrategyMemory Per GPU (7B params, 8 GPUs)Communication
Naive DDP112 GBAll-reduce gradients
ZeRO-198 GBScatter optimizer, all-reduce grads
ZeRO-256 GBScatter grads + optimizer
ZeRO-314 GBAll-gather params + reduce-scatter grads
ZeRO-3 + CPU offload~3 GB GPU+ CPU memory transfers

CPU Offloading

For maximum memory efficiency, ZeRO-3 can offload optimizer states and even parameters to CPU RAM. The tradeoff: CPU-GPU transfers are slow (PCIe: ~20 GB/s). CPU offloading is only viable when the saved VRAM enables a larger batch size that more than compensates for the transfer overhead.

# DeepSpeed ZeRO-3 with CPU offloading configuration
deepspeed_config = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True, # use pinned (non-pageable) CPU memory for faster transfers
},
"offload_param": {
"device": "cpu",
"pin_memory": True,
},
"overlap_comm": True, # overlap communication with computation
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
},
"bf16": {
"enabled": True,
},
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 8, # effective batch size = 4 * 8 * n_gpus
}

Debugging OOM Errors

When you hit a CUDA OOM, the default error message tells you how much memory was free but not what allocated the memory. These tools help.

import torch

def debug_oom():
"""
Tools for diagnosing CUDA OOM errors.
"""
device = torch.cuda.current_device()

# 1. Memory summary: shows all allocations by category
print("=== Memory Summary ===")
print(torch.cuda.memory_summary(device=device, abbreviated=False))

# 2. Identify largest tensors currently allocated
print("\n=== Largest Tensors in GPU Memory ===")
tensor_sizes = []
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) and obj.is_cuda:
size_mb = obj.nelement() * obj.element_size() / 1e6
tensor_sizes.append((size_mb, tuple(obj.shape), obj.dtype))
except Exception:
pass

tensor_sizes.sort(reverse=True)
for size_mb, shape, dtype in tensor_sizes[:20]:
print(f" {size_mb:.1f} MB {shape} {dtype}")

# 3. Memory snapshot for detailed allocation tracking
# (PyTorch 2.0+)
torch.cuda.memory._record_memory_history(max_entries=100_000)
# ... run your code ...
snapshot = torch.cuda.memory._snapshot()
# Save and visualize at https://pytorch.org/memory_viz


def safe_train_step(model, batch, optimizer, scaler):
"""Training step with OOM recovery."""
try:
optimizer.zero_grad()
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model(batch["input"])
loss = output.loss

scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
return loss.item()

except torch.cuda.OutOfMemoryError:
print("OOM! Clearing cache and skipping batch")
torch.cuda.empty_cache()
optimizer.zero_grad()
return None

Gradient Accumulation

When batch size is limited by VRAM, gradient accumulation simulates a larger effective batch by accumulating gradients over multiple small batches before stepping the optimizer.

def train_with_gradient_accumulation(
model,
data_loader,
optimizer,
accumulation_steps: int = 8,
):
"""
Gradient accumulation: effective batch = batch_size * accumulation_steps.
Allows training with large effective batch sizes on limited VRAM.
"""
model.train()
optimizer.zero_grad()

for step, (inputs, labels) in enumerate(data_loader):
inputs = inputs.cuda()
labels = labels.cuda()

with torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model(inputs)
# Divide loss by accumulation_steps to get correct gradient scale
loss = nn.functional.cross_entropy(output, labels) / accumulation_steps

loss.backward()

if (step + 1) % accumulation_steps == 0:
# Gradient clip and step after accumulating
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
print(f"Step {(step + 1) // accumulation_steps}: "
f"loss = {loss.item() * accumulation_steps:.4f}")

Production Engineering Notes

Always set a memory fraction limit for multi-process training. If multiple training jobs share a GPU (MIG or MPS), use torch.cuda.set_per_process_memory_fraction(0.9) to prevent one process from grabbing all memory and OOM-killing others.

Profile memory over the full training step, not just forward pass. Peak memory occurs during the backward pass when both forward activations and gradients are live simultaneously. Many engineers profile only the forward pass and underestimate peak memory by 2–3×.

Use gradient checkpointing selectively. Checkpoint only the most memory-intensive layers (attention layers). Leaving feedforward layers un-checkpointed avoids most of the compute overhead while still recovering most of the memory.


Common Mistakes

:::danger Forgetting that optimizer states dominate model weight memory for Adam For a 1B parameter model, weights = 2 GB (fp16), but Adam optimizer states = 12 GB (fp32 master weights + first + second moments). Engineers plan memory based on model size and are surprised when training needs 6× more than the model's footprint. Always include optimizer state estimates in your memory budget. :::

:::warning Using large batch sizes to "maximize GPU utilization" without checking memory impact Larger batches increase GPU utilization and throughput - up to a point. At large sequence lengths, activation memory grows as O(batch × sequence_length) for attention. Doubling batch size can more than double memory due to activation memory growth. Use gradient accumulation instead of increasing batch size if you are near VRAM limits. :::

:::tip The sequence length budget for attention is quadratic For standard attention, memory grows as O(batch × heads × seq²). With sequence_length=2048, memory is 4× that of seq_length=1024. FlashAttention reduces this to O(seq) by not materializing the full attention matrix. If you are training on long documents, FlashAttention is not optional - it is the difference between fitting in VRAM or not. :::


Interview Questions

Q1: Explain activation checkpointing. When do you use it and what is the tradeoff?

During backpropagation, PyTorch needs the activations saved during the forward pass to compute gradients. By default, all activations are kept in memory until the backward pass processes them - this is O(n_layers) memory. Activation checkpointing (or gradient checkpointing) discards activations during the forward pass and recomputes them on-demand during the backward pass using the saved checkpoints. Memory reduces to O(√n_layers) for optimal checkpoint placement, at the cost of 30–40% more compute (the recomputed layers effectively run twice). Use it when: training deep models (12+ layers) at batch sizes where activation memory dominates, or when you need to fit a model that marginally exceeds VRAM capacity.

Q2: What are ZeRO stages 1, 2, and 3? When would you choose each?

ZeRO partitions training state across GPUs instead of replicating it. Stage 1: partition optimizer states (Adam moments) across GPUs - 4× memory reduction vs DDP, minimal communication overhead. Use for models that almost fit in GPU memory. Stage 2: partition optimizer states + gradients - 8× reduction. Use when stages 1 is insufficient and you can afford slightly more communication. Stage 3: partition everything including model parameters - memory scales with 1/N where N is number of GPUs. Use for models that do not fit on a single GPU even with stages 1/2 (>10B parameters on 40 GB GPUs). Stage 3 adds all-gather operations during the forward pass which increases communication - only worth it when memory is the hard constraint.

Q3: You are training a model and hit OOM at the exact same step every run. What does this tell you, and how do you debug it?

An OOM that occurs at a specific step (rather than randomly) indicates the memory usage grows over time - classic symptoms include a gradient that is not being properly freed (still referenced somewhere), an optimizer state that accumulates allocations, or a data loader that builds up tensors. Debug tools: (1) torch.cuda.memory_snapshot() before and after the OOM step to see which tensors grew; (2) check if optimizer.zero_grad(set_to_none=True) is being called correctly - zero_grad() without set_to_none=True sets gradients to zero but keeps the tensor allocated; (3) check for any accumulation of tensors in lists or dicts in training loops (logging buffers, history lists) that grow unboundedly.

Q4: Mixed precision training uses FP16 for computation but FP32 for optimizer states. Why not FP16 for optimizer states too?

FP16 has limited dynamic range (exponents 5 bits, max value ~65504). Optimizer states (Adam's first and second moments) are running averages that can have very small values, particularly the second moment vtv_t early in training before it warms up. If vtv_t underflows to zero in FP16, the update step θt=θt1ηmt/(vt+ϵ)\theta_t = \theta_{t-1} - \eta \cdot m_t / (\sqrt{v_t} + \epsilon) explodes. BF16 has the same exponent range as FP32 (8 exponent bits) and avoids underflow at the cost of lower mantissa precision - which is why BF16 is preferred for training. Even with BF16, master weights are typically kept in FP32 to avoid accumulated precision loss over millions of gradient steps.

Q5: What is the memory cost difference between a 7B parameter model in inference vs training on a single GPU?

Inference (FP16): 7B × 2 bytes = 14 GB for weights + ~1–2 GB for KV cache (batch=1, seq=1024) = ~16 GB. Fits on an A100 40GB with room for larger batches. Training with Adam (FP16 compute, FP32 optimizer): weights 14 GB + gradients 14 GB + Adam states (master weights + m + v): 7B × 4 × 3 = 84 GB + activations (~30–50 GB for batch=4, seq=1024). Total: ~160–180 GB. This does not fit on a single A100 80GB - you need at least 3 A100s with ZeRO-2 or 2 A100s with ZeRO-3. This is why inference serving is much cheaper than training at equivalent model size: it needs only 14 GB versus 180 GB.

© 2026 EngineersOfAI. All rights reserved.