Skip to main content

Large-Scale Memory Optimization

The Production Scenario

A team at a mid-size AI company receives approval to fine-tune Llama 3 70B on proprietary data. They have a budget for eight A100 80GB GPUs. They are told that training starts in two weeks. The naive calculation: 70 billion parameters at BF16 (2 bytes each) is 140 GB just for the weights. Eight A100s have 640 GB total VRAM. Plenty of room, they think.

They are wrong. Loading the model weights takes 140 GB. Each parameter has a corresponding gradient (another 140 GB). The AdamW optimizer stores two moment vectors per parameter (280 GB in FP32). That is 560 GB for weights plus optimizer state alone, with zero space for activations or the training batch. With a sequence length of 2048 and batch size of 4, the attention activations for a single layer require another 2-4 GB. For 80 layers, that is 160-320 GB of activations.

Total naive requirement: over 1 TB. They have 640 GB. The training job will OOM before the first step completes.

In the next two weeks, the team has to learn gradient checkpointing, ZeRO-3 optimizer state sharding, mixed precision training, and gradient accumulation just to fit the job into their hardware budget. Engineers who understand these techniques going in can design a training run that fits on day one. Engineers who discover these techniques reactively when the job crashes spend a week debugging instead of training.

This lesson is the memory budget sheet for large model training and serving. Every number here is exact. Understanding these numbers is what separates ML engineers who can reason about infrastructure from those who cargo-cult configuration flags.


Why This Exists

Deep learning hit a wall in the early 2010s: models large enough to be interesting were too large to fit in GPU memory, and GPUs were the only way to train them fast enough to be practical. The field developed a collection of techniques to trade compute for memory (gradient checkpointing), distribute memory across devices (ZeRO), and reduce the precision of stored numbers (mixed precision) - each technique independently valuable, all compatible with each other.

The LLM era made these techniques critical at a different scale. GPT-3 (175B parameters) and its successors could not be trained on any plausible number of GPUs without all three categories of optimization simultaneously. The ZeRO paper (Rajbhandari et al., 2019) directly enabled training models at this scale by solving the optimizer state memory problem. Flash Attention (Dao et al., 2022) solved the attention activation memory problem that threatened to dominate at long sequence lengths.

Inference has its own memory crisis. A 70B model at BF16 requires 140 GB just for weights. Serving it at any meaningful batch size requires additional memory for KV cache. The KV cache grows linearly with sequence length and batch size, which means serving long-context requests requires carefully managing a finite memory budget across concurrent requests. PagedAttention, the technique at the core of vLLM, applies virtual memory principles to KV cache management to maximize serving throughput.


Historical Context

Gradient checkpointing (also called rematerialization) was formalized by Chen et al. (2016) in "Training Deep Networks with Sublinear Memory Cost." The core insight was that you can reduce memory from O(N)O(N) (storing all activations for N layers) to O(N)O(\sqrt{N}) by recomputing activations during the backward pass instead of storing them.

Mixed precision training was systematized by Micikevicius et al. (2017, NVIDIA) in "Mixed Precision Training." The paper showed that FP16 (half precision) could be used for most operations in the forward and backward pass without accuracy loss, provided that a master copy of weights was kept in FP32 for optimizer updates. This halved the activation and gradient memory requirement.

BF16 (Brain Float 16), originally developed for Google's TPU training, became available in NVIDIA's A100 (2020). BF16 has the same range as FP32 (same number of exponent bits) but lower precision than FP16. For training, BF16 is generally preferred over FP16 because it avoids the overflow issues that plague FP16 in large model training.

ZeRO (Zero Redundancy Optimizer) was introduced by Rajbhandari et al. (Microsoft Research, 2019). The key observation was that data parallel training replicates the full optimizer state on every GPU - a massive redundancy. ZeRO partitions this state across GPUs, reducing per-GPU memory by up to 8x (with ZeRO-3) while maintaining the same training throughput.

Flash Attention was introduced by Dao et al. (Stanford, 2022). It reformulates the attention computation to use a tiled approach that keeps intermediate results in on-chip SRAM rather than writing them to HBM (the slow GPU main memory). This reduces attention's memory requirement from O(n2)O(n^2) to O(n)O(n) and also speeds up attention by 2-4x by eliminating repeated HBM reads.

PagedAttention was introduced by Kwon et al. (UC Berkeley, 2023) in "Efficient Memory Management for Large Language Model Serving with PagedAttention." It applies the OS virtual memory principle - dividing memory into fixed-size pages, allowing non-contiguous physical allocation for logically contiguous data - to the KV cache in LLM serving.


Core Concepts

The Training Memory Breakdown

Every training step uses memory in six distinct categories. Knowing the exact formula for each allows you to construct a memory budget before writing a single line of training code.

Let:

  • PP = number of parameters
  • BB = batch size (number of sequences)
  • TT = sequence length in tokens
  • LL = number of transformer layers
  • HH = model hidden dimension
  • AA = number of attention heads
  • dh=H/Ad_h = H/A = head dimension
  • bytes per parameter: 2 for BF16/FP16, 4 for FP32

Weights (BF16)=2P bytes\text{Weights (BF16)} = 2P \text{ bytes} Gradients (BF16)=2P bytes\text{Gradients (BF16)} = 2P \text{ bytes} Optimizer states (AdamW, FP32)=8P bytes(2 moments×4 bytes each)\text{Optimizer states (AdamW, FP32)} = 8P \text{ bytes} \quad (2 \text{ moments} \times 4 \text{ bytes each}) Optimizer master weights (FP32)=4P bytes\text{Optimizer master weights (FP32)} = 4P \text{ bytes}

