Skip to main content

ZeRO and Memory Efficiency

The Night the Cluster Ran Out of Memory

It is 2 AM at a major AI lab. The training run for a 70-billion parameter model has been queued for three weeks. The team fought for 512 A100 GPUs, justified the cost to three levels of management, and finally got the green light. The job script fires. Thirty seconds later, every GPU throws the same error: CUDA out of memory. Tried to allocate 2.50 GiB.

The model itself is only 140 GB in BF16. Spread across 512 GPUs, that is barely 274 MB per device. The math seems fine. But the memory profiler tells a different story. Optimizer states - Adam's first and second moments - consume 560 GB. Gradients take another 140 GB. Master weights in FP32 eat 280 GB. Total: nearly 1.1 terabytes, all of it duplicated on every single GPU. Each 80 GB A100 needs to hold over 2 GB just for this one model's state, before a single activation is computed.

This is the redundancy problem at the heart of data parallel training. For two years, the standard answer was "buy more GPUs" or "use a smaller model." Neither answer was acceptable in 2020, when researchers at Microsoft were trying to train models at scales nobody had attempted before. They needed a different approach - one that would turn redundancy from a bug into a feature by eliminating it.

The Zero Redundancy Optimizer (ZeRO) was the result. Published by Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, and Yuxiong He at Microsoft Research in 2020, ZeRO did something that seemed obvious in retrospect but had been missed for years: it partitioned the redundant state across GPUs rather than duplicating it. The insight was surgical. You only need the full gradient on the GPU that owns that parameter's optimizer step. You only need the full optimizer state on the GPU doing the update. You can reconstruct everything else on demand.

That insight let the Microsoft team train a 17-billion parameter model on 32 V100 GPUs in early 2020 - a scale that was physically impossible without ZeRO. Today, ZeRO is the backbone of nearly every large-scale training run in the world, whether through DeepSpeed directly or through PyTorch's FSDP implementation. Understanding it is not optional for any engineer working on models above a few billion parameters.

Why This Exists - The Redundancy That Kills You

Before ZeRO, the standard approach to scaling data parallel training was Distributed Data Parallel (DDP). DDP is conceptually clean: each GPU holds a full copy of the model, processes a different mini-batch, computes gradients, synchronizes gradients via all-reduce, and each GPU independently updates its parameters. The problem is what "full copy" means in memory.

For a model with NN parameters trained with Adam in mixed precision, the memory breakdown per GPU looks like this:

Memory per GPU=2NFP16 params+2NFP16 grads+4NFP32 master+4NAdam m+4NAdam v=16N bytes\text{Memory per GPU} = \underbrace{2N}_{\text{FP16 params}} + \underbrace{2N}_{\text{FP16 grads}} + \underbrace{4N}_{\text{FP32 master}} + \underbrace{4N}_{\text{Adam } m} + \underbrace{4N}_{\text{Adam } v} = 16N \text{ bytes}

For a 7-billion parameter model, that is 16×7×109=11216 \times 7 \times 10^9 = 112 GB per GPU. An A100-80GB cannot hold it without aggressive offloading. Even if you could fit it, every one of those 7 billion parameters has its optimizer state duplicated on every single GPU in the cluster. If you have 128 GPUs, you are storing 128×8N128 \times 8N bytes of optimizer state redundantly. That is not a minor inefficiency. It is the entire memory budget, wasted.

The approach before ZeRO that tried to solve this was model parallelism - split the model across GPUs and have each GPU own different layers. This works, but it is painful. It requires rewriting the model, it introduces pipeline bubbles, it does not compose cleanly with data parallelism, and it requires detailed knowledge of the model architecture. ZeRO's breakthrough was realizing you could get the memory benefits of model parallelism with the programming model of data parallelism - no model rewriting required.

The fundamental question ZeRO answers is: why are we storing the same bytes on every GPU? In DDP, GPU 0 and GPU 1 both hold Adam's first moment vector mm for every parameter. But GPU 0 only uses m[0:N/D]m[0:N/D] for its optimizer step. GPU 1 only uses m[N/D:2N/D]m[N/D:2N/D]. The overlap is 100% redundant. ZeRO eliminates that overlap without changing the observable behavior of the training algorithm.

Historical Context - From Parameter Servers to ZeRO

The story of distributed training memory starts with the parameter server architecture, first systematized by Mu Li et al. at Carnegie Mellon in 2013 and later at Google. The parameter server insight was that not every worker needs every parameter at every moment - you can fetch-and-push on demand. This worked well for sparse models (recommendation systems, NLP with bag-of-words) but was awkward for dense neural networks where every layer reads every parameter every forward pass.

The direct predecessor to ZeRO was the Megatron-LM work from NVIDIA (Shoeybi et al., 2019) which used tensor parallelism to split individual layers across GPUs. Megatron could train multi-billion parameter models, but required careful partitioning of attention heads and FFN columns, making it architecture-specific.

