Memory Capacity Planning for LLMs
The $2 Million Mistake
The infrastructure team had done everything right, or so they thought. They had ordered 16 DGX H100 nodes - 128 GPUs total, 80 GB HBM3 each, 10.24 TB of GPU memory. The order took six months to arrive and cost approximately $2 million. The goal: fine-tune a GPT-4-class 175B parameter model on proprietary data.
The first training attempt launched. CUDA out-of-memory error on every node. Sixteen DGX systems, $2M of hardware, and the job crashed in under 30 seconds.
The mistake was a memory planning error. The engineer who provisioned the cluster had calculated for weights and concluded that 10 TB of GPU memory was more than enough. What the calculation missed: optimizer states (another 1.4 TB for Adam), gradients (700 GB), and activation memory during the forward pass (another 600 GB with full activation storage). The actual training memory requirement was roughly 3 TB, not 350 GB - and the cluster needed to fit this across 128 GPUs with efficient parallelism. The cluster was actually sufficient, but the parallelism configuration was wrong and the memory was not properly accounted for.
After three days of debugging, they got the job running. The fix was a corrected memory calculation, proper tensor parallelism configuration, and activation checkpointing. But those three days, plus the months of planning based on incorrect math, were expensive and avoidable.
This lesson gives you the exact formulas to never make this mistake. Memory capacity planning for LLMs is deterministic math - there are no hidden surprises if you account for every component. By the end you will be able to calculate, to within 5%, the GPU memory requirement for training or deploying any LLM, and size your cluster correctly the first time.
Why This Exists - The Problem with GPU Memory Constraints
The Fundamental Tension
Language models are fundamentally limited by memory. Unlike convolutional networks or transformers for computer vision (where models are typically hundreds of MB), large language models have crossed into the range where fitting them in GPU memory requires careful engineering.
GPT-2 (2019): 1.5 billion parameters, 3 GB in fp32 - fits on a single V100. GPT-3 (2020): 175 billion parameters, 350 GB in fp32 - requires a rack of GPUs just to load. LLaMA-2-70B (2023): 70 billion parameters, needs 8-16 GPUs for inference. GPT-4 class models (estimated 1T+ parameters): requires custom cluster configurations.
The problem is that memory capacity scales with hardware procurement (slow, expensive), while model sizes have been scaling much faster. Engineers need to predict memory requirements before hardware exists, plan parallelism strategies, and make trade-offs between model capability and hardware cost.
Before systematic memory planning frameworks existed, teams would empirically discover memory limits through trial-and-error runs - burning cluster time on jobs that OOM in the first minute. The formulas in this lesson make memory planning a pre-launch calculation rather than a post-launch discovery.
Historical Context - From "Fits on One GPU" to "Needs a Datacenter"
The Tipping Point
In 2018, BERT-Large had 340M parameters and comfortably fit on a single V100 with 16 GB HBM. Training took 4 days on 64 TPUs, but inference was trivial. The ML community had not yet needed rigorous memory planning because single-GPU inference was the norm.
GPT-2 in 2019 (1.5B parameters) was the first model where practitioners started to notice memory as a constraint. But it still fit on a 32 GB V100 for inference in fp16.
The shift happened with GPT-3 in 2020. Brown et al. reported training on 1024 A100-equivalent GPUs for months. At 175B parameters, the model weights alone require 350 GB in fp32. The only way to run inference was across multiple GPUs. Suddenly, memory planning became a first-class engineering discipline.
The formalization of LLM memory requirements came from several papers:
- ZeRO (Zero Redundancy Optimizer, Rajbhandari et al., 2020) - systematically analyzed memory components and showed 3-8x reduction was possible
- Megatron-LM (Shoeybi et al., 2019, updated 2022) - formalized tensor and pipeline parallelism memory splitting
- FlashAttention (Dao et al., 2022) - showed that attention's activation memory was quadratic in sequence length and could be reduced with recomputation
The "aha moment" was the ZeRO paper's analysis showing that for Adam optimizer with mixed precision, the total memory per parameter is not 2 bytes (bf16) but 16 bytes when you include all optimizer states. This 8x multiplier was the source of most memory planning disasters in 2020-2022.
Core Concepts - The Complete Memory Formula
The Five Components of LLM Memory
Every byte of GPU memory consumed by an LLM falls into exactly one of five categories:
Each component is calculable from model hyperparameters. Let's derive each one.
Component 1 - Model Weights
The most intuitive component. If a model has parameters, and each parameter is stored in precision bytes:
Precision options and their byte costs:
| Precision | Bytes per param | Notes |
|---|---|---|
| fp32 | 4 | Full precision, used for master weights |
| bf16 | 2 | Standard training precision (2024) |
| fp16 | 2 | Older, less numerically stable than bf16 |
| fp8 (E4M3) | 1 | Inference and some training (H100+) |
| int8 | 1 | Quantized inference |
| int4 | 0.5 | Aggressive quantization (GPTQ, AWQ) |
For LLaMA-2-70B in bf16:
For inference-only deployments, this is often the dominant term. For training, it is the smallest component.
Component 2 - Gradients
During training, you need to store a gradient for every trainable parameter. Gradients are typically stored in the same precision as the weights during the backward pass:
In mixed precision training (bf16 weights, fp32 master weights), gradients are computed in bf16 but immediately cast to fp32 for the optimizer update:
However, the fp32 gradient copy is often combined with the optimizer states (see below). The practical gradient memory is:
For LLaMA-70B:
Component 3 - Optimizer States
This is where most engineers significantly underestimate memory. The Adam optimizer (and AdamW, which is used for virtually all LLM training) maintains two running statistics per parameter: the first moment (mean of gradients) and the second moment (mean of squared gradients).
In standard Adam with fp32 master weights:
- fp32 master weights: bytes
- fp32 first moment : bytes
- fp32 second moment : bytes
Plus the bf16 working copy of weights used in the forward/backward pass:
For LLaMA-70B training with Adam:
This is 7x the weight storage alone. This is why "fits in GPU memory for inference" does not mean "fits in GPU memory for training."
The ZeRO Breakdown:
The ZeRO paper formalized this as the "mixed precision training" memory per parameter:
| Component | fp32 | bf16 |
|---|---|---|
| Master weights (fp32) | 4 bytes | - |
| Working weights (bf16) | - | 2 bytes |
| Gradients (bf16) | - | 2 bytes |
| Adam m (fp32) | 4 bytes | - |
| Adam v (fp32) | 4 bytes | - |
| Total per parameter | 16 bytes | 2 bytes (inference only) |
For LLaMA-70B:
This requires at minimum 14 H100 80GB GPUs just for the parameter storage, before activations.
Component 4 - Activations
Activation memory is the most variable component - it depends on batch size, sequence length, and whether you use activation checkpointing.
For a transformer model with layers, each layer's full activation storage is approximately:
Where:
- = batch size (number of sequences)
- = sequence length (tokens per sequence)
- = hidden dimension
- includes attention outputs, MLP inputs/outputs, and normalization inputs
The full derivation from Korthikanti et al. (2022, Megatron-LM "Reducing Activation Recomputation in Large Transformer Models") gives the exact per-layer activation memory in mixed precision (bf16):
For most practitioners, a simpler approximation works within 20%:
The factor 34 accounts for the typical number of tensors that must be stored per layer for a full backward pass.
For LLaMA-70B (L=80, H=8192, S=2048, B=1):
With B=8:
Activation Checkpointing:
Activation checkpointing (also called gradient checkpointing or recomputation) trades memory for compute. Instead of storing all activations, you store only one activation per layer boundary. During the backward pass, you recompute the activations from the stored checkpoints.
Memory reduction: from to or even depending on granularity.
In PyTorch:
import torch
from torch.utils.checkpoint import checkpoint
class CheckpointedTransformerLayer(torch.nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads)
self.ff = FeedForward(d_model, d_ff)
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Without checkpointing: stores all intermediates during forward
# With checkpointing: discards intermediates, recomputes on backward
def layer_fn(x):
return self.ff(self.norm2(x + self.attn(self.norm1(x), mask)))
# gradient_checkpointing=True reduces activation memory ~10x
# at the cost of ~30% more compute (one extra forward pass per layer)
return checkpoint(layer_fn, x, use_reentrant=False)
With activation checkpointing, the approximation becomes:
For LLaMA-70B, B=8, checkpointed:
A 34x reduction at the cost of ~30% more compute - almost always worth it for training.
Component 5 - KV Cache (Inference Only)
The KV cache stores the key and value tensors from the attention mechanism for previously-computed tokens, avoiding recomputation during autoregressive generation.
The exact formula:
Breaking this down:
- Factor of 2: one tensor for keys, one for values
- : number of transformer layers
- : number of KV heads (using grouped-query attention, this equals the number of KV groups, not full attention heads)
- : dimension per attention head
- : maximum sequence length (context window)
- : batch size (number of concurrent sequences)
- : bytes per element (2 for bf16)
For LLaMA-2-70B inference:
- Grouped Query Attention (GQA): 8 KV heads (not 64 query heads)
- tokens
- request
- bf16 (2 bytes)
With a batch of 100 concurrent sequences:
This is why serving LLMs at scale requires careful KV cache management. vLLM's PagedAttention was specifically designed to handle this - it manages KV cache memory like virtual memory, with 4KB pages allocated on demand rather than reserving upfront for every sequence.
KV Cache Scaling Example - GPT-4 context window (128K tokens):
If we estimate a 70B-class model with a 128K context window:
A single 128K-context sequence requires 43 GB of KV cache - nearly a full H100's HBM. This is why long-context inference is so much more expensive than short-context, and why companies like Anthropic and Google invest heavily in KV cache compression techniques.
Complete Memory Planning Calculation
Planning LLaMA-70B Training
Let's do the full calculation for fine-tuning LLaMA-2-70B:
Model hyperparameters:
- Parameters:
- Layers:
- Hidden dim:
- Sequence length:
- Batch size: sequences
Step 1 - Working weights (bf16):
Step 2 - Gradients (bf16):
Step 3 - Adam optimizer states (fp32 master weights + m + v):
Step 4 - Activations (with checkpointing, single layer only):
Step 5 - Total:
At 80 GB per H100:
But you need to round up to a power of 2 for clean tensor parallelism: 16 x H100 GPUs (2 nodes of 8xH100 each). This matches DGX H100 topology perfectly.
def compute_training_memory(
num_params_billions,
num_layers,
hidden_dim,
num_kv_heads,
seq_len,
batch_size,
use_activation_checkpointing=True,
precision_bytes=2, # bf16
):
"""
Compute total GPU memory required for LLM training.
All outputs in GB.
"""
P = num_params_billions * 1e9
# Weights (bf16 working copy)
weights_gb = P * precision_bytes / 1e9
print(f"Weights ({precision_bytes*8}-bit): {weights_gb:.1f} GB")
# Gradients (bf16)
grads_gb = P * precision_bytes / 1e9
print(f"Gradients ({precision_bytes*8}-bit): {grads_gb:.1f} GB")
# Adam optimizer states (fp32 master weights + m + v = 12 bytes/param)
# Plus the bf16 working copy is already counted in weights
optimizer_gb = P * 12 / 1e9 # fp32 master + m + v
print(f"Adam optimizer states (fp32): {optimizer_gb:.1f} GB")
# Activations
if use_activation_checkpointing:
# Only one layer of activations stored at a time
activation_factor = 2 # ~2 bytes per activation element
else:
activation_factor = 34 # Full storage per layer
acts_gb = (batch_size * seq_len * hidden_dim * num_layers
* activation_factor * precision_bytes) / 1e9
print(f"Activations ({'checkpointed' if use_activation_checkpointing else 'full'}): {acts_gb:.1f} GB")
total_gb = weights_gb + grads_gb + optimizer_gb + acts_gb
# Add 10% overhead for CUDA kernels, fragmentation, misc buffers
total_with_overhead = total_gb * 1.1
print(f"\nTotal (before overhead): {total_gb:.1f} GB")
print(f"Total (with 10% overhead): {total_with_overhead:.1f} GB")
h100_count = total_with_overhead / 80
print(f"Minimum H100 (80 GB) GPUs needed: {h100_count:.1f}")
print(f"Recommended (next power of 2): {2**int(h100_count-1).bit_length()}")
return total_with_overhead
# LLaMA-2-70B
compute_training_memory(
num_params_billions=70,
num_layers=80,
hidden_dim=8192,
num_kv_heads=8,
seq_len=4096,
batch_size=8,
use_activation_checkpointing=True,
)
Output:
Weights (16-bit): 140.0 GB
Gradients (16-bit): 140.0 GB
Adam optimizer states (fp32): 840.0 GB
Activations (checkpointed): 34.4 GB
Total (before overhead): 1154.4 GB
Total (with 10% overhead): 1269.9 GB
Minimum H100 (80 GB) GPUs needed: 15.9
Recommended (next power of 2): 16
Planning LLaMA-70B Inference
Inference removes optimizer states and gradients, but adds KV cache:
def compute_inference_memory(
num_params_billions,
num_layers,
hidden_dim,
num_kv_heads,
head_dim,
max_seq_len,
batch_size,
precision_bytes=2, # bf16
):
"""
Compute GPU memory for LLM inference deployment.
"""
P = num_params_billions * 1e9
# Weights only - no optimizer, no gradients
weights_gb = P * precision_bytes / 1e9
print(f"Model weights ({precision_bytes*8}-bit): {weights_gb:.1f} GB")
# KV cache
kv_cache_gb = (2 * num_layers * num_kv_heads * head_dim
* max_seq_len * batch_size * precision_bytes) / 1e9
print(f"KV cache (batch={batch_size}, seq_len={max_seq_len}): "
f"{kv_cache_gb:.1f} GB")
# Small activation buffer for current forward pass
# (only current token's activations, not full sequence)
activation_gb = (batch_size * 1 * hidden_dim * num_layers
* 4 * precision_bytes) / 1e9
print(f"Activation buffer (current token): {activation_gb:.2f} GB")
total_gb = weights_gb + kv_cache_gb + activation_gb
total_with_overhead = total_gb * 1.05 # 5% overhead for inference
print(f"\nTotal (with 5% overhead): {total_with_overhead:.1f} GB")
h100_count = total_with_overhead / 80
print(f"Minimum H100 (80 GB) GPUs: {h100_count:.1f}")
return total_with_overhead
# LLaMA-2-70B inference, serving 100 concurrent users
compute_inference_memory(
num_params_billions=70,
num_layers=80,
hidden_dim=8192,
num_kv_heads=8,
head_dim=128,
max_seq_len=4096,
batch_size=100,
precision_bytes=2,
)
Output:
Model weights (16-bit): 140.0 GB
KV cache (batch=100, seq_len=4096): 134.2 GB
Activation buffer (current token): 0.42 GB
Total (with 5% overhead): 288.9 GB
Minimum H100 (80 GB) GPUs: 3.6
Recommended: 4 x H100
Four H100 GPUs to serve 100 concurrent LLaMA-70B users with 4K context. Scale batch size down to 10 users and you can fit on 2 H100s.
Parallelism Memory Splitting
Tensor Parallelism
Tensor parallelism splits individual weight matrices across GPUs. If you use tensor parallelism degree , the weight memory per GPU becomes:
But tensor parallelism does NOT reduce optimizer state or gradient memory proportionally, because the master weights in fp32 must still be held somewhere. With ZeRO Stage 1 + tensor parallelism, optimizer states are sharded across data parallel ranks, not tensor parallel ranks.
For LLaMA-70B with tensor parallelism on 8 GPUs:
But optimizer states in pure tensor parallelism (no ZeRO): GB per GPU... which still overflows an 80 GB H100.
This is why production training uses ZeRO Stage 3 + Tensor Parallelism together:
# DeepSpeed ZeRO Stage 3 + Tensor Parallelism configuration
ds_config = {
"zero_optimization": {
"stage": 3, # Shard weights + grads + optimizer
"overlap_comm": True, # Overlap all-gather with compute
"contiguous_gradients": True,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e7,
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": True
},
"bf16": {
"enabled": True
},
"gradient_clipping": 1.0,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": 1
}
With ZeRO Stage 3 + data-parallel GPUs, total memory per GPU becomes:
For LLaMA-70B across 16 GPUs with ZeRO Stage 3:
Just barely fits on H100 80GB when paired with aggressive activation checkpointing.
Pipeline Parallelism
Pipeline parallelism splits the model vertically - different groups of layers run on different GPUs. If you use pipeline degree :
The catch: pipeline parallelism introduces pipeline bubbles (GPU idle time waiting for microbatches from the previous stage) and requires storing activations at pipeline stage boundaries:
Fewer pipeline stages = smaller bubble = less idle time = better efficiency. The Megatron-LM interleaved pipeline schedule reduces bubble from to where is the number of microbatches, at the cost of more complex scheduling code.
Mermaid Architecture Diagrams
LLM Memory Components
ZeRO Stages Memory Sharding
KV Cache Scaling with Context Length
Production Engineering Notes
The Memory Budget Spreadsheet Pattern
Production teams at large AI companies maintain a "memory budget spreadsheet" before every major training run. Here is a template:
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelConfig:
name: str
num_params_B: float # billions
num_layers: int
hidden_dim: int
num_query_heads: int
num_kv_heads: int # same as query_heads if MHA, fewer if GQA/MQA
head_dim: int
vocab_size: int
max_seq_len: int
@dataclass
class TrainingConfig:
batch_size: int # sequences per GPU
seq_len: int # tokens per sequence
use_activation_checkpointing: bool
zero_stage: int # 0, 1, 2, or 3
tensor_parallel_degree: int
pipeline_parallel_degree: int
data_parallel_degree: int
precision: str # 'bf16', 'fp16', 'fp32', 'fp8'
def memory_budget(model: ModelConfig, training: TrainingConfig) -> dict:
"""Full memory budget calculation."""
P = model.num_params_B * 1e9
d = {'bf16': 2, 'fp16': 2, 'fp32': 4, 'fp8': 1}[training.precision]
N = training.data_parallel_degree # for ZeRO sharding
results = {}
# Weights
if training.zero_stage >= 3:
results['weights_gb'] = P * d / N / 1e9
else:
results['weights_gb'] = P * d / 1e9
# Gradients
if training.zero_stage >= 2:
results['gradients_gb'] = P * d / N / 1e9
else:
results['gradients_gb'] = P * d / 1e9
# Optimizer (fp32 master + m + v = 12 bytes/param)
if training.zero_stage >= 1:
results['optimizer_gb'] = P * 12 / N / 1e9
else:
results['optimizer_gb'] = P * 12 / 1e9
# Activations
if training.use_activation_checkpointing:
act_factor = 2
else:
act_factor = 34
results['activations_gb'] = (
training.batch_size * training.seq_len * model.hidden_dim
* model.num_layers * act_factor * d / 1e9
)
# Overhead (CUDA context, fragmentation, misc)
subtotal = sum(results.values())
results['overhead_gb'] = subtotal * 0.10
results['total_gb'] = subtotal + results['overhead_gb']
# GPU requirements
results['h100_80gb_needed'] = results['total_gb'] / 80
return results
# Example: LLaMA-70B with ZeRO Stage 3 on 16 GPUs
llama70b = ModelConfig(
name="LLaMA-2-70B",
num_params_B=70,
num_layers=80,
hidden_dim=8192,
num_query_heads=64,
num_kv_heads=8,
head_dim=128,
vocab_size=32000,
max_seq_len=4096
)
training_cfg = TrainingConfig(
batch_size=1,
seq_len=4096,
use_activation_checkpointing=True,
zero_stage=3,
tensor_parallel_degree=8,
pipeline_parallel_degree=2,
data_parallel_degree=16,
precision='bf16'
)
budget = memory_budget(llama70b, training_cfg)
for k, v in budget.items():
print(f"{k:30s}: {v:.1f}")
Accounting for Memory Fragmentation
CUDA memory allocators suffer from fragmentation over long training runs. Two patterns cause this:
-
Variable-length sequence batching - If you sort sequences by length and pack them, tensor shapes vary between batches. The allocator carves out differently-sized regions over time, leaving small unusable gaps.
-
Optimizer step memory spikes - During the optimizer step, PyTorch temporarily allocates intermediate buffers. The peak memory during the optimizer step can be 10-20% higher than during the forward/backward pass.
import torch
# Monitor memory fragmentation
def memory_stats():
stats = torch.cuda.memory_stats()
allocated = stats['allocated_bytes.all.current'] / 1e9
reserved = stats['reserved_bytes.all.current'] / 1e9
fragmentation = (reserved - allocated) / reserved * 100
print(f"Allocated: {allocated:.2f} GB")
print(f"Reserved: {reserved:.2f} GB")
print(f"Fragmentation: {fragmentation:.1f}%")
print(f"Peak allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")
# Reduce fragmentation with expandable segments (PyTorch 2.0+)
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Or tune the garbage collection threshold
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = (
'max_split_size_mb:512,'
'garbage_collection_threshold:0.8,'
'expandable_segments:True'
)
Quantization as a Memory Strategy
When you cannot afford full bf16 precision, quantization reduces model weights to lower bit representations:
# GPTQ quantization - 4-bit weights with fp16 activations
# Memory: ~0.5 bytes/param instead of 2 bytes/param
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
# 4-bit NF4 quantization (bitsandbytes)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # Normal Float 4 - better than int4 for LLMs
bnb_4bit_compute_dtype=torch.bfloat16, # Compute in bf16, store in 4-bit
bnb_4bit_use_double_quant=True, # Quantize the quantization constants too
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
quantization_config=bnb_config,
device_map="auto"
)
# Memory check: 70B * 0.5 bytes = 35 GB instead of 140 GB
# Fits on a single H100 80GB with room for KV cache
# AWQ (Activation-aware Weight Quantization) - better quality than GPTQ
from awq import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(
"TheBloke/Llama-2-70B-AWQ",
fuse_layers=True,
trust_remote_code=False
)
Memory breakdown for LLaMA-70B with 4-bit quantization serving 100 users:
Weights (4-bit NF4): 70B * 0.5 bytes = 35 GB
KV cache (bf16, 100 seqs, 4096 tokens): 13.4 GB
Activations (current token only): 0.04 GB
Overhead (5%): 2.4 GB
Total: 50.8 GB - fits on one H100 80GB
Common Mistakes
:::danger Forgetting Optimizer States - The 8x Multiplier
The most common memory planning mistake in production: calculating only weight storage and forgetting that Adam optimizer adds 12 bytes per parameter (fp32 master weights + first moment + second moment).
For a 70B model in bf16:
- Weights: 140 GB (this is what people calculate)
- Gradients: 140 GB (often forgotten)
- Adam states: 840 GB (usually completely missing from estimates)
The total is 1,120 GB, not 140 GB. Estimates based on weights alone are 8x too low.
Always use the 16 bytes/parameter rule for mixed-precision Adam training:
Total training memory = num_params * 16 bytes
+ activation memory
:::
:::danger Confusing MHA and GQA Head Counts in KV Cache Calculations
LLaMA-2-70B uses Grouped Query Attention (GQA) with 64 query heads but only 8 KV heads. If you calculate KV cache using 64 heads instead of 8, your estimate is 8x too high.
Always check the model config for num_key_value_heads vs num_attention_heads. Models that use MQA (Multi-Query Attention, 1 KV head) or GQA (Grouped Query Attention) have significantly smaller KV caches.
from transformers import AutoConfig
config = AutoConfig.from_pretrained("meta-llama/Llama-2-70b-hf")
print(f"Query heads: {config.num_attention_heads}") # 64
print(f"KV heads: {config.num_key_value_heads}") # 8 for GQA
# Use num_key_value_heads for KV cache calculation
:::
:::warning Ignoring Activation Memory Spikes with Variable Sequence Lengths
Activation memory calculations assume a fixed sequence length. In practice, if you use dynamic batching (grouping sequences by length), the longest sequence in a batch determines the activation tensor size. A single 8192-token sequence in a batch of 4096-token sequences causes activation memory to double.
Always calculate activation memory using the MAXIMUM sequence length in your training data, not the average. Add 20% buffer for batching variance. :::
:::warning Pipeline Parallelism Bubble Memory is Often Missed
With pipeline parallelism, the GPU holding stage 0 must keep all intermediate activations alive until stage completes its forward pass. This "in-flight activation" memory is proportional to and is separate from the activation checkpointing memory.
For a 4-stage pipeline with 8 microbatches per batch: the in-flight activation buffer requires storing up to 4 full stages of activations simultaneously at peak. Always add this to your pipeline parallelism memory estimate. :::
:::warning KV Cache at Maximum Context Fills Memory Before Weights
For long-context models (32K-128K tokens), the KV cache at maximum context can exceed the model weight storage. A 70B model serving a single 128K-context request requires ~43 GB for KV cache vs 140 GB for weights. At batch size 3, the KV cache equals the weights.
Always plan KV cache memory separately and check whether your serving configuration can actually sustain the claimed context length at your target batch size. Many deployments silently limit context length because the KV cache overflows. :::
Interview Q&A
Q1: Walk me through the complete memory breakdown for training GPT-3 (175B parameters). How many A100 80GB GPUs do you need minimum?
The complete breakdown uses the 16 bytes/parameter rule for mixed-precision Adam training:
- Working weights (bf16): GB
- Gradients (bf16): GB
- Adam optimizer states (fp32 master + m + v): GB
- Subtotal (parameters): GB
For activations with checkpointing (assuming batch=1, seq=2048, L=96, H=12288): GB
Total: approximately 2,810 GB 2.74 TB.
Minimum A100 (80 GB) GPUs: GPUs.
But you need power-of-2 configurations for tensor parallelism, so the minimum practical deployment is 40 A100s with ZeRO Stage 3 (8-way tensor parallel 5-way data parallel), or more commonly 64 A100s (8-way tensor parallel 8-way data parallel) for clean topology and communication efficiency.
OpenAI reportedly trained GPT-3 on 1024 V100s (32 GB each). The large number allowed smaller tensor and pipeline parallel degrees, reducing communication overhead at the cost of more GPUs.
Q2: Explain the difference between ZeRO Stage 1, 2, and 3, and which components each stage shards.
ZeRO (Zero Redundancy Optimizer) progressively eliminates memory redundancy across data-parallel workers. In standard data parallel training, each GPU holds a complete copy of weights, gradients, and optimizer states - exactly N times the minimum required memory for N GPUs.
ZeRO Stage 1 - Optimizer State Partitioning: Each of the N data-parallel workers holds only of the Adam optimizer states (fp32 master weights, first and second moments). Weights and gradients are still replicated. Memory savings: 4x for large models where optimizer states dominate. Communication: same all-reduce as standard DDP.
ZeRO Stage 2 - Gradient and Optimizer Partitioning: Extends Stage 1 by also sharding gradients. After each layer's backward pass, gradients are immediately reduced and the non-local shards discarded. Each worker only keeps its gradient shard. Memory savings: 8x. Communication: same total volume as all-reduce but structured as reduce-scatter.
ZeRO Stage 3 - Full Sharding (Weights + Gradients + Optimizer): Even the model weights are sharded. Each worker holds only of the parameters. During the forward pass, an all-gather collects the full parameters before each layer's computation, then discards the non-local shards after use. Memory savings: linear with N - bytes per parameter instead of 16. Communication: 1.5x more than standard DDP (extra all-gather for weights).
The trade-off: higher ZeRO stages reduce memory but increase communication. Stage 3 is most beneficial when the model barely fits or does not fit, and when the inter-GPU network is fast (NVLink). On slower connections (PCIe or Ethernet), the extra communication of Stage 3 may make training slower despite better memory efficiency.
Q3: How does activation checkpointing work, and what is the memory-compute trade-off?
Activation checkpointing (gradient checkpointing) is a technique to reduce the memory required to store intermediate activations during the forward pass.
Without checkpointing: every intermediate tensor computed during the forward pass must be kept alive in GPU memory until the backward pass uses it to compute gradients. For a transformer with 80 layers, this means 80 sets of full activation tensors - proportional to .
With checkpointing: only selected "checkpoint" activations are kept (typically the input to each transformer layer). All other intermediates are discarded after the forward pass. During the backward pass, when a non-checkpointed activation is needed, it is recomputed by running the relevant forward computation again from the nearest checkpoint.
The trade-off: you pay with compute (each checkpointed segment is computed twice - once during forward, once during backward recomputation) to save memory. For a full-layer checkpointing strategy:
- Memory reduction: from activations stored to (one activation per checkpoint boundary)
- Compute overhead: approximately 30-33% more FLOPs (one extra forward pass of all but the last segment)
The formula I use: activation memory with checkpointing bytes, versus bytes without. A 17x reduction in activation memory for a 33% compute increase is almost always the right trade-off for large model training.
In PyTorch, torch.utils.checkpoint.checkpoint handles this automatically. In Hugging Face Transformers, model.gradient_checkpointing_enable() enables it globally.
Q4: How do you calculate KV cache memory, and why does it grow with both batch size and context length?
The KV cache stores the key and value matrices from every transformer attention layer for every token in every active sequence. It exists because autoregressive generation reuses previously-computed keys and values at every step - without the cache, you would recompute the full attention over the entire previous context at every token, making generation in compute rather than .
The formula:
Why it grows with batch size (): each concurrent request/sequence requires its own independent KV cache - different users have different conversation contexts.
Why it grows with context length (): each new token added to the context appends new KV pairs to the cache for all layers. A 4096-token conversation has 8x the KV cache of a 512-token conversation.
For a production serving system, KV cache is often the binding constraint on throughput. The practical implication: to double the number of concurrent users while holding GPU count constant, you must either halve the maximum context length, use a model with fewer/smaller KV heads (GQA or MQA), or reduce KV precision (int8 KV cache quantization).
vLLM's PagedAttention addresses KV cache fragmentation by managing it in 4096-byte pages allocated on demand, similar to virtual memory paging. This allows the system to serve many short conversations efficiently without reserving max-context memory for each.
Q5: A startup is building a chatbot product using LLaMA-2-70B. They want to serve 1,000 concurrent users with up to 8,192 tokens of context. Size the GPU cluster for this.
Let's work through the full calculation:
Model weights (bf16): GB
KV cache calculation: LLaMA-2-70B has 8 KV heads, head dim = 128, 80 layers.
For 1,000 concurrent users:
Activation buffer (current token only): approximately 0.5 GB at batch=1000
Total: GB
With 10% overhead: GB
H100 (80 GB) GPUs needed: GPUs
Round up to 40 GPUs (5 nodes of 8xH100 each) for clean parallelism.
But wait - not all 1,000 users will have full 8192-token contexts simultaneously. If the average active context is 2048 tokens (a realistic p50), and you plan for p95 = 6144 tokens on 200 of the 1,000 users:
Average KV memory: GB.
This drops to roughly 1,080 GB total, fitting on 14 H100s. Real deployments use dynamic batching with vLLM's continuous batching and PagedAttention to handle this variance efficiently.
The recommended configuration for this workload: 2 nodes of 8xH100 (16 GPUs, 1280 GB), serving the 140 GB model with tensor parallelism across 4 GPUs, leaving 1080 GB for KV cache. Plan to saturate the KV cache at roughly 800 concurrent 8K-context sessions.
Q6: What is the memory difference between using full Multi-Head Attention versus Grouped Query Attention for a 70B model serving 1,000 concurrent 4096-token requests?
This comparison shows exactly why GQA was introduced.
Full MHA (64 query heads, 64 KV heads):
For 1,000 concurrent sequences: GB - requiring 135 H100 GPUs just for KV cache.
GQA with 8 KV heads (LLaMA-2-70B actual configuration):
For 1,000 concurrent sequences: GB.
MQA (Multi-Query Attention, 1 KV head):
For 1,000 concurrent sequences: GB.
The comparison:
- MHA: 10,740 GB KV cache for 1K users
- GQA (8 KV heads): 1,340 GB KV cache (8x reduction)
- MQA (1 KV head): 168 GB KV cache (64x reduction)
GQA at 8 KV heads achieves most of MQA's memory benefit with minimal quality degradation (Ainslie et al., 2023 showed GQA matches MHA quality with 8 groups while MQA shows measurable quality loss). This is why virtually every modern LLM (Mistral, LLaMA-3, Gemma, Falcon) uses GQA rather than full MHA or MQA.
The practical impact: switching from MHA to GQA in a 70B model reduces the KV cache memory by 8x, enabling either 8x more concurrent users on the same hardware or 8x longer context windows at the same batch size.