Total for optimizer + weights + gradients (mixed precision AdamW): Mparam=2P+2P+8P+4P=16P bytesM_{param} = 2P + 2P + 8P + 4P = 16P \text{ bytes}

For a 7B parameter model: 16×7×109=11216 \times 7 \times 10^9 = 112 GB. That does not fit on a single A100 80GB.

Activation memory per transformer layer: Mact/layerBTH(34+5AT/H) bytesM_{act/layer} \approx B \cdot T \cdot H \cdot (34 + 5A \cdot T / H) \text{ bytes}

This simplifies for most models to approximately: Mact/layer12BTH bytes (BF16)M_{act/layer} \approx 12 \cdot B \cdot T \cdot H \text{ bytes (BF16)}

For Llama 2 7B with B=4,T=2048,H=4096,L=32B=4, T=2048, H=4096, L=32: Mact=32×12×4×2048×4096×251 GBM_{act} = 32 \times 12 \times 4 \times 2048 \times 4096 \times 2 \approx 51 \text{ GB}

Total naive memory for 7B fine-tuning: 112 GB + 51 GB = 163 GB. This requires at minimum three A100 80GB GPUs even before we start optimizing.

def training_memory_budget(
n_params: int,
n_layers: int,
hidden_dim: int,
n_heads: int,
seq_len: int,
batch_size: int,
dtype: str = "bf16",
optimizer: str = "adamw",
gradient_checkpointing: bool = False,
zero_stage: int = 0,
n_gpus: int = 1,
) -> dict:
"""
Compute per-GPU memory budget for a transformer training run.
Returns breakdown in GB.
"""
bytes_per_param = 2 if dtype in ("bf16", "fp16") else 4

# Model weights
weights_gb = (n_params * bytes_per_param) / 1e9

# Gradients: same dtype as weights
grads_gb = weights_gb

# Optimizer states
if optimizer == "adamw":
# Mixed precision: FP32 master weights + FP32 m + FP32 v
optimizer_gb = (n_params * 4 * 3) / 1e9 # 12 bytes per param
elif optimizer == "adamw_8bit":
# bitsandbytes 8-bit Adam: 1 byte per state value
optimizer_gb = (n_params * 2) / 1e9
elif optimizer == "sgd":
optimizer_gb = 0 # SGD with no momentum
else:
optimizer_gb = (n_params * 4 * 2) / 1e9

# ZeRO sharding
zero_divisors = {0: 1, 1: n_gpus, 2: n_gpus, 3: n_gpus}
if zero_stage >= 1:
optimizer_gb /= zero_divisors.get(zero_stage, 1)
if zero_stage >= 2:
grads_gb /= n_gpus
if zero_stage >= 3:
weights_gb /= n_gpus

# Activation memory
if gradient_checkpointing:
# With GC: only store activations at checkpointed layers
# Rule of thumb: sqrt(n_layers) checkpoints needed
import math
checkpoint_layers = max(1, int(math.sqrt(n_layers)))
effective_layers = checkpoint_layers
else:
effective_layers = n_layers

# Per-layer activation: 12 * B * T * H bytes at BF16
act_per_layer_gb = (12 * batch_size * seq_len * hidden_dim * 2) / 1e9
activations_gb = act_per_layer_gb * effective_layers

# KV cache during training (not same as inference KV cache)
# Included in activation estimate above for training

total_gb = weights_gb + grads_gb + optimizer_gb + activations_gb

return {
"weights_gb": round(weights_gb, 2),
"gradients_gb": round(grads_gb, 2),
"optimizer_states_gb": round(optimizer_gb, 2),
"activations_gb": round(activations_gb, 2),
"total_per_gpu_gb": round(total_gb, 2),
}

# Llama 2 7B fine-tuning scenarios
print("=== Llama 2 7B Memory Budget ===\n")

# Naive: no optimizations, single GPU
naive = training_memory_budget(
n_params=7e9, n_layers=32, hidden_dim=4096, n_heads=32,
seq_len=2048, batch_size=4, dtype="bf16",
optimizer="adamw", gradient_checkpointing=False,
zero_stage=0, n_gpus=1
)
print("Naive (1 GPU, no optimizations):")
for k, v in naive.items():
print(f" {k}: {v} GB")

print()

# Optimized: GC + ZeRO-3 on 8 GPUs
optimized = training_memory_budget(
n_params=7e9, n_layers=32, hidden_dim=4096, n_heads=32,
seq_len=2048, batch_size=4, dtype="bf16",
optimizer="adamw", gradient_checkpointing=True,
zero_stage=3, n_gpus=8
)
print("Optimized (8 GPUs, GC + ZeRO-3):")
for k, v in optimized.items():
print(f" {k}: {v} GB")

Mixed Precision Training (FP16 and BF16)

Mixed precision training stores weights, gradients, and activations in half precision (2 bytes) but keeps the optimizer's master weights and moments in FP32 (4 bytes). The key insight: most of the computation can happen in lower precision without loss of accuracy, but the optimizer update step is numerically sensitive and benefits from higher precision.

FP16 (IEEE 754 half precision):

  • 1 sign bit, 5 exponent bits, 10 mantissa bits
  • Range: approximately ±65504\pm 65504
  • Problem: gradient values during large model training often exceed this range, causing overflow to Inf or underflow to 0 (gradient vanishing)
  • Requires loss scaling to prevent underflow