Rajbhandari et al. at Microsoft Research had the key insight in 2019: treat the data parallel group not just as a communication group for gradient synchronization, but as a memory pool. Each GPU in the group owns a shard of the total state. When you need a piece of state you do not own, you communicate to get it. The paper "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" was published at SC'20 (Supercomputing 2020) and won Best Paper.

The timing was perfect. GPT-3 had just demonstrated that scale was the path to capability. Everyone wanted to train larger models. ZeRO made it possible without requiring new hardware or model rewrites. Within a year, it was integrated into the Megatron-LM codebase, Microsoft's own Turing-NLG training pipeline, and eventually into PyTorch core as FSDP (Fully Sharded Data Parallel).

Core Concepts - The Three Stages of ZeRO

The Partition Principle

ZeRO works by partitioning three categories of model state across the DD data-parallel GPUs:

  1. Optimizer states (ZeRO Stage 1)
  2. Gradients (ZeRO Stage 2)
  3. Parameters (ZeRO Stage 3)

Each stage builds on the previous one. The key invariant is: at any point during training, each GPU holds only the state it needs to do its share of the optimizer step. Everything else is reconstructed through communication when needed.

ZeRO Stage 1 - Optimizer State Partitioning

Imagine you have 4 GPUs and a model with NN parameters. Each parameter has an Adam optimizer state: mim_i (first moment) and viv_i (second moment). In vanilla DDP, every GPU holds m[0:N]m[0:N] and v[0:N]v[0:N]. In ZeRO-1, GPU 0 owns optimizer states for parameters [0:N/4][0:N/4], GPU 1 for [N/4:N/2][N/4:N/2], and so on.

During the forward and backward passes, everything looks identical to DDP. After the backward pass, gradient synchronization uses a reduce-scatter instead of all-reduce: each GPU ends up with the sum of gradients only for the parameters whose optimizer state it owns. Then each GPU runs the optimizer update on its shard. Finally, a broadcast/scatter returns the updated parameters to all GPUs.

The memory savings are on the optimizer state only:

ZeRO-1 memory=2N+2N+4N+4N+4ND=4N+12ND\text{ZeRO-1 memory} = 2N + 2N + \frac{4N + 4N + 4N}{D} = 4N + \frac{12N}{D}

For large DD, this approaches 4N4N bytes per GPU - a 4x reduction from the 16N baseline. And importantly, communication volume is identical to DDP: the reduce-scatter plus broadcast is mathematically equivalent to all-reduce.

ZeRO Stage 2 - Gradient Partitioning

ZeRO-2 extends Stage 1 by also partitioning gradients. Each GPU only keeps the gradients for the parameters whose optimizer state it owns. During the backward pass, instead of computing full gradients everywhere and then doing all-reduce, ZeRO-2 uses a streaming reduce-scatter: as gradients are computed layer by layer during backward, they are immediately scattered to the owning GPU and discarded on all other GPUs.

Memory per GPU with ZeRO-2:

ZeRO-2 memory=2N+2N+4N+4N+4ND=2N+14ND\text{ZeRO-2 memory} = 2N + \frac{2N + 4N + 4N + 4N}{D} = 2N + \frac{14N}{D}

For large DD, this approaches 2N2N bytes - an 8x reduction from baseline. You are still holding the full FP16 parameters on every GPU, because the forward pass needs them. Communication volume remains identical to DDP.

ZeRO Stage 3 - Parameter Partitioning

ZeRO-3 is the most aggressive stage. Each GPU holds only 1/D1/D of the parameters as well. This requires new communication during the forward pass: before computing each layer, all GPUs must receive the full parameters for that layer via an all-gather. After the computation, the parameters are discarded by the non-owning GPUs. The same all-gather happens during the backward pass to recompute activations.

Memory per GPU with ZeRO-3:

ZeRO-3 memory=16ND\text{ZeRO-3 memory} = \frac{16N}{D}

This is perfect linear scaling with the number of GPUs. A 7B parameter model requiring 112 GB per GPU in DDP requires only 112/D112/D GB with ZeRO-3. With 128 GPUs, that is under 900 MB per GPU for model state - freeing most of each 80 GB A100 for activations and batch data.

The cost is communication volume. ZeRO-3 adds two all-gather operations per layer per forward pass and two per backward pass, roughly doubling communication vs DDP in the worst case. In practice, the prefetch overlap mechanism hides most of this overhead on systems with fast interconnects (NVLink at 600 GB/s or 400 Gb/s InfiniBand).

ZeRO-Infinity - Offloading to CPU and NVMe

ZeRO-Infinity extends Stage 3 by offloading optimizer states and even parameters to CPU memory or NVMe SSDs. CPU memory is typically 4-8x larger than GPU memory per node. NVMe is another 10-50x larger than CPU memory, though with much lower bandwidth.

The key to making ZeRO-Infinity practical is a prefetch pipeline: while the GPU is computing layer ii, the system prefetches layer i+1i+1's parameters from CPU/NVMe via PCIe DMA in parallel with the GPU computation. If computation time is longer than data transfer time - which holds for large transformer layers - the offloading latency is fully hidden.

For large transformer models, a single attention layer computation at sequence length 2048 takes several milliseconds on an A100. A 128 MB PCIe 4.0 DMA transfer at 25 GB/s takes about 5 ms. The overlap is favorable. For small models or short sequences, it is not, and ZeRO-Infinity adds overhead.

With ZeRO-Infinity, theoretical model scale is bounded only by storage, not GPU memory. This enabled training trillion-parameter models on hardware clusters that would be completely inadequate without offloading.

Memory Analysis: Full Comparison

DDP: 16N bytes/GPU\text{DDP: } 16N \text{ bytes/GPU} ZeRO-1: 4N+12ND4N (large D)\text{ZeRO-1: } 4N + \frac{12N}{D} \approx 4N \text{ (large D)} ZeRO-2: 2N+14ND2N (large D)\text{ZeRO-2: } 2N + \frac{14N}{D} \approx 2N \text{ (large D)} ZeRO-3: 16ND (linear scaling)\text{ZeRO-3: } \frac{16N}{D} \text{ (linear scaling)}

For a concrete example with a 30B parameter model (N=30×109N = 30 \times 10^9) and 64 GPUs:

StageMemory per GPUSavings vs DDP
DDP480 GB1x
ZeRO-1122 GB3.9x
ZeRO-261 GB7.9x
ZeRO-37.5 GB64x

Only ZeRO-3 fits on a single 80 GB A100 in this example.

FSDP - PyTorch's ZeRO-3 Equivalent

Fully Sharded Data Parallel (FSDP) is PyTorch's native implementation of ZeRO-3, introduced in PyTorch 1.11 and stabilized in PyTorch 2.0. It is co-designed with the transformer architecture but works with any module.

FSDP wraps modules in FSDPUnit objects. Each unit manages the lifecycle of its parameters: all-gather before the unit's forward/backward, discard after, store only the local shard otherwise. The sharding happens automatically when you wrap a module.

The critical FSDP concept is the wrapping policy. You tell FSDP which modules to treat as independent sharding units. A coarse-grained policy (wrap the whole model as one unit) minimizes communication rounds but means you hold the whole model in memory during each layer's computation. A fine-grained policy (wrap each transformer block as its own unit) minimizes peak memory but increases the number of communication rounds.

The right policy for a 7B+ parameter model is almost always fine-grained: wrap each transformer decoder layer as a separate FSDP unit. This lets the system discard each layer's full parameters immediately after computing it, keeping only the local shard for the next optimizer step.

Code Examples

DeepSpeed ZeRO-3 Configuration

{
"train_batch_size": 512,
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 16,
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"stage3_prefetch_bucket_size": 5e7,
"stage3_param_persistence_threshold": 1e5,
"stage3_max_live_params": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": 1e-4,
"betas": [0.9, 0.95],
"eps": 1e-8,
"weight_decay": 0.1
}
}
}

DeepSpeed ZeRO-3 Training Loop

# training_zero3.py
import torch
import deepspeed
from transformers import AutoModelForCausalLM
import json

def get_model_and_engine(model_name: str, ds_config_path: str):
"""Initialize model with DeepSpeed ZeRO-3."""

with open(ds_config_path) as f:
ds_config = json.load(f)

# For ZeRO-3, init model inside zero.Init context
# This prevents materializing full params on CPU before sharding
with deepspeed.zero.Init(config_dict_or_path=ds_config):
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
)

engine, optimizer, _, scheduler = deepspeed.initialize(
model=model,
config=ds_config,
)

return engine, optimizer, scheduler


def training_step(engine, batch):
"""Single training step with ZeRO-3."""

input_ids = batch["input_ids"].cuda()
labels = batch["labels"].cuda()
attention_mask = batch["attention_mask"].cuda()

# Forward pass - ZeRO-3 automatically all-gathers params per layer
outputs = engine(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
loss = outputs.loss

# Backward - ZeRO-3 handles reduce-scatter of gradients
engine.backward(loss)

# engine.step() handles gradient clipping + optimizer update
engine.step()

return loss.item()


def measure_memory_usage(tag: str = ""):
"""Log GPU memory stats."""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
peak = torch.cuda.max_memory_allocated() / 1e9
print(f"[{tag}] Allocated: {allocated:.2f} GB | "
f"Reserved: {reserved:.2f} GB | Peak: {peak:.2f} GB")
torch.cuda.reset_peak_memory_stats()

ZeRO-Infinity with CPU Offloading

# Extended ds_config with CPU offload
zero_infinity_additions = {
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True,
"buffer_count": 4,
"fast_init": False
},
"offload_param": {
"device": "cpu",
"pin_memory": True,
"buffer_count": 5,
"buffer_size": 1e8,
"max_in_cpu": 1e9
},
"stage3_max_live_params": 1e9,
"stage3_prefetch_bucket_size": 1e9,
"stage3_param_persistence_threshold": 1e6
}
}