BF16 (bfloat16):

  • 1 sign bit, 8 exponent bits, 7 mantissa bits
  • Range: same as FP32 (±3.4×1038\approx \pm 3.4 \times 10^{38})
  • Lower precision than FP16 but larger range - avoids overflow without loss scaling
  • Available in A100, H100, and TPUs - generally preferred for training
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

# Setup: a simple transformer block
model = nn.TransformerEncoderLayer(d_model=512, nhead=8).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# FP16 training with loss scaling (required for FP16 to prevent underflow)
scaler = GradScaler() # manages loss scaling automatically

def train_step_fp16(inputs, targets):
optimizer.zero_grad()

with autocast(dtype=torch.float16):
# Forward pass runs in FP16
output = model(inputs)
loss = criterion(output.view(-1, output.size(-1)), targets.view(-1))

# Scale loss to prevent gradient underflow, then backward
scaler.scale(loss).backward()

# Unscale gradients, then clip (clipping must happen after unscaling)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# Step: scaler checks if gradients have Inf/NaN; skips step if so
scaler.step(optimizer)
scaler.update()

return loss.item()


# BF16 training (simpler: no loss scaling needed)
def train_step_bf16(inputs, targets):
optimizer.zero_grad()

with autocast(dtype=torch.bfloat16):
output = model(inputs)
loss = criterion(output.view(-1, output.size(-1)), targets.view(-1))

# Direct backward: BF16 does not need loss scaling
loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

return loss.item()


# Memory comparison
def measure_model_memory(model):
param_bytes = sum(p.nelement() * p.element_size() for p in model.parameters())
buffer_bytes = sum(b.nelement() * b.element_size() for b in model.buffers())
return (param_bytes + buffer_bytes) / 1e6

model_fp32 = nn.TransformerEncoderLayer(d_model=512, nhead=8)
model_bf16 = nn.TransformerEncoderLayer(d_model=512, nhead=8).to(torch.bfloat16)

print(f"FP32 model: {measure_model_memory(model_fp32):.1f} MB")
print(f"BF16 model: {measure_model_memory(model_bf16):.1f} MB")
print(f"Memory ratio: {measure_model_memory(model_fp32) / measure_model_memory(model_bf16):.2f}x")

Gradient Checkpointing

Gradient checkpointing (also called "activation checkpointing" or "rematerialization") trades compute for memory. Normally, PyTorch stores all intermediate activations during the forward pass so they are available for gradient computation in the backward pass. With gradient checkpointing, you discard activations during the forward pass and recompute them during the backward pass.

The compute cost: one additional forward pass per checkpointed segment. The memory savings: instead of O(L)O(L) activation memory, you need only O(L)O(\sqrt{L}) memory (if you checkpoint every L\sqrt{L} layers) or O(1)O(1) memory (if you checkpoint every layer).

For a 32-layer model with per-layer activation of 1.6 GB:

  • No checkpointing: 51 GB of activations
  • Checkpoint every layer: ~1.6 GB (one layer at a time), but 2x compute
  • Checkpoint every 4 layers: ~12.8 GB activations, ~1.3x compute overhead
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential

# Method 1: checkpoint individual modules
class CheckpointedTransformerLayer(nn.Module):
def __init__(self, d_model: int, nhead: int):
super().__init__()
self.layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
batch_first=True)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# checkpoint() runs layer in a no_grad context during forward,
# saves only the input, and recomputes during backward
return checkpoint(self.layer, x, use_reentrant=False)


# Method 2: checkpoint a sequence of modules
class CheckpointedTransformer(nn.Module):
def __init__(self, d_model: int, nhead: int, n_layers: int,
segments: int = 4):
super().__init__()
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
batch_first=True)
for _ in range(n_layers)
])
self.segments = segments # number of checkpoint segments

def forward(self, x: torch.Tensor) -> torch.Tensor:
# checkpoint_sequential divides n_layers into `segments` segments
# Only the segment boundaries are stored; internals are recomputed
return checkpoint_sequential(self.layers, self.segments, x)


# Method 3: Hugging Face transformers native gradient checkpointing
# Most HF models support this with a single flag
from transformers import AutoModelForCausalLM

def load_with_gradient_checkpointing(model_name: str):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Enable gradient checkpointing after loading
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
# use_reentrant=False is recommended for PyTorch >= 2.0
# It uses the more memory-efficient non-reentrant implementation

return model


# Demonstration: measure activation memory with and without checkpointing
def measure_activation_memory(use_checkpoint: bool, n_layers: int = 16,
batch: int = 4, seq: int = 512, d_model: int = 512):
"""Measure peak GPU memory for a forward+backward pass."""
if not torch.cuda.is_available():
print("CUDA not available")
return

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

if use_checkpoint:
model = CheckpointedTransformer(d_model, nhead=8, n_layers=n_layers)
else:
layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True)
for _ in range(n_layers)
])
model = nn.Sequential(*layers)

model = model.cuda().to(torch.bfloat16)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

x = torch.randn(batch, seq, d_model, device='cuda', dtype=torch.bfloat16)

optimizer.zero_grad()
out = model(x)
loss = out.mean()
loss.backward()
optimizer.step()

peak_mb = torch.cuda.max_memory_allocated() / 1e6
print(f"{'With' if use_checkpoint else 'Without'} checkpointing: "
f"{peak_mb:.0f} MB peak GPU memory")