# Measuring savings programmatically
def get_zero_memory_stats(ds_engine):
"""Extract memory stats from a running ZeRO engine."""

optimizer = ds_engine.optimizer

if hasattr(optimizer, 'fp16_groups'):
total_params = sum(
p.numel() for group in optimizer.fp16_groups
for p in group
)
print(f"Total parameters: {total_params / 1e9:.2f}B")
print(f"Stage: ZeRO-{ds_engine.zero_optimization_stage()}")
print(f"World size (sharding factor): {ds_engine.world_size}")
print(f"Effective param memory per GPU: "
f"{total_params * 2 / ds_engine.world_size / 1e9:.2f} GB (FP16 shard)")

gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
used_gb = torch.cuda.memory_allocated() / 1e9
print(f"GPU memory: {used_gb:.2f}/{gpu_memory_gb:.0f} GB used")

PyTorch FSDP - Complete Setup for LLaMA

# fsdp_training.py
import os
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
CPUOffload,
StateDictType,
FullStateDictConfig,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from functools import partial
from transformers import LlamaForCausalLM, LlamaConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer


def setup_fsdp_model(
model_config: LlamaConfig,
sharding_strategy: str = "FULL_SHARD",
cpu_offload: bool = False,
mixed_precision: bool = True,
) -> FSDP:
"""
Wrap a LLaMA model with FSDP.

sharding_strategy options:
FULL_SHARD = ZeRO-3 (shard params + grads + optim states)
SHARD_GRAD_OP = ZeRO-2 (shard grads + optim states only)
NO_SHARD = DDP equivalent
HYBRID_SHARD = FULL_SHARD within node, replicate across nodes
"""

mp_policy = None
if mixed_precision:
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)

cpu_offload_policy = None
if cpu_offload:
cpu_offload_policy = CPUOffload(offload_params=True)

strategy_map = {
"FULL_SHARD": ShardingStrategy.FULL_SHARD,
"SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
"NO_SHARD": ShardingStrategy.NO_SHARD,
"HYBRID_SHARD": ShardingStrategy.HYBRID_SHARD,
}

# Wrap each LlamaDecoderLayer as an independent FSDP unit
# This is critical - coarse wrapping defeats the memory benefit
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)

# Initialize on meta device - no CPU memory used
with torch.device("meta"):
model = LlamaForCausalLM(model_config)

fsdp_model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
sharding_strategy=strategy_map[sharding_strategy],
cpu_offload=cpu_offload_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
# Initialize actual parameters on current GPU device
param_init_fn=lambda module: module.to_empty(
device=torch.cuda.current_device()
),
device_id=torch.cuda.current_device(),
)

return fsdp_model


def save_fsdp_checkpoint(fsdp_model: FSDP, optimizer, output_dir: str):
"""
Save full consolidated state dict from FSDP model.
Uses rank0_only=True to avoid OOM from gathering on all GPUs.
"""

save_policy = FullStateDictConfig(
offload_to_cpu=True,
rank0_only=True
)

with FSDP.state_dict_type(
fsdp_model,
StateDictType.FULL_STATE_DICT,
save_policy
):
state_dict = fsdp_model.state_dict()
opt_state = FSDP.full_optim_state_dict(fsdp_model, optimizer)

if dist.get_rank() == 0:
checkpoint = {
"model": state_dict,
"optimizer": opt_state,
}
os.makedirs(output_dir, exist_ok=True)
torch.save(checkpoint, f"{output_dir}/checkpoint.pt")
print(f"Checkpoint saved to {output_dir}/checkpoint.pt")


def get_fsdp_memory_breakdown(fsdp_model: FSDP) -> dict:
"""Measure actual memory breakdown for an FSDP model."""

torch.cuda.synchronize()

stats = {
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
"peak_gb": torch.cuda.max_memory_allocated() / 1e9,
}

local_params = sum(p.numel() for p in fsdp_model.parameters())
total_params_tensor = torch.tensor(local_params, device="cuda")
dist.all_reduce(total_params_tensor, op=dist.ReduceOp.SUM)

stats["local_params_M"] = local_params / 1e6
stats["total_params_B"] = total_params_tensor.item() / 1e9
stats["sharding_factor"] = dist.get_world_size()

return stats

Choosing the Right ZeRO Stage

# zero_stage_selector.py
import math

def recommend_zero_stage(
model_params_B: float,
gpu_memory_gb: float,
num_gpus: int,
batch_size_per_gpu: int,
seq_len: int,
) -> dict:
"""
Recommend a ZeRO stage based on memory requirements.

Memory model (mixed precision, Adam optimizer):
DDP: 16 bytes/param total
ZeRO-1: (4 + 12/D) bytes/param
ZeRO-2: (2 + 14/D) bytes/param
ZeRO-3: 16/D bytes/param
"""

N = model_params_B * 1e9
D = num_gpus
bytes_per_gb = 1e9

# Model state memory per GPU (GB)
ddp_mem = 16 * N / bytes_per_gb
zero1_mem = (4 + 12 / D) * N / bytes_per_gb
zero2_mem = (2 + 14 / D) * N / bytes_per_gb
zero3_mem = 16 * N / D / bytes_per_gb

# Rough activation estimate
# Typical hidden_dim heuristic: 128 * sqrt(params_B)
hidden = 128 * math.sqrt(model_params_B)
n_layers = max(1, int(model_params_B * 1e9 / (12 * hidden**2)))
# 4 bytes (FP16) * seq * batch * hidden * n_layers
act_gb = (4 * seq_len * batch_size_per_gpu * hidden * n_layers) / bytes_per_gb

overhead = 1.15 # 15% framework overhead

results = {
"model_params_B": model_params_B,
"gpu_memory_gb": gpu_memory_gb,
"num_gpus": num_gpus,
"estimated_activation_gb": round(act_gb, 2),
}

for stage, model_mem in [
("DDP", ddp_mem),
("ZeRO-1", zero1_mem),
("ZeRO-2", zero2_mem),
("ZeRO-3", zero3_mem),
]:
total = (model_mem + act_gb) * overhead
results[stage] = {
"model_state_gb": round(model_mem, 2),
"total_estimate_gb": round(total, 2),
"fits": total <= gpu_memory_gb,
}

# Recommend minimum stage that fits
for stage in ["DDP", "ZeRO-1", "ZeRO-2", "ZeRO-3"]:
if results[stage]["fits"]:
results["recommendation"] = stage
break
else:
results["recommendation"] = "ZeRO-3 + CPU offload"

return results


if __name__ == "__main__":
# LLaMA-7B on 8x A100-80GB
r = recommend_zero_stage(7.0, 80.0, 8, 4, 2048)
print(f"Model: {r['model_params_B']}B on {r['num_gpus']}x {r['gpu_memory_gb']}GB")
print(f"Recommendation: {r['recommendation']}")
for s in ["DDP", "ZeRO-1", "ZeRO-2", "ZeRO-3"]:
i = r[s]
ok = "OK " if i["fits"] else "OOM"
print(f" {s:8s}: {i['model_state_gb']:6.1f} GB model, "
f"{i['total_estimate_gb']:6.1f} GB total [{ok}]")

Architecture Diagrams

Production Engineering Notes

Critical DeepSpeed ZeRO-3 Config Knobs

stage3_param_persistence_threshold controls which small parameters stay on the GPU permanently rather than being sharded. Parameters smaller than this threshold (in number of elements) are never evicted. The default of 100,000 is good for most models. Attention bias terms and layer norm parameters are typically much smaller than this threshold, so they naturally persist. Lowering to 10,000 saves a small amount of memory with marginal extra communication.

overlap_comm: true enables computation-communication overlap where the all-gather for layer i+1i+1 starts while GPU is computing layer ii. This can hide 60-80% of ZeRO-3 communication overhead on NVLink systems. Always enable it except when debugging timing issues. On InfiniBand clusters with certain NCCL MTU settings, it occasionally causes hangs - see the warning below.

allgather_bucket_size and reduce_bucket_size at 5e8 (500 MB) are appropriate for 40-80 GB GPUs. On 600 GB/s NVLink, increase to 1e9 for fewer large operations. On slower interconnects (100 Gb/s Ethernet), decrease to 2e8 for lower latency per operation.

sub_group_size controls how ZeRO-3 partitions optimizer work within each rank. Keep at 1e9 unless you are hitting memory spikes during the optimizer step.

FSDP with Gradient Checkpointing

FSDP and gradient checkpointing interact in a critical way. When using FULL_SHARD with checkpointing, the backward pass needs to all-gather parameters twice per checkpointed layer: once for the recomputed forward and once for the backward proper. This doubles communication for those layers, which is acceptable given the activation memory savings (see Lesson 06 for full treatment).

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)

def wrap_with_checkpointing(fsdp_model: FSDP):
"""Apply activation checkpointing after FSDP wrapping."""

check_fn = lambda m: isinstance(m, LlamaDecoderLayer)

apply_activation_checkpointing(
fsdp_model,
checkpoint_wrapper_fn=partial(
checkpoint_wrapper,
# NO_REENTRANT avoids issues with FSDP's param management
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
),
check_fn=check_fn,
)

return fsdp_model

Checkpoint Saving Strategy

Checkpoint saving with ZeRO-3 is the most error-prone part of production use. The three approaches and their trade-offs:

  1. Full consolidated checkpoint (rank 0 only): Gathers all parameters on rank 0 CPU. Simple, portable. Requires rank 0 to have enough CPU memory for the full model. Best for models up to ~70B.

  2. Distributed checkpoint (PyTorch DCP): Each rank saves its own shard to separate files. Fast saves and loads. Requires the same world size to load. Best for large checkpoints during training.

  3. In-flight consolidation: Uses DeepSpeed's _zero3_consolidated_16bit_state_dict(). Aggregates across CPU with ring-reduce pattern. Most memory-efficient for saving but slowest.