measure_activation_memory(use_checkpoint=False)
measure_activation_memory(use_checkpoint=True)

ZeRO Optimizer Stages

ZeRO (Zero Redundancy Optimizer) partitions training state across data parallel processes. In standard data parallelism, every GPU holds a complete copy of weights, gradients, and optimizer states. ZeRO progressively eliminates this redundancy.

ZeRO-1 partitions optimizer states. Each GPU holds 1/N1/N of the optimizer states. Memory savings for optimizer states: NNx.

ZeRO-2 partitions gradients as well. Each GPU holds 1/N1/N of gradients and 1/N1/N of optimizer states. Memory savings for gradients + optimizer states: NNx.

ZeRO-3 partitions weights as well. Each GPU holds 1/N1/N of weights, gradients, and optimizer states. Memory savings for all three: NNx.

Memory reduction formulas (for NN GPUs, FP16/BF16 training with AdamW):

Baseline per GPU=16P bytes\text{Baseline per GPU} = 16P \text{ bytes} ZeRO-1 per GPU=(4P+12P/N) bytes\text{ZeRO-1 per GPU} = (4P + 12P/N) \text{ bytes} ZeRO-2 per GPU=(4P+14P/N) bytes(4P+14P/N)\text{ZeRO-2 per GPU} = (4P + 14P/N) \text{ bytes} \approx (4P + 14P/N) ZeRO-3 per GPU=(16P/N) bytes\text{ZeRO-3 per GPU} = (16P/N) \text{ bytes}

For a 7B model on 8 GPUs:

  • Baseline: 16×7B=11216 \times 7B = 112 GB per GPU
  • ZeRO-1: (4×7B+12×7B/8)=28+10.5=38.5(4 \times 7B + 12 \times 7B/8) = 28 + 10.5 = 38.5 GB per GPU
  • ZeRO-3: 16×7B/8=1416 \times 7B / 8 = 14 GB per GPU
# DeepSpeed ZeRO-3 configuration
# Save as deepspeed_config.json

import json

zero3_config = {
"train_batch_size": "auto", # set by trainer
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,

"bf16": {
"enabled": True # use BF16; set to False for FP16
},

"zero_optimization": {
"stage": 3,

# Optimizer state partitioning
"overlap_comm": True, # overlap gradient communication with backward pass
"contiguous_gradients": True, # contiguous gradient buffer (reduces fragmentation)
"reduce_scatter": True, # use reduce-scatter instead of allreduce
"reduce_bucket_size": 5e8, # 500 MB communication bucket

# CPU offloading (ZeRO-Infinity)
# Offload optimizer states to CPU: saves ~8P bytes of GPU memory
# at the cost of slower optimizer updates
"offload_optimizer": {
"device": "cpu",
"pin_memory": True # pin CPU memory for fast H2D transfer
},

# Uncomment to also offload parameters to CPU (slower but fits more)
# "offload_param": {
# "device": "cpu",
# "pin_memory": True
# },

# Sub-group size for ZeRO-3 parameter gathering
"sub_group_size": 1e9,

# Stage 3 parameter gathering
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True,
},
}

print(json.dumps(zero3_config, indent=2))

# Using DeepSpeed with Hugging Face Trainer
from transformers import TrainingArguments

training_args = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=4,
gradient_accumulation_steps=8, # effective batch = 4 * 8 * n_gpus
learning_rate=2e-5,
num_train_epochs=3,
bf16=True,
deepspeed="deepspeed_config.json",
gradient_checkpointing=True,
logging_steps=10,
save_strategy="epoch",
)
# Trainer automatically integrates DeepSpeed when deepspeed= is set

Activation Offloading to CPU

When GPU memory is exhausted even with gradient checkpointing, activations can be offloaded to CPU memory (pinned). This adds PCIe transfer overhead but allows training models that would otherwise not fit at all.

import torch
from torch.utils.checkpoint import checkpoint

class CPUOffloadCheckpoint(torch.nn.Module):
"""
Variant of gradient checkpointing that stores the saved input on CPU.
Useful when GPU memory is critically tight.
"""

def __init__(self, module: torch.nn.Module):
super().__init__()
self.module = module

def forward(self, *inputs):
# offload=True: saved tensors (inputs to checkpoint) are stored on CPU
# They are moved back to GPU during the backward pass
return checkpoint(
self.module,
*inputs,
use_reentrant=False,
)

# PyTorch native CPU offloading via FSDP (Fully Sharded Data Parallel)
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import CPUOffload, MixedPrecision
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
import functools

def wrap_model_with_fsdp(model: torch.nn.Module, offload_to_cpu: bool = False):
"""Wrap a model with FSDP for ZeRO-3-equivalent sharding."""

mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)

wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)

cpu_offload = CPUOffload(offload_params=offload_to_cpu)

model = FSDP(
model,
auto_wrap_policy=wrap_policy,
mixed_precision=mixed_precision,
cpu_offload=cpu_offload,
device_id=torch.cuda.current_device(),
)
return model

KV Cache Memory in LLM Inference

During LLM inference (autoregressive decoding), the model generates one token at a time. Each token requires attention over all previous tokens. To avoid recomputing attention keys and values for previous tokens, they are cached.

The KV cache memory formula:

MKV=2×L×H×dh×T×B×bytesM_{KV} = 2 \times L \times H \times d_h \times T \times B \times \text{bytes}