For production training, use distributed checkpoints during training (fast) and consolidate to a full checkpoint at the end for deployment.

Monitoring and Debugging ZeRO

# Enable DeepSpeed's built-in memory reporting
deepspeed --num_gpus=8 train.py \
--deepspeed ds_config.json \
2>&1 | grep -E "ZeRO|memory|stage"

# Monitor all GPUs in real time during training
watch -n 2 nvidia-smi \
--query-gpu=index,memory.used,memory.total,utilization.gpu \
--format=csv,noheader,nounits

# Check if overlap_comm is actually overlapping
# Look for NCCL and compute ops interleaving in profiler output
python -m torch.profiler.collect_trace -- python train_fsdp.py

:::tip When to Use Each ZeRO Stage

  • ZeRO-1: Model state fits in GPU memory but optimizer states are tight. Free 4x savings with zero communication overhead increase. Use this before anything else.
  • ZeRO-2: ZeRO-1 is not enough. Same communication cost as DDP, 8x total reduction. Good default for 7-30B models on 8-64 GPU clusters.
  • ZeRO-3 / FSDP FULL_SHARD: Model itself does not fit per GPU. Accept communication overhead for linear memory scaling. Required for 30B+ models on standard GPU configs.
  • ZeRO-Infinity: Scales beyond available GPU memory entirely. 20-40% training overhead but enables previously impossible scales (100B+ on a single node). :::

:::note FSDP vs DeepSpeed ZeRO-3 For PyTorch users starting a new project in 2024+, FSDP FULL_SHARD is the preferred choice over DeepSpeed ZeRO-3. FSDP is native PyTorch, composes correctly with torch.compile, and does not require learning the DeepSpeed config format. Use DeepSpeed ZeRO-3 specifically when you need ZeRO-Infinity (NVMe offload), which FSDP does not yet support natively. :::

Common Mistakes

:::danger ZeRO-3 with Out-of-Order Parameter Access ZeRO-3 assumes parameters are accessed in a predictable layer-by-layer order matching the module hierarchy. If your model accesses parameters out of order - conditional computation, mixture-of-experts routing, dynamic skip connections - ZeRO-3 triggers unexpected all-gather operations and can run 5-10x slower than DDP. The symptom is very high NCCL communication time and GPU compute utilization near zero. Always profile a few steps with torch.profiler before committing to ZeRO-3 for a new architecture. :::

:::danger Saving ZeRO-3 Checkpoints Without Consolidation Calling torch.save(model.state_dict()) directly on a ZeRO-3 or FSDP FULL_SHARD model saves only the local parameter shard for each rank. The resulting files are useless for inference - each contains 1/D1/D of the model. The correct approaches are engine._zero3_consolidated_16bit_state_dict() for DeepSpeed or FSDP.state_dict_type(StateDictType.FULL_STATE_DICT) context manager for FSDP. This is the single most common ZeRO production mistake. :::

:::warning ZeRO-2 and Pipeline Parallelism ZeRO-2 and ZeRO-3 do not compose cleanly with pipeline parallelism in most frameworks. If you are using Megatron-style pipeline parallelism, use ZeRO-1 only (optimizer state sharding). ZeRO-2 and ZeRO-3 require the full parameter set to be visible from within a single pipeline stage, which conflicts with pipeline partitioning. DeepSpeed's Hybrid Parallel configuration handles this correctly but requires explicit pipeline stage boundaries defined in the config. :::

:::warning The overlap_comm Deadlock overlap_comm: true occasionally causes deadlocks on clusters with custom NCCL configurations or non-standard InfiniBand MTU settings. If your job hangs at exactly step 1 (completes step 0 successfully, then freezes), disable overlap_comm first as a diagnostic step. This is especially common on clusters where NCCL's IB GDR (GPU Direct RDMA) is misconfigured. The symptom is all ranks blocked in an NCCL call with no progress indicator. :::

:::warning FSDP Wrapping Policy Silent Failures FSDP's transformer_auto_wrap_policy silently falls back to wrapping the entire model as a single unit if no modules match the provided transformer_layer_cls. This happens when your model uses a custom class name, a wrapper around the standard class, or a model loaded from a checkpoint where class names differ. The symptom is training appearing to work but using as much memory as DDP. Always verify wrapping depth by checking that isinstance(m, FSDP) is true for individual layer modules, not just the top-level model. :::

Interview Q&A

Q1: Explain the memory breakdown for a 13B parameter model trained with Adam in mixed precision. How does ZeRO-3 on 64 GPUs change this?

A: In mixed precision training with Adam, each parameter requires 16 bytes total:

  • 2 bytes: FP16 or BF16 parameters (forward/backward)
  • 2 bytes: FP16 or BF16 gradients
  • 4 bytes: FP32 master weight copy (numerically stable optimizer update)
  • 4 bytes: Adam first moment mm in FP32
  • 4 bytes: Adam second moment vv in FP32