Where:

  • Factor 2: keys and values (not queries - queries are not cached)
  • LL = number of layers
  • HH = number of attention heads (or KV heads for grouped-query attention)
  • dhd_h = head dimension (= hidden_dim / n_heads)
  • TT = sequence length (including generated tokens so far)
  • BB = batch size (number of concurrent sequences)
  • bytes = 2 for BF16/FP16

For Llama 2 7B: L=32,H=32,dh=128,BF16L=32, H=32, d_h=128, \text{BF16}: MKV=2×32×32×128×T×B×2 bytes=524288×T×B bytesM_{KV} = 2 \times 32 \times 32 \times 128 \times T \times B \times 2 \text{ bytes} = 524288 \times T \times B \text{ bytes}

At T=4096,B=32T=4096, B=32 (moderate serving): MKV=524288×4096×32=68.7 GBM_{KV} = 524288 \times 4096 \times 32 = 68.7 \text{ GB}

That is almost the entire A100 80GB VRAM, for KV cache alone, with no room for model weights.

def kv_cache_size_gb(
n_layers: int,
n_kv_heads: int, # = n_heads for MHA; smaller for GQA
head_dim: int,
seq_len: int,
batch_size: int,
dtype_bytes: int = 2, # 2 for BF16/FP16, 4 for FP32
) -> float:
"""
Compute KV cache memory in GB.
n_kv_heads: for Grouped Query Attention (GQA), this is smaller than n_heads.
"""
bytes_total = (
2 # K and V
* n_layers
* n_kv_heads
* head_dim
* seq_len
* batch_size
* dtype_bytes
)
return bytes_total / 1e9

# Llama 2 7B: multi-head attention (MHA), no GQA
llama2_7b = kv_cache_size_gb(
n_layers=32, n_kv_heads=32, head_dim=128,
seq_len=4096, batch_size=32, dtype_bytes=2
)
print(f"Llama 2 7B KV cache (4096 tokens, batch 32): {llama2_7b:.1f} GB")

# Llama 3 8B: uses GQA with 8 KV heads (vs 32 query heads)
llama3_8b = kv_cache_size_gb(
n_layers=32, n_kv_heads=8, head_dim=128,
seq_len=4096, batch_size=32, dtype_bytes=2
)
print(f"Llama 3 8B KV cache (GQA, 4096 tokens, batch 32): {llama3_8b:.1f} GB")

# KV cache growth during generation
print("\nKV cache growth during generation (Llama 3 8B, batch=32):")
for seq_len in [256, 512, 1024, 2048, 4096, 8192]:
size = kv_cache_size_gb(32, 8, 128, seq_len, 32)
bar = "#" * int(size * 5)
print(f" {seq_len:5d} tokens: {size:5.2f} GB {bar}")

# PagedAttention / vLLM approach
# Instead of allocating one contiguous block per sequence,
# allocate fixed-size "pages" of KV cache (e.g. 16 tokens per page)
# Pages from different sequences can interleave in physical memory
# This dramatically reduces memory fragmentation

PAGE_SIZE = 16 # tokens per KV cache page

def pagedattention_capacity(
total_kv_memory_gb: float,
n_layers: int,
n_kv_heads: int,
head_dim: int,
dtype_bytes: int = 2,
page_size_tokens: int = PAGE_SIZE,
) -> dict:
"""
Compute serving capacity under PagedAttention.
Returns: max tokens storable, max concurrent sequences at various lengths.
"""
bytes_per_page = (
2 * n_layers * n_kv_heads * head_dim * page_size_tokens * dtype_bytes
)
total_bytes = total_kv_memory_gb * 1e9
total_pages = int(total_bytes / bytes_per_page)
total_tokens = total_pages * page_size_tokens

result = {"total_tokens": total_tokens, "total_pages": total_pages}
for seq_len in [256, 512, 1024, 2048, 4096]:
pages_per_seq = (seq_len + page_size_tokens - 1) // page_size_tokens
max_concurrent = total_pages // pages_per_seq
result[f"max_concurrent_{seq_len}_tokens"] = max_concurrent

return result

# A100 80GB: reserve 20GB for model weights, 60GB for KV cache
capacity = pagedattention_capacity(60.0, 32, 8, 128)
print("\nPagedAttention capacity (60 GB KV budget, Llama 3 8B):")
for k, v in capacity.items():
print(f" {k}: {v:,}")

Gradient Accumulation

Gradient accumulation allows you to train with an effective batch size larger than what fits in GPU memory. You run multiple forward-backward passes on small micro-batches, accumulating gradients without taking an optimizer step. After KK accumulation steps, you take one optimizer step on the accumulated gradients - equivalent to one step on a batch KK times larger.

Memory cost: same as a single micro-batch (gradients just accumulate in place). Compute cost: identical to training with the large batch. Communication cost (in distributed training): same total communication volume, but batched into fewer optimizer steps.

Effective batch size=micro-batch×accumulation steps×n GPUs\text{Effective batch size} = \text{micro-batch} \times \text{accumulation steps} \times \text{n GPUs}

import torch
import torch.nn as nn

def train_with_gradient_accumulation(
model: nn.Module,
dataloader,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
accumulation_steps: int = 8,
device: str = "cuda",
):
"""
Training loop with gradient accumulation.
Effective batch = micro_batch_size * accumulation_steps.
"""
model.train()
total_loss = 0.0

for step, (inputs, targets) in enumerate(dataloader):
inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)