For 13B parameters: 13×109×16=20813 \times 10^9 \times 16 = 208 GB per GPU in DDP. No single A100-80GB can hold this. With ZeRO-3 on 64 GPUs, all state is sharded: 208/64=3.25208/64 = 3.25 GB per GPU for model state. You still need activation memory (typically 10-30 GB depending on batch size and sequence length), but the model state itself is now comfortably manageable on an 80 GB A100. The trade-off is ZeRO-3 adds all-gather operations per layer, roughly doubling communication vs DDP. On a cluster with 400 Gb/s InfiniBand, this is typically acceptable.

Q2: What is the communication volume difference between DDP, ZeRO-2, and ZeRO-3? Why do ZeRO-1 and ZeRO-2 have the same communication cost as DDP?

A: DDP performs all-reduce on gradients. An all-reduce is implemented as reduce-scatter then all-gather, totaling 2N2N bytes per GPU (N bytes each way), which is the theoretical minimum for full gradient synchronization.

ZeRO-1 replaces the all-reduce with a reduce-scatter for gradients plus a broadcast of updated parameters after the optimizer step. Total bytes moved: 2N2N per GPU - identical to DDP. The only difference is which GPU computes the optimizer step for which parameters.

ZeRO-2 does reduce-scatter for gradients (each GPU keeps only its shard) with no subsequent all-gather, because ZeRO-2 leaves full parameters replicated. Gradient communication: NN bytes (reduce-scatter). Parameter update broadcast: NN bytes. Total: 2N2N - still identical to DDP.

ZeRO-3 adds all-gather for parameters before each layer: one per forward pass layer (NN bytes total across all layers), one per backward pass layer (NN bytes). Plus reduce-scatter for gradients (NN bytes). Total: approximately 3N3N bytes per GPU, or 1.5x DDP. The sliding window prefetch hides much of this on high-bandwidth interconnects.

Q3: When would you choose FSDP over DeepSpeed ZeRO-3 for a new training project?

A: FSDP is the right choice when:

  • Building a new system in pure PyTorch where native integration matters
  • Using torch.compile() - FSDP composes correctly with it, DeepSpeed does not fully yet
  • Your team has deep PyTorch expertise but limited DeepSpeed experience
  • You need fine-grained per-layer mixed precision control

DeepSpeed ZeRO-3 is the right choice when:

  • You need ZeRO-Infinity (NVMe offload) - FSDP does not support NVMe natively
  • Working with an existing DeepSpeed codebase (Megatron-DeepSpeed, many LLM repos)
  • You need the DeepSpeed optimizer library (1-bit Adam, CPU Adam with lower precision)
  • You want DeepSpeed's built-in memory profiler and flops counter

For a greenfield 7B-70B model training project in 2024 on 8-256 A100s, FSDP FULL_SHARD is the cleaner choice. For very large scale (1T+) or offload requirements, DeepSpeed.

Q4: Explain ZeRO-Infinity's offloading strategy. Why does offloading to CPU not simply make training 10x slower?

A: ZeRO-Infinity keeps only the currently active layer's parameters on the GPU. All other parameters live on CPU DRAM or NVMe SSD. The key is computational-communication overlap: while the GPU computes layer ii, the system uses PCIe DMA to prefetch layer i+1i+1's parameters from CPU memory in parallel.

This overlap works because large transformer layers are compute-bound, not memory-bound. A 4096x4096 matmul in BF16 on an A100 takes several milliseconds. A 128 MB PCIe 4.0 DMA transfer at 25 GB/s takes about 5 ms. As long as the DMA finishes before the GPU needs the next layer - which holds for large models - the latency is fully hidden.

For small models with fast layers, the compute-to-transfer ratio is unfavorable and ZeRO-Infinity does slow things down. For wide, deep transformers with large hidden dimensions, the ratio is favorable. In practice, ZeRO-Infinity adds 20-40% training overhead vs on-GPU ZeRO-3, while enabling models 5-10x larger. The 20-40% overhead is worth it when the alternative is "cannot train this model at all."

Q5: A training run using FSDP FULL_SHARD on 8 GPUs is 3x slower than DDP on the same 8 GPUs. What is likely wrong and how do you diagnose it?

A: The most common cause is incorrect wrapping policy - the entire model is being wrapped as a single FSDP unit instead of wrapping individual transformer blocks. When this happens, FSDP all-gathers all 7B parameters at once before any compute starts, immediately frees them, then all-gathers them again for backward. This is maximally inefficient.

Diagnostic steps:

First, check the wrapping depth: [(n, type(m).__name__) for n, m in model.named_modules() if 'FullySharded' in type(m).__name__]. If you see only one entry at the top level, wrapping is wrong.

Second, profile with torch.profiler and look for a single massive all_gather operation taking 500ms+ before any matmuls appear. Correct wrapping shows many small all_gather calls interleaved with compute.

Third, check that transformer_layer_cls in your auto_wrap_policy references the exact Python class of your transformer blocks - not a parent class or a wrapper. Print the class: type(list(model.modules())[5]).__name__ to verify.

Fix: ensure the auto-wrap policy targets each individual decoder layer class. After fixing, you should see 32 small all-gather calls (one per layer) instead of one large one.

Q6: How does ZeRO-2 handle gradient communication during gradient accumulation? What is the subtle trap?

A: In DDP, gradient accumulation is implemented by skipping the all-reduce for accumulation steps (using model.no_sync() context) and only synchronizing on the final step. This works because all-reduce is a pure aggregation with no state.

In ZeRO-2, the situation is more subtle. The reduce-scatter is not just synchronization - it routes gradients to the correct owner GPU. If you skip the reduce-scatter during accumulation steps, each GPU accumulates its local gradients without sending them to the owner GPU. When you finally do the reduce-scatter on the last accumulation step, the gradients from earlier steps are already baked into local tensors and the routing is correct. So the result is the same as doing reduce-scatter every step from a correctness standpoint.

The trap is when using DeepSpeed's gradient_accumulation_steps config. DeepSpeed handles this internally - do not also manually implement gradient accumulation in the training loop, as you will get double-accumulation. Either use DeepSpeed's config-driven accumulation or implement it manually with engine.backward(loss) called every step and engine.step() called every N steps. Not both.

Q7: What happens to ZeRO-3 when a model has tied weights, such as shared embedding and output projection matrices?

A: Tied weights are a significant complication for ZeRO-3. In standard ZeRO-3, each parameter has a single "owner" GPU. Tied weights are two names pointing to the same underlying tensor. ZeRO-3 may try to assign them to different shards, which creates two problems: different shards of the same tensor on different GPUs, and double gradient accumulation during backward (both the embedding and the LM head contribute gradients to the same underlying parameters from different backward paths).

DeepSpeed handles this by detecting tied parameters during deepspeed.zero.Init() and marking them to use the same shard assignment. FSDP handles it via sync_module_states=True and careful module wrapping - the two references must be under the same FSDP unit.

The practical debugging symptom of incorrectly handled tied weights is a language model where loss decreases for 100-200 steps and then diverges or plateaus much higher than expected (perplexity stuck above 100 for a language model that should reach 20-30). The embedding and LM head are effectively learning opposite directions because gradients are being incorrectly doubled or misrouted. Always verify tied weight handling explicitly by checking that model.embed_tokens.weight.data_ptr() == model.lm_head.weight.data_ptr() returns True after ZeRO initialization.

Q8: How do you benchmark the actual throughput impact of ZeRO stage upgrades before committing to a long training run?

A: The benchmark methodology is a short fixed-step run comparing throughput (tokens per second) and peak memory across ZeRO stages. Here is a practical approach:

# zero_benchmark.py
import time
import torch
import deepspeed
from transformers import AutoModelForCausalLM

def benchmark_zero_stage(stage: int, model_name: str, n_steps: int = 20):
"""Quick throughput benchmark for a given ZeRO stage."""

ds_config = {
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 8,
"bf16": {"enabled": True},
"zero_optimization": {"stage": stage},
}

with deepspeed.zero.Init(config_dict_or_path=ds_config):
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16
)

engine, _, _, _ = deepspeed.initialize(model=model, config=ds_config)

# Fake batch for benchmarking
batch_size, seq_len = 4, 2048
input_ids = torch.randint(0, 32000, (batch_size, seq_len)).cuda()
labels = input_ids.clone()

# Warmup
for _ in range(3):
outputs = engine(input_ids=input_ids, labels=labels)
engine.backward(outputs.loss)
engine.step()

torch.cuda.reset_peak_memory_stats()
t0 = time.time()

for _ in range(n_steps):
outputs = engine(input_ids=input_ids, labels=labels)
engine.backward(outputs.loss)
engine.step()

torch.cuda.synchronize()
elapsed = time.time() - t0
peak_gb = torch.cuda.max_memory_allocated() / 1e9
tokens_per_sec = (n_steps * batch_size * seq_len) / elapsed

return {
"stage": stage,
"tokens_per_sec": tokens_per_sec,
"peak_memory_gb": peak_gb,
"elapsed_sec": elapsed,
}

Run this with torchrun --nproc_per_node=8 for all three stages. The typical result on 8x A100s for a 7B model shows ZeRO-1 within 2% of DDP throughput, ZeRO-2 within 5%, and ZeRO-3 within 10-20% (depending on NVLink vs InfiniBand). If ZeRO-3 shows more than 30% throughput loss, check that overlap_comm is enabled and that the bucket sizes are large enough for your bandwidth. On NVLink-equipped nodes, ZeRO-3 overhead should be under 15%.

The key insight from benchmarking is that communication efficiency scales with the number of GPUs. On 8 GPUs, ZeRO-3 overhead is noticeable. On 64 GPUs, the memory savings are so large (and the communication is so well pipelined) that ZeRO-3 often outperforms ZeRO-2 in overall training efficiency because you can run much larger micro-batches.

© 2026 EngineersOfAI. All rights reserved.