# Forward pass
with torch.autocast(device_type=device, dtype=torch.bfloat16):
outputs = model(inputs)
# Divide loss by accumulation steps so gradients scale correctly
# (equivalent to averaging over the full effective batch)
loss = criterion(outputs, targets) / accumulation_steps

# Backward pass (gradients accumulate in .grad buffers)
loss.backward()
total_loss += loss.item() * accumulation_steps

# Optimizer step only every accumulation_steps
if (step + 1) % accumulation_steps == 0:
# Gradient clipping: must happen before step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()

# Handle leftover steps (if dataset size not divisible by accumulation_steps)
remaining = len(dataloader) % accumulation_steps
if remaining > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()

return total_loss / len(dataloader)

Quantization: INT8 and INT4

Quantization represents model weights (and optionally activations) in lower-bit integer formats. This reduces memory and often increases inference throughput on hardware with efficient integer kernels.

Memory savings:

  • FP32 (4 bytes) to INT8 (1 byte): 4x reduction
  • FP32 (4 bytes) to INT4 (0.5 bytes): 8x reduction
  • BF16 (2 bytes) to INT8 (1 byte): 2x reduction

A 7B parameter model:

  • BF16: 14 GB
  • INT8: 7 GB
  • INT4 (4-bit GPTQ or AWQ): 3.5 GB
# BitsAndBytes (bitsandbytes) quantization - post-training quantization
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

# 8-bit quantization: minimal quality loss, 2x memory reduction
bnb_8bit_config = BitsAndBytesConfig(
load_in_8bit=True,
# Linear layers are quantized; LayerNorm, embed, lm_head stay in FP16
)

# 4-bit quantization: ~4x memory reduction, moderate quality loss
# NF4 (Normal Float 4) preserves more information than regular INT4
bnb_4bit_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # nf4 or fp4
bnb_4bit_compute_dtype=torch.bfloat16, # computation in BF16
bnb_4bit_use_double_quant=True, # quantize the quantization constants too
)

# Load a 7B model with 4-bit quantization
model_4bit = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=bnb_4bit_config,
device_map="auto",
)
print(f"4-bit model memory: {model_4bit.get_memory_footprint() / 1e9:.2f} GB")

# QLoRA: fine-tune a quantized model with LoRA adapters
# The quantized base model is frozen; only LoRA matrices are trained
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# Prepare quantized model for training (handles some quantization quirks)
model_4bit = prepare_model_for_kbit_training(
model_4bit,
use_gradient_checkpointing=True,
)

lora_config = LoraConfig(
r=16, # LoRA rank
lora_alpha=32, # scaling factor
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

model_qlora = get_peft_model(model_4bit, lora_config)
model_qlora.print_trainable_parameters()
# trainable params: ~4M (LoRA) vs 7B (full model) = 0.06% of parameters
# Memory: ~5 GB for quantized base + 50 MB for LoRA = fits in single 8GB GPU

Memory Optimization Decision Tree

Q1["Does the model fit in<br/>GPU memory at all?"]:::blue
Q2["Reduce precision:<br/>FP32 -> BF16<br/>(2x savings)"]:::green
Q3["Do activations fit<br/>with BF16?"]:::blue
Q4["Enable gradient<br/>checkpointing<br/>(~10x activation savings)"]:::orange
Q5["Do optimizer states fit?"]:::blue
Q6["Apply ZeRO-1/2<br/>or use 8-bit Adam<br/>(8x optimizer savings)"]:::orange
Q7["Does full 7-component budget fit?"]:::blue
Q8["ZeRO-3 sharding<br/>across N GPUs<br/>(Nx all components)"]:::purple
Q9["CPU offloading<br/>(ZeRO-Infinity)"]:::teal
Q10["Ready to train"]:::green

Q1 -->|"No"| Q2
Q1 -->|"Yes"| Q10
Q2 --> Q3
Q3 -->|"No"| Q4
Q3 -->|"Yes"| Q5
Q4 --> Q5
Q5 -->|"No"| Q6
Q5 -->|"Yes"| Q7
Q6 --> Q7
Q7 -->|"No, need more GPUs"| Q8
Q7 -->|"No, no more GPUs"| Q9
Q7 -->|"Yes"| Q10
Q8 --> Q10
Q9 --> Q10

classDef blue fill:#dbeafe,color:#1e293b,stroke:#2563eb
classDef teal fill:#ccfbf1,color:#134e4a,stroke:#14b8a6
classDef orange fill:#ffedd5,color:#7c2d12,stroke:#ea580c
classDef green fill:#dcfce7,color:#14532d,stroke:#16a34a
classDef purple fill:#ede9fe,color:#4c1d95,stroke:#7c3aed
classDef red fill:#fee2e2,color:#7f1d1d,stroke:#dc2626

Production Engineering Notes

Memory Fragmentation in Long Training Runs

PyTorch's CUDA memory allocator uses a caching allocator: freed memory is not returned to CUDA immediately but held in a free list for reuse. This is efficient for steady-state training but causes fragmentation over time. Tensors of different sizes accumulate in the free list, and eventually there is no single contiguous free block large enough for a requested allocation, even if total free memory is ample.

Symptoms: CUDA out of memory errors that appear after N training steps but not at step 1, even though step 1 used more peak memory.

import torch

# Periodically clear the allocator cache in long training runs
# (costs one or two extra ms per call, but prevents fragmentation OOM)
if step % 1000 == 0:
torch.cuda.empty_cache()

# Check fragmentation:
mem_info = torch.cuda.memory_stats()
active = mem_info['active_bytes.all.current']
reserved = mem_info['reserved_bytes.all.current']
fragmentation = 1.0 - (active / reserved) if reserved > 0 else 0
print(f"Active: {active/1e9:.2f} GB, Reserved: {reserved/1e9:.2f} GB, "
f"Fragmentation: {fragmentation*100:.1f}%")

Estimation Before Commitment

Always run a memory estimate before starting a multi-day training job. The torchinfo library can estimate model memory, and you can run a single batch forward-backward pass with torch.cuda.max_memory_allocated() to measure actual peak usage.

import torch

def estimate_training_memory(model, sample_batch, device='cuda'):
"""Run a single training step and report peak memory."""
model = model.to(device)
inputs, targets = sample_batch
inputs, targets = inputs.to(device), targets.to(device)

torch.cuda.reset_peak_memory_stats(device)
torch.cuda.synchronize(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
optimizer.zero_grad()

with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
out = model(inputs)
loss = out.mean()

loss.backward()
optimizer.step()

torch.cuda.synchronize(device)
peak_gb = torch.cuda.max_memory_allocated(device) / 1e9
print(f"Peak GPU memory: {peak_gb:.2f} GB")
return peak_gb

Common Mistakes

danger

Not accounting for optimizer states in memory budgets: The single most common mistake when estimating training memory is counting only model weights. For AdamW with FP32 optimizer states, the optimizer adds 12 bytes per parameter on top of the 2-byte BF16 weights. A 7B model at BF16 = 14 GB weights, but full training memory including optimizer = 112 GB. Always include optimizer states in your budget.

danger

Enabling gradient checkpointing after the optimizer is initialized: In some training frameworks, enabling gradient checkpointing after optimizer initialization causes incorrect gradient accumulation. Always enable gradient checkpointing before creating the optimizer. In Hugging Face Trainer, set gradient_checkpointing=True in TrainingArguments before the Trainer is constructed.

warning

ZeRO-3 with small models is slower than DDP: ZeRO-3 adds all-gather communication for parameters before each forward pass and reduce-scatter after each backward pass. For small models where the communication latency is a significant fraction of the compute time, this overhead makes ZeRO-3 slower than standard DDP. Use ZeRO-3 only when the model does not fit in GPU memory without it. Use ZeRO-1 or ZeRO-2 when optimizer state memory is the binding constraint.

warning

Forgetting to divide loss by accumulation steps in gradient accumulation: Without the division, each micro-batch contributes its full gradient, and the accumulated gradient is N times larger than it should be (where N is the accumulation steps). This effectively multiplies the learning rate by N, causing training instability. The fix: divide loss by accumulation_steps before calling loss.backward().


Interview Q&A

Q: Walk through the complete memory budget for fine-tuning a 7B parameter model with AdamW in BF16. What components does memory consist of?

A: Memory consists of six components.

First, model weights at BF16: 2 bytes per parameter, so 7×109×2=147 \times 10^9 \times 2 = 14 GB.

Second, gradients: same dtype as weights in the standard setup, so another 14 GB.

Third, optimizer states: AdamW with FP32 master weights stores (a) FP32 master copy of weights (4 bytes/param = 28 GB), (b) first moment (gradient mean, FP32, 28 GB), (c) second moment (gradient variance, FP32, 28 GB). Total optimizer: 84 GB.

Fourth, activation memory: this depends on batch size, sequence length, and whether gradient checkpointing is enabled. For a batch of 4, sequence length 2048, with 32 layers, activations are roughly 51 GB without checkpointing. With checkpointing enabled, this drops to roughly 1.6 GB (one layer at a time).

Fifth, the input batch itself: a batch of 4 sequences of 2048 tokens as int64 is negligible (less than 100 MB).

Sixth, CUDA workspace memory for kernels: typically 1-2 GB.

Without gradient checkpointing, total is roughly 14 + 14 + 84 + 51 = 163 GB. With gradient checkpointing: 14 + 14 + 84 + 1.6 = 113.6 GB. This requires at minimum two A100 80GB GPUs with ZeRO-2 (which reduces gradients + optimizer by N=2). For a single GPU or to fit comfortably, you need ZeRO-3 across four or more GPUs.


Q: Explain the ZeRO optimizer stages and the memory savings formula for each.

A: ZeRO stands for Zero Redundancy Optimizer. In standard data parallel training, every GPU holds a complete replica of weights, gradients, and optimizer states. ZeRO progressively eliminates this redundancy.

ZeRO-1 partitions optimizer states across NN GPUs. Each GPU holds 1/N1/N of the momentum and variance buffers. When the optimizer step runs, each GPU updates its 1/N1/N shard of parameters. The updated parameters are then all-gathered across GPUs. Memory savings: the optimizer state contribution goes from 12P12P bytes to 12P/N12P/N bytes per GPU. For AdamW with mixed precision, per-GPU memory goes from 16P bytes to approximately (4P+12P/N)(4P + 12P/N) bytes.

ZeRO-2 additionally partitions gradients. After the backward pass, each GPU only needs to hold 1/N1/N of gradients for its optimizer shard. Memory savings: gradients go from 2P2P to 2P/N2P/N bytes per GPU. Per-GPU memory: approximately (4P+14P/N)(4P + 14P/N) bytes.

ZeRO-3 additionally partitions the model parameters themselves. At any point in the forward or backward pass, each GPU holds only the parameters for the layers currently being computed. Parameters are all-gathered just before use and discarded immediately after. Per-GPU memory: 16P/N16P/N bytes - a full NNx reduction over the baseline.

The communication cost: ZeRO-1 adds one all-gather of parameters per optimizer step. ZeRO-2 adds reduce-scatter of gradients after backward. ZeRO-3 adds all-gather of parameters before each layer's forward and backward. In practice, ZeRO-3 has the highest communication overhead but enables the largest models on fixed hardware.


Q: What is the KV cache, how much memory does it consume, and how does PagedAttention reduce memory waste?

A: During autoregressive generation, the transformer computes attention over all previous tokens at each step. Rather than recomputing the key and value projections for all previous tokens at each step, the model caches them.

The KV cache size: 2×L×Hkv×dh×T×B×bytes2 \times L \times H_{kv} \times d_h \times T \times B \times \text{bytes}, where the factor of 2 is for keys and values, LL is layers, HkvH_{kv} is KV heads (= attention heads for MHA, smaller for GQA), dhd_h is head dimension, TT is sequence length, and BB is batch size. For Llama 2 7B serving a batch of 32 at 4096 tokens, this is roughly 68 GB - nearly a full A100.

The problem: KV cache must be pre-allocated per-sequence because standard implementations require contiguous memory for efficient attention computation. This creates fragmentation. If you pre-allocate 8192 tokens per sequence but the average sequence is 512 tokens, you waste 94% of KV memory.

PagedAttention fixes this by dividing KV cache into fixed-size pages (typically 16 tokens per page) stored in non-contiguous physical memory, with a page table mapping logical sequence positions to physical pages. This is analogous to how the OS handles virtual memory pages for processes.

The practical benefit: memory waste drops from 60-80% (with contiguous pre-allocation) to under 4% (since wasted memory is at most one partial page per sequence). vLLM, which implements PagedAttention, typically achieves 2-4x higher serving throughput than naive implementations on the same hardware because it can serve 2-4x more concurrent sequences.


Q: Explain gradient checkpointing. What is the memory-compute tradeoff, and when is it worth using?

A: During the standard forward pass, PyTorch stores all intermediate activations (the outputs of every layer) because they are needed for computing gradients in the backward pass. For a transformer with LL layers, this is O(L)O(L) memory proportional to both layer count and batch size.

Gradient checkpointing selectively discards intermediate activations during the forward pass. When the backward pass reaches a layer whose activations were discarded, it re-runs the forward computation for that layer to regenerate the activations needed for the gradient. The memory cost is now O(c)O(c) where cc is the number of checkpoints, not O(L)O(L) layers. Choosing checkpoints every L\sqrt{L} layers gives O(L)O(\sqrt{L}) memory with O(1)O(1) additional forward passes per checkpoint.

Checkpointing every layer (the most common setting in practice) reduces activation memory from O(L×B×T×H)O(L \times B \times T \times H) to O(B×T×H)O(B \times T \times H) - roughly 10-30x reduction for typical models. The compute overhead is one additional forward pass per layer, which adds approximately 30-40% to total training time.

It is worth using whenever activation memory is a binding constraint. For large models with long sequences, activations dominate memory usage. A 70B model with sequence length 4096 and batch size 8 has activations requiring hundreds of GB - gradient checkpointing drops this to under 10 GB. The 33% training time overhead is a good trade for fitting the job on available hardware.

When not to use it: if your batch size is already small (batch size 1) and activations are already minimal, gradient checkpointing adds compute overhead with no meaningful memory benefit. Also, gradient checkpointing is incompatible with some optimizations like torch.compile() in certain modes - verify compatibility with your specific PyTorch version.


Q: What is the difference between Flash Attention and standard attention in terms of memory usage?

A: Standard attention computes the full n×nn \times n attention score matrix, which requires O(n2)O(n^2) memory where nn is sequence length. For sequence length 4096, this matrix has 16 million entries. At BF16 (2 bytes), that is 32 MB per head per layer. For Llama 2 7B with 32 heads and 32 layers, the attention matrices alone require 32×32×32=32,76832 \times 32 \times 32 = 32,768 MB = 32 GB during training.

Flash Attention reformulates the computation using the tiling principle. Instead of computing the full n×nn \times n matrix and writing it to GPU HBM (high-bandwidth memory), it computes attention in tiles that fit in on-chip SRAM. For each tile of queries, it loads the corresponding keys and values, computes the partial softmax, and accumulates the weighted sum - all without writing the full attention matrix to HBM.

Memory requirement: Flash Attention's activation memory for attention is O(n)O(n) rather than O(n2)O(n^2). The full attention matrix is never materialized in HBM. For sequence length 4096, this is a 4096x memory reduction for the attention component.

Speed improvement: Flash Attention is typically 2-4x faster than standard attention at long sequence lengths because it dramatically reduces HBM traffic. The bottleneck in standard attention is not arithmetic but memory bandwidth - repeatedly reading and writing the n×nn \times n matrix is slow. Flash Attention eliminates most of these HBM accesses.

Flash Attention 2 (Dao, 2023) improved on the original by better exploiting GPU parallelism (better partitioning of work across warps and blocks). Flash Attention 3 targets H100-specific hardware features (TMA, WGMMA instructions). In practice, all modern training frameworks use Flash Attention by default because the speedup is substantial and there is no quality tradeoff.

© 2026 EngineersOfAI. All rights reserved.