Skip to main content

Gradient Checkpointing and Rematerialization

When Memory Runs Out Before Compute Does

It is 3 AM and a training run for a 13B parameter model has just crashed on step 1. The error is not an NaN. It is not a gradient explosion. It is a simple, brutal CUDA out of memory. The engineer on call has already spent six hours tuning ZeRO-3 and FSDP sharding. Model state is down to 8 GB per A100. The problem now is activations - 62 GB per GPU, and growing with every increase in sequence length or batch size.

Every modern deep learning model faces the same tension during training. The forward pass must store intermediate activations - the outputs of every layer - because the backward pass needs them to compute gradients. For a 32-layer transformer with 2048 sequence length and batch size 8, the activations alone can exceed 40 GB per GPU. Scale to sequence length 8192 and the number jumps past 150 GB. Even an H100-80GB cannot hold it.

The brute-force fix - reduce batch size until everything fits - destroys GPU utilization and slows training by 4-8x. For many modern training setups, the minimum batch size needed for stable training at scale already exceeds what fits in memory after model state. You are stuck: you need the activations for backprop, but you cannot afford to store them all.

Gradient checkpointing is the solution. The insight is surprisingly simple: you do not have to store every activation. You only need to store enough checkpoints that you can recompute any activation you need from the nearest checkpoint. The forward pass becomes a partial storage problem. The backward pass becomes a recompute problem. The total memory drops dramatically. The total compute increases by about 33%. For most training regimes, that is an excellent trade.

The technique was independently discovered and named multiple times - "rematerialization" in the compilers community, "gradient checkpointing" in the deep learning community, "activation recomputation" in the GPU training community. They are all the same idea. And understanding it deeply - not just knowing the API exists, but knowing when to checkpoint what - is the difference between engineers who can train at scale and those who cannot.

Why This Exists - The Activation Memory Problem

Understanding why activation memory is so large requires understanding what backpropagation actually needs. During the forward pass through a transformer block, the layer norm computes a normalized activation, the attention mechanism computes query, key, value projections and attention weights, and the FFN computes two linear projections with a nonlinearity. Every single one of these intermediate tensors must be kept alive until its corresponding backward pass.

For a single transformer layer with hidden dimension dd and sequence length ss and batch size bb, the activation memory is roughly:

Activations per layer12sbd bytes (FP16)\text{Activations per layer} \approx 12 \cdot s \cdot b \cdot d \text{ bytes (FP16)}

This comes from:

  • Layer norm output: sbds \cdot b \cdot d (2 bytes)
  • QKV projections: 3sbd3 \cdot s \cdot b \cdot d (6 bytes total)
  • Attention weights: s2bhnum_headss^2 \cdot b \cdot h \cdot \text{num\_heads} (variable)
  • Attention output: sbds \cdot b \cdot d (2 bytes)
  • FFN intermediate: sb4ds \cdot b \cdot 4d (8 bytes for the expanded dimension)

For LLaMA-7B (d=4096d = 4096, 32 layers, batch=8, seq=2048):

Total activations32×12×2048×8×4096×2=51 GB\text{Total activations} \approx 32 \times 12 \times 2048 \times 8 \times 4096 \times 2 = \text{51 GB}

This is before model state (16N bytes in DDP, or sharded with ZeRO). Even with perfect ZeRO-3 sharding of model parameters, activations can dominate GPU memory for long-context or large-batch training.

The naive solution is to not store activations at all and recompute everything from scratch during backprop. This would use constant memory for activations but would require running the entire forward pass twice for every backward pass - effectively doubling compute cost. Gradient checkpointing finds the optimal middle ground between these extremes.

Historical Context - From Compiler Theory to Deep Learning

The idea of trading computation for memory in algorithms goes back to work in programming language theory in the 1980s. The specific formulation for backpropagation appears in work by Andreas Griewank (1992) on "Achieving logarithmic growth of temporal and spatial complexity in reverse automatic differentiation." Griewank proved that the optimal checkpointing strategy for a sequential computation graph of depth LL requires only O(L)O(\sqrt{L}) memory (compared to O(L)O(L) for full storage) at the cost of one extra forward pass.

The first widely used implementation in deep learning was "Training Very Deep Networks" by Gomez et al. (2017) which proposed reversible residual connections as a way to reconstruct activations without storing them. This was the method used in RevNet and later iRevNet.

Chen et al. (2016) at the University of Washington published "Training Deep Nets with Sublinear Memory Cost" which applied Griewank's checkpointing idea directly to deep learning. They showed you could train networks with O(N)O(\sqrt{N}) memory in activations by checkpointing every N\sqrt{N}-th layer. This paper is the direct ancestor of torch.utils.checkpoint.

PyTorch integrated gradient checkpointing in version 0.4 (2018) via torch.utils.checkpoint.checkpoint(). JAX/XLA uses the term "rematerialization" and supports it via jax.checkpoint decorator. The concepts are identical but the execution differs due to XLA's compilation model.

The modern resurgence came with large language models. Transformers at 7B+ parameters, trained at sequence length 2048+ with reasonable batch sizes, made activation checkpointing not an optimization but a requirement. Every major LLM training framework - Megatron-LM, DeepSpeed, Hugging Face Accelerate - integrates it by default.

Core Concepts - Intuition Before Math

The Checkpoint Boundary Idea

Imagine training a neural network with LL layers. Standard backprop stores the output of every layer: a0,a1,a2,,aLa_0, a_1, a_2, \ldots, a_L. This requires O(L)O(L) memory for activations.

Gradient checkpointing divides the network into L\sqrt{L} segments of L\sqrt{L} layers each. At the segment boundaries, we store the activations (the "checkpoints"). Within each segment, we store nothing during the forward pass.

During backpropagation, when we need the activations for layer ii (which is inside some segment), we:

  1. Retrieve the checkpoint at the start of that segment
  2. Recompute the forward pass from the checkpoint to layer ii
  3. Use the recomputed activation for the gradient computation
  4. Discard the activation immediately after

This recomputation happens during the backward pass. The total memory for activations at any point is bounded by:

  • The L\sqrt{L} checkpoint activations stored permanently: O(L)O(\sqrt{L})
  • The activations in the currently active segment being recomputed: O(L)O(\sqrt{L})

Total: O(L)O(\sqrt{L}) activation memory instead of O(L)O(L).

The compute cost: each segment of L\sqrt{L} layers is recomputed once during backward. Total extra forward operations: L×L=L\sqrt{L} \times \sqrt{L} = L. Exactly one extra forward pass. So gradient checkpointing costs one extra forward pass in exchange for O(L)O(\sqrt{L}) activation memory instead of O(L)O(L).

The Memory-Compute Trade-off Formula

Let CfC_f be the cost of one forward pass (in FLOPS or wall time). Let CbC_b be the cost of one backward pass. In standard training, the total cost is Cf+Cb3CfC_f + C_b \approx 3C_f (backward is roughly 2x forward).

With gradient checkpointing over the full model: Total cost=Cf(first forward)+Cf(recompute during backward)+Cb=4Cf\text{Total cost} = C_f (\text{first forward}) + C_f (\text{recompute during backward}) + C_b = 4C_f

That is a 33% increase in compute: 4Cf/3Cf=1.334C_f / 3C_f = 1.33. The memory savings depend on the model. For a 32-layer transformer going from O(32)O(32) to O(32)O(6)O(\sqrt{32}) \approx O(6) checkpoint segments, activation memory drops by roughly 32/65x32/6 \approx 5x (in practice, 8-10x with optimized implementations).

Why 33% Extra Compute Is Often Worth It

The question engineers ask is: if gradient checkpointing costs 33% more compute, why not just buy 33% more GPUs and avoid the problem? The answer is economic:

  1. Activation memory scales with batch size and sequence length, model state does not. If you are training at sequence length 8192, activations dominate memory regardless of how much you shard model state. More GPUs with data parallelism does not help - activations are per GPU.

  2. Gradient checkpointing lets you increase batch size per GPU by 8-10x. A larger effective batch size often enables higher learning rates and faster convergence, potentially offsetting the 33% compute overhead.

  3. The 33% is a ceiling for full checkpointing. Selective checkpointing (only expensive layers) can achieve 5-8x memory savings with only 15-20% compute overhead.

Selective Checkpointing - The Engineer's Real Decision

Full gradient checkpointing (checkpoint every layer) is the safe default, but it is suboptimal. Different layers have dramatically different activation-to-compute ratios:

Attention layers produce large activations relative to their compute. The O(s2)O(s^2) attention weight matrix for long sequences is expensive to store but cheap to recompute (it is just a softmax of QKV products). For sequence length 2048 and 32 heads, attention weights alone are 20482×32×2=2682048^2 \times 32 \times 2 = 268 MB per layer per batch. Always checkpoint attention.

FFN layers expand to 4×dmodel4 \times d_{model} intermediate dimension. For dmodel=4096d_{model} = 4096, the FFN intermediate is 4×4096=163844 \times 4096 = 16384 features per token. At 2 bytes/feature and 2048 tokens: 16384×2048×2=6716384 \times 2048 \times 2 = 67 MB per layer per batch. Moderately expensive to store, cheap to recompute (two matmuls and a GELU). Checkpoint if memory is tight.

Layer norm activations are small (dmodeld_{model} values per token) but layer norm backward requires the pre-normalized input. The activation is small enough that storing it is usually worth avoiding the recompute overhead. Avoid checkpointing layer norm.

Embedding layers are only at the input - no need to checkpoint.

The practical decision matrix:

Layer TypeMemory CostRecompute CostCheckpoint?
Attention (softmax weights)HighLowYes
FFN intermediateMediumLowIf tight
Layer normLowVery lowNo
Residual addsVery lowTrivialNo

The Selective Checkpointing API Pattern

PyTorch's torch.utils.checkpoint.checkpoint() wraps a function. Any computation wrapped in it will not store activations - they will be recomputed on demand during backward. You can apply it at any granularity: entire layers, individual sublayers, or even specific operations.

from torch.utils.checkpoint import checkpoint

# Checkpoint entire layer (most common)
def forward(self, x):
for layer in self.layers:
x = checkpoint(layer, x, use_reentrant=False)
return x

# Checkpoint only the expensive part (attention + FFN, not layer norm)
def forward(self, hidden_states, attention_mask):
# Layer norm is cheap - do not checkpoint
normed = self.layer_norm(hidden_states)

# Attention is expensive - checkpoint it
attn_out = checkpoint(
self.attention, normed, attention_mask,
use_reentrant=False
)

# Residual add is trivial
hidden_states = hidden_states + attn_out

# FFN - checkpoint if memory is tight
normed2 = self.layer_norm2(hidden_states)
ffn_out = checkpoint(self.ffn, normed2, use_reentrant=False)

return hidden_states + ffn_out

Code Examples

Basic torch.utils.checkpoint Usage

# basic_checkpointing.py
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential


class TransformerLayer(nn.Module):
def __init__(self, d_model: int, nhead: int, d_ff: int):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Attention sublayer
normed = self.norm1(x)
attn_out, _ = self.self_attn(normed, normed, normed)
x = x + attn_out

# FFN sublayer
normed2 = self.norm2(x)
x = x + self.ffn(normed2)
return x


class CheckpointedTransformer(nn.Module):
def __init__(
self,
n_layers: int,
d_model: int,
nhead: int,
d_ff: int,
use_checkpointing: bool = True,
checkpoint_every_n: int = 1,
):
super().__init__()
self.layers = nn.ModuleList([
TransformerLayer(d_model, nhead, d_ff)
for _ in range(n_layers)
])
self.use_checkpointing = use_checkpointing
self.checkpoint_every_n = checkpoint_every_n

def forward(self, x: torch.Tensor) -> torch.Tensor:
for i, layer in enumerate(self.layers):
if self.use_checkpointing and (i % self.checkpoint_every_n == 0):
# use_reentrant=False is the modern API (avoids global state)
# Required when combining with FSDP
x = checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
return x


def measure_peak_memory(model_fn, *args, device="cuda"):
"""Measure peak memory during a forward+backward pass."""
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()

# Run forward + backward
out = model_fn(*args)
loss = out.sum()
loss.backward()

peak_mb = torch.cuda.max_memory_allocated(device) / 1e6
return peak_mb


def compare_checkpointing_memory():
"""Compare memory usage with and without gradient checkpointing."""

device = "cuda"
d_model, nhead, d_ff = 1024, 16, 4096
n_layers = 24
batch_size, seq_len = 8, 512

x = torch.randn(batch_size, seq_len, d_model, device=device, requires_grad=True)

results = {}

# No checkpointing
model = CheckpointedTransformer(
n_layers, d_model, nhead, d_ff, use_checkpointing=False
).to(device)
results["No checkpointing"] = measure_peak_memory(model, x)
del model; torch.cuda.empty_cache()

# Full checkpointing (every layer)
model = CheckpointedTransformer(
n_layers, d_model, nhead, d_ff,
use_checkpointing=True, checkpoint_every_n=1
).to(device)
results["Checkpoint every layer"] = measure_peak_memory(model, x)
del model; torch.cuda.empty_cache()

# Partial checkpointing (every 2 layers)
model = CheckpointedTransformer(
n_layers, d_model, nhead, d_ff,
use_checkpointing=True, checkpoint_every_n=2
).to(device)
results["Checkpoint every 2 layers"] = measure_peak_memory(model, x)
del model; torch.cuda.empty_cache()

# Sqrt(L) checkpointing - optimal strategy
import math
interval = max(1, int(math.sqrt(n_layers)))
model = CheckpointedTransformer(
n_layers, d_model, nhead, d_ff,
use_checkpointing=True, checkpoint_every_n=interval
).to(device)
results[f"Checkpoint every sqrt({n_layers})={interval} layers"] = measure_peak_memory(model, x)
del model; torch.cuda.empty_cache()

print(f"\nGradient checkpointing memory comparison")
print(f"Model: {n_layers}L x d={d_model}, batch={batch_size}, seq={seq_len}")
print("-" * 55)
baseline = results["No checkpointing"]
for strategy, peak_mb in results.items():
savings = baseline / peak_mb
print(f" {strategy:40s}: {peak_mb:7.1f} MB ({savings:.2f}x savings)")

LLaMA-Style Selective Checkpointing

# selective_checkpointing.py - Production-ready selective checkpointing
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from dataclasses import dataclass
from typing import Optional


@dataclass
class CheckpointConfig:
"""Configure which parts of a transformer block to checkpoint."""
checkpoint_attention: bool = True
checkpoint_ffn: bool = True
checkpoint_layer_norm: bool = False # Usually not worth it


class SelectivelyCheckpointedDecoderLayer(nn.Module):
"""
LLaMA-style decoder layer with configurable selective checkpointing.

Memory hierarchy for checkpointing decisions (most to least expensive):
1. Attention weights (O(seq^2)) - always checkpoint for long sequences
2. FFN intermediate (O(seq * 4*d)) - checkpoint if memory tight
3. Layer norm output (O(seq * d)) - usually not worth checkpointing
4. Residual connections (O(seq * d)) - never checkpoint
"""

def __init__(
self,
d_model: int,
n_heads: int,
d_ff: int,
ckpt_config: Optional[CheckpointConfig] = None,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.ckpt = ckpt_config or CheckpointConfig()

self.input_layernorm = nn.LayerNorm(d_model, elementwise_affine=True)
self.post_attention_layernorm = nn.LayerNorm(d_model, elementwise_affine=True)

# Attention projections
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.o_proj = nn.Linear(d_model, d_model, bias=False)

# FFN
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
self.down_proj = nn.Linear(d_ff, d_model, bias=False)

def _attention_forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor]
) -> torch.Tensor:
"""Attention computation - checkpoint wrapper target."""
B, S, _ = x.shape

q = self.q_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)

# Scaled dot-product attention
scale = self.head_dim ** -0.5
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale

if attention_mask is not None:
attn_weights = attn_weights + attention_mask

attn_weights = torch.softmax(attn_weights, dim=-1)
attn_out = torch.matmul(attn_weights, v)

attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, self.d_model)
return self.o_proj(attn_out)

def _ffn_forward(self, x: torch.Tensor) -> torch.Tensor:
"""SwiGLU FFN - checkpoint wrapper target."""
gate = torch.nn.functional.silu(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = hidden_states

# Layer norm - not checkpointed (small, cheap)
normed = self.input_layernorm(hidden_states)

# Attention - checkpointed if configured
if self.ckpt.checkpoint_attention:
attn_out = checkpoint(
self._attention_forward, normed, attention_mask,
use_reentrant=False
)
else:
attn_out = self._attention_forward(normed, attention_mask)

hidden_states = residual + attn_out

# Second layer norm - not checkpointed
residual = hidden_states
normed2 = self.post_attention_layernorm(hidden_states)

# FFN - checkpointed if configured
if self.ckpt.checkpoint_ffn:
ffn_out = checkpoint(self._ffn_forward, normed2, use_reentrant=False)
else:
ffn_out = self._ffn_forward(normed2)

return residual + ffn_out

Measuring Checkpointing Overhead for LLaMA-7B

# profile_checkpointing.py
import torch
import time
from contextlib import contextmanager


@contextmanager
def cuda_timer(label: str):
"""Context manager for CUDA event timing."""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
yield
end.record()
torch.cuda.synchronize()
elapsed = start.elapsed_time(end)
print(f" {label}: {elapsed:.1f} ms")


def profile_step(model, optimizer, x, labels, label: str):
"""Profile a complete training step: memory and time."""

torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
optimizer.zero_grad()

with cuda_timer(f"{label} - forward"):
outputs = model(x)
loss = torch.nn.functional.cross_entropy(
outputs.view(-1, outputs.size(-1)), labels.view(-1)
)

fwd_mem_gb = torch.cuda.memory_allocated() / 1e9
print(f" {label} - memory after forward: {fwd_mem_gb:.2f} GB")

with cuda_timer(f"{label} - backward"):
loss.backward()

peak_gb = torch.cuda.max_memory_allocated() / 1e9
print(f" {label} - peak memory: {peak_gb:.2f} GB")

optimizer.step()

return {
"peak_gb": peak_gb,
"fwd_mem_gb": fwd_mem_gb,
"loss": loss.item(),
}


def estimate_memory_savings_llama7b():
"""
Theoretical memory estimate for LLaMA-7B at different configs.

LLaMA-7B: 32 layers, d=4096, n_heads=32, d_ff=11008
"""

d_model = 4096
n_heads = 32
d_ff = 11008
n_layers = 32
bytes_fp16 = 2

print("LLaMA-7B Activation Memory Estimates (FP16)")
print("=" * 65)

for seq_len in [512, 1024, 2048, 4096, 8192]:
for batch_size in [1, 4, 8]:
# Per-layer activation estimate (approx)
# QKV projections: 3 * seq * batch * d_model
qkv_mem = 3 * seq_len * batch_size * d_model * bytes_fp16
# Attention weights: seq^2 * batch * n_heads
attn_mem = seq_len**2 * batch_size * n_heads * bytes_fp16
# Attention output: seq * batch * d_model
attn_out_mem = seq_len * batch_size * d_model * bytes_fp16
# FFN intermediate: seq * batch * d_ff
ffn_mem = seq_len * batch_size * d_ff * bytes_fp16
# Layer norm: 2 * seq * batch * d_model
ln_mem = 2 * seq_len * batch_size * d_model * bytes_fp16

per_layer = (qkv_mem + attn_mem + attn_out_mem + ffn_mem + ln_mem)
total_no_ckpt = per_layer * n_layers / 1e9
# With checkpointing: store sqrt(L) checkpoints, recompute rest
import math
n_ckpts = max(1, int(math.sqrt(n_layers)))
total_with_ckpt = (per_layer * (n_layers / n_ckpts + n_ckpts)) / 1e9

print(f" seq={seq_len:5d}, batch={batch_size}: "
f"no_ckpt={total_no_ckpt:.1f}GB, "
f"with_ckpt(sqrt)={total_with_ckpt:.1f}GB, "
f"savings={total_no_ckpt/max(total_with_ckpt,0.001):.1f}x")

Gradient Checkpointing with FSDP Integration

# fsdp_with_checkpointing.py
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from functools import partial
from transformers.models.llama.modeling_llama import LlamaDecoderLayer


def create_fsdp_model_with_checkpointing(model, local_rank: int):
"""
FSDP + gradient checkpointing integration.

Order matters: wrap with FSDP first, then apply checkpointing.
Applying checkpointing before FSDP changes the module hierarchy
and breaks the auto-wrap policy.

Communication note: with FULL_SHARD + checkpointing, backward pass
performs TWO all-gather operations per checkpointed layer:
1. All-gather for the recomputed forward pass (inside checkpoint)
2. All-gather for the actual backward pass
This is 33% more communication than FSDP without checkpointing.
It is almost always worth it for the activation memory savings.
"""

# Step 1: Wrap with FSDP (NO checkpointing yet)
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)

fsdp_model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
device_id=local_rank,
)

# Step 2: Apply activation checkpointing AFTER FSDP wrapping
# use NO_REENTRANT - required for FSDP compatibility
# Reentrant checkpointing uses global state that conflicts with FSDP's
# parameter management during the all-gather in backward
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)

apply_activation_checkpointing(
fsdp_model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=check_fn,
)

return fsdp_model


def training_loop_with_ckpt(fsdp_model, dataloader, optimizer, n_steps: int):
"""Complete training loop with FSDP + gradient checkpointing."""

fsdp_model.train()
step_times = []
peak_memories = []

for step, batch in enumerate(dataloader):
if step >= n_steps:
break

t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
torch.cuda.reset_peak_memory_stats()

input_ids = batch["input_ids"].cuda()
labels = batch["labels"].cuda()

t0.record()

optimizer.zero_grad()

# FSDP automatically handles all-gather before each layer
# Checkpointing handles recompute during backward
outputs = fsdp_model(input_ids=input_ids, labels=labels)
loss = outputs.loss

loss.backward()
optimizer.step()

t1.record()
torch.cuda.synchronize()

step_time = t0.elapsed_time(t1)
peak_mem = torch.cuda.max_memory_allocated() / 1e9
step_times.append(step_time)
peak_memories.append(peak_mem)

if dist.get_rank() == 0 and step % 10 == 0:
print(f"Step {step}: loss={loss.item():.4f}, "
f"time={step_time:.0f}ms, "
f"peak_mem={peak_mem:.2f}GB")

return step_times, peak_memories

JAX/XLA Rematerialization

# jax_rematerialization.py
# Equivalent concept in JAX (uses different terminology: "rematerialization")
import jax
import jax.numpy as jnp
from jax import checkpoint as jax_checkpoint
from functools import partial


def transformer_layer(params, x):
"""Simple transformer layer in pure JAX."""
# Layer norm
normed = jax.nn.standardize(x, axis=-1) * params["ln_scale"] + params["ln_bias"]

# Attention (simplified, no causal mask for clarity)
q = jnp.dot(normed, params["q_weight"])
k = jnp.dot(normed, params["k_weight"])
v = jnp.dot(normed, params["v_weight"])

scale = q.shape[-1] ** -0.5
attn = jax.nn.softmax(jnp.einsum("bsd,btd->bst", q, k) * scale, axis=-1)
attn_out = jnp.einsum("bst,btd->bsd", attn, v)
attn_out = jnp.dot(attn_out, params["o_weight"])

x = x + attn_out

# FFN
normed2 = jax.nn.standardize(x, axis=-1) * params["ln2_scale"] + params["ln2_bias"]
ffn_out = jax.nn.gelu(jnp.dot(normed2, params["ffn_w1"]))
ffn_out = jnp.dot(ffn_out, params["ffn_w2"])

return x + ffn_out


# Apply rematerialization (equivalent to gradient checkpointing)
# This tells XLA: "do not save activations from transformer_layer during forward;
# recompute them during backward"
checkpointed_layer = jax_checkpoint(transformer_layer)


def transformer_stack_with_remat(params_list, x):
"""Stack of transformer layers with rematerialization."""
for params in params_list:
# Each layer is rematerialized independently
x = checkpointed_layer(params, x)
return x


# Selective rematerialization with custom policy
# jax.checkpoint supports a policy argument to control which intermediates to save
def attention_heavy_remat_policy(prim, *args, **kwargs):
"""
Custom rematerialization policy for XLA.
Return True means 'save this value' (do not recompute).
Return False means 'recompute this value during backward'.

Save cheap things, recompute expensive things.
"""
# Save layer norm outputs (cheap, small)
if prim.name == "add":
return True # Residual connections - trivial to save
# Recompute attention weights (expensive to store, cheap to recompute)
if prim.name == "dot_general":
return False
return True # Default: save everything else


checkpointed_layer_selective = partial(
jax_checkpoint, policy=attention_heavy_remat_policy
)(transformer_layer)


# JAX makes it easy to compare memory profiles via jax.make_jaxpr
# This shows the computation graph with/without remat
def inspect_remat_graph(params, x):
"""Inspect the JAX computation graph with and without rematerialization."""

# Without rematerialization
no_remat_jaxpr = jax.make_jaxpr(
lambda p, x: jax.value_and_grad(
lambda p: transformer_layer(p, x).sum()
)(p)
)(params, x)

# With rematerialization
remat_jaxpr = jax.make_jaxpr(
lambda p, x: jax.value_and_grad(
lambda p: checkpointed_layer(p, x).sum()
)(p)
)(params, x)

# Count operations (remat should have more operations but fewer residuals)
no_remat_ops = len(no_remat_jaxpr.jaxpr.eqns)
remat_ops = len(remat_jaxpr.jaxpr.eqns)

print(f"Without rematerialization: {no_remat_ops} operations in graph")
print(f"With rematerialization: {remat_ops} operations in graph")
print(f"Extra ops from recompute: {remat_ops - no_remat_ops}")

Architecture Diagrams

Production Engineering Notes

The use_reentrant=False Requirement for FSDP

The use_reentrant parameter in torch.utils.checkpoint.checkpoint() controls how the recomputation is triggered during backward. The old reentrant implementation (default in PyTorch < 2.0) uses the autograd engine's reentrant backward mechanism, which has a global lock. FSDP's all-gather during backward conflicts with this lock, causing deadlocks in about 20% of real-world FSDP + checkpointing configurations.

Always use use_reentrant=False when combining checkpointing with FSDP. This uses a custom autograd Function instead of the reentrant mechanism and is compatible with FSDP. PyTorch 2.0+ makes this the default behavior.

Checkpoint Granularity and Peak Memory

The peak memory during gradient checkpointing depends on which segment is being recomputed. If you checkpoint every kk layers, during backward the system holds:

  • All stored checkpoints: roughly N/kN/k of total activation memory
  • The recomputed activations for the current kk-layer segment: k/Nk/N of total activation memory

Peak memory during backward is approximately k/N+N/kk/N + \sqrt{N/k} of total activation memory (the checkpoints plus one segment). Setting k=Nk = \sqrt{N} minimizes this to 2N2\sqrt{N} of total activation memory, confirming the O(L)O(\sqrt{L}) bound.

In practice, for a 32-layer model, setting k=1k = 1 (checkpoint every layer) is simple to implement and gives roughly 10x activation memory savings with 33% compute overhead. The difference between k=1k=1 and k=326k=\sqrt{32} \approx 6 in total memory is small compared to the simplicity benefit of k=1k=1. Use k=1k=1 for most production systems.

Detecting Checkpointing Overhead in Profiles

# profile_with_torch_profiler.py
import torch
from torch.profiler import profile, record_function, ProfilerActivity


def profile_checkpointing_overhead(model, x, n_warmup=2, n_profile=5):
"""
Profile a model to measure checkpointing recompute overhead.
Look for doubled forward operations in the backward phase.
"""

# Warmup
for _ in range(n_warmup):
out = model(x)
out.sum().backward()

# Profile
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=False,
) as prof:
for step in range(n_profile):
with record_function("forward"):
out = model(x)
loss = out.sum()

with record_function("backward"):
loss.backward()

model.zero_grad()

# Print top operations - look for matmul appearing in both forward and backward
print(prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20,
))

# Export for visualization
prof.export_chrome_trace("checkpointing_trace.json")

# Key metric: ratio of backward CUDA time to forward CUDA time
# Without checkpointing: ~2x (backward is ~2x forward)
# With checkpointing: ~2.66x (backward is ~2x forward + 0.66x recompute)
events = prof.key_averages()
fwd_time = sum(e.cuda_time_total for e in events if "forward" in e.key.lower())
bwd_time = sum(e.cuda_time_total for e in events if "backward" in e.key.lower())
if fwd_time > 0:
print(f"\nBackward/Forward time ratio: {bwd_time/fwd_time:.2f}x")
print("Expected without checkpointing: ~2.0x")
print("Expected with full checkpointing: ~2.66x (33% recompute overhead)")

Mixed Precision Interaction with Checkpointing

Gradient checkpointing recomputes the forward pass during backward. If your model uses mixed precision (BF16 forward, FP32 accumulation), the recomputation also happens in BF16. This means the recomputed activations may differ slightly from the original activations due to floating-point non-determinism. In practice, this difference is negligible for training stability. However, for architectures that depend on exact activation values (certain normalization schemes), be aware of this.

# Ensure consistent dtype during recompute
def stable_checkpoint_layer(layer_fn, *args):
"""
Wrapper that ensures recomputed activations are in the same dtype
as the original computation.
"""
input_dtype = args[0].dtype if args else None

def wrapped_fn(*a):
# Ensure same dtype in recompute as in original forward
if input_dtype is not None:
a = tuple(
t.to(input_dtype) if isinstance(t, torch.Tensor) else t
for t in a
)
return layer_fn(*a)

return checkpoint(wrapped_fn, *args, use_reentrant=False)

Checkpointing with Gradient Accumulation

Gradient accumulation (running multiple micro-batches before stepping the optimizer) interacts correctly with gradient checkpointing by default in PyTorch. Each micro-batch's backward pass triggers its own recompute from the checkpoints. The gradients accumulate in .grad tensors normally. No special handling is needed.

With FSDP, use the no_sync() context manager during accumulation steps to avoid premature gradient all-reduce. Checkpointing still works correctly inside no_sync() because the recompute happens locally.

# Gradient accumulation with FSDP + checkpointing
def train_with_accumulation(fsdp_model, dataloader, optimizer, accumulation_steps=4):
"""Correct gradient accumulation with FSDP and checkpointing."""

optimizer.zero_grad()

for step, batch in enumerate(dataloader):
is_last_accumulation = (step + 1) % accumulation_steps == 0

if not is_last_accumulation:
# No gradient sync for accumulation steps
with fsdp_model.no_sync():
outputs = fsdp_model(**batch)
# Normalize loss by accumulation steps
loss = outputs.loss / accumulation_steps
loss.backward()
# Checkpointing recompute happens here, no sync
else:
# Final accumulation step - sync gradients across FSDP ranks
outputs = fsdp_model(**batch)
loss = outputs.loss / accumulation_steps
loss.backward()
# Gradient sync happens here (reduce-scatter in FULL_SHARD mode)

torch.nn.utils.clip_grad_norm_(fsdp_model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()

:::tip Choosing Checkpointing Granularity in Practice For 99% of production transformer training, use checkpoint-every-layer (k=1). The implementation is one line of code, memory savings are 8-12x for activation memory, and the 33% compute overhead is usually dominated by communication costs in distributed training anyway. Only tune the interval kk if you have profiled and confirmed that checkpointing compute - not communication or data loading - is your bottleneck. :::

:::note Flash Attention and Checkpointing Interaction Flash Attention (Dao et al., 2022) is designed to be memory-efficient without checkpointing by recomputing the attention softmax during backward without storing attention weights. If you are using Flash Attention (via torch.nn.functional.scaled_dot_product_attention with attn_implementation="flash_attention_2"), you get the attention memory savings for free. In this case, you should still checkpoint the FFN layers but may be able to skip checkpointing the attention sublayer, reducing compute overhead while keeping the key memory savings. :::

Common Mistakes

:::danger Combining Reentrant Checkpointing with FSDP Using torch.utils.checkpoint.checkpoint(fn, *args) without use_reentrant=False will cause deadlocks or incorrect gradients when combined with FSDP. The reentrant implementation uses thread-local autograd state that conflicts with FSDP's parameter materialization during the all-gather in backward pass. Always use checkpoint(fn, *args, use_reentrant=False) or apply_activation_checkpointing with CheckpointImpl.NO_REENTRANT when using FSDP. This mistake causes a training job to hang with no error message after completing the first backward pass. :::

:::danger Checkpointing After FSDP Wrapping vs Before The order of FSDP wrapping and checkpointing application matters. If you apply apply_activation_checkpointing BEFORE wrapping with FSDP, the auto-wrap policy sees CheckpointWrapper modules instead of LlamaDecoderLayer modules and fails to match, wrapping the entire model as a single shard unit. Always apply FSDP wrapping first, then apply activation checkpointing. The symptom of wrong ordering is FSDP using as much memory as DDP (no sharding happening) combined with checkpointing overhead but none of the memory savings. :::

:::warning Checkpointing Modules with Internal State Do not checkpoint modules that maintain internal state modified during the forward pass - such as custom normalization layers that update running statistics, or modules with caches that change during forward. When the forward is recomputed during backward, the state-modifying side effect runs again, corrupting the statistics. Batch norm is the classic example: if you checkpoint a batch norm layer, the running mean and variance get updated twice (once in original forward, once in recomputed forward during backward). Use model.eval() mode or freeze running stats if you must checkpoint such layers. :::

:::warning Tensor Arguments vs Non-Tensor Arguments in checkpoint() torch.utils.checkpoint.checkpoint(fn, *args) only creates checkpoints (and manages gradients) for tensor arguments. Non-tensor arguments (integers, strings, booleans, Python objects) are captured by closure and NOT part of the checkpoint mechanism. This means if a non-tensor argument changes between forward and backward (e.g., training vs inference mode flags, dropout probabilities via a parameter), the recomputed forward will use the current value, not the value at forward time. Always make such flags into tensors or use deterministic seeding for any randomness inside checkpointed modules (dropout requires setting torch.manual_seed inside the checkpoint). :::

:::warning Memory Spikes During the Last Backward Segment With gradient checkpointing at interval kk, the backward pass has a predictable memory spike structure: each time a new segment is being recomputed, both the stored checkpoints AND the current segment's recomputed activations are in memory simultaneously. For the last segment in the backward pass (processing layers LkL - k through LL), you have all N/kN/k checkpoints plus the activations for kk layers. If you set kk too large (e.g., checkpoint only the first and last layer), this spike can exceed your GPU memory even though the average memory usage is low. Profile peak memory in backward pass separately from average memory, and verify that the spike does not cause OOM. :::

Interview Q&A

Q1: Explain gradient checkpointing intuitively. What memory-compute trade-off does it make, and what is the optimal checkpointing interval for a model with LL layers?

A: Standard backpropagation stores the output of every layer during the forward pass because the backward pass needs them to compute gradients. For a model with LL layers, this requires O(L)O(L) activation memory.

Gradient checkpointing divides the network into segments and stores only the activations at segment boundaries - the "checkpoints." During backward, when gradients for a layer inside a segment are needed, the system recomputes the forward pass from the nearest checkpoint up to that layer. The checkpointed activations are never stored beyond their segment boundary.

The optimal segment size balances two costs: storing checkpoints (cost proportional to N/kN/k where kk is segment size) and recomputing each segment during backward (cost proportional to kk per segment, N\sqrt{N} total). Minimizing N/k+kN/k + k over kk gives k=Nk = \sqrt{N}, resulting in O(L)O(\sqrt{L}) memory for activations.

The compute overhead: each of the N/k=NN/k = \sqrt{N} segments is recomputed once during backward. Total extra forward compute: N×N=N\sqrt{N} \times \sqrt{N} = N operations - exactly one extra forward pass. For modern transformers, this is approximately 33% total overhead (standard training is roughly 3×3\times forward in total compute: 11 forward +2+ 2 backward).

Q2: How does Flash Attention change the calculus for attention checkpointing?

A: Flash Attention (Dao et al., 2022) is an IO-aware attention algorithm that avoids materializing the full O(s2)O(s^2) attention weight matrix in GPU HBM by performing the attention computation in tiles that fit in SRAM. The key property for checkpointing is that Flash Attention's backward pass recomputes the softmax attention weights from Q, K, V rather than storing them. This is built-in rematerialization for the attention mechanism.

Concretely, standard attention stores the s2×b×hs^2 \times b \times h attention weight matrix (268 MB for seq=2048, batch=8, 32 heads). Flash Attention does not store this matrix at all - it only stores a small O(s)O(s) log-sum-exp normalization factor and recomputes the softmax tile-by-tile during backward.

For models using Flash Attention, you get attention activation savings for free. You should still checkpoint the FFN layers (their intermediate activations are significant: s×b×4ds \times b \times 4d in FP16). The typical setup with Flash Attention is: checkpoint FFN sublayers, do not checkpoint attention sublayers. This reduces compute overhead from 33% to approximately 18-20% while achieving most of the memory savings.

Q3: What is the difference between gradient checkpointing and reversible networks? When would you choose one over the other?

A: Reversible networks (Gomez et al., 2017) use invertible architectures where activations can be exactly reconstructed from later activations. The RevNet idea: maintain two streams x1,x2x_1, x_2 and update them as y1=x1+F(x2)y_1 = x_1 + F(x_2), y2=x2+G(y1)y_2 = x_2 + G(y_1). This is invertible: given y1,y2y_1, y_2 you can recover x1,x2x_1, x_2 without storing them. No stored activations means O(1)O(1) activation memory (just the current layer's tensors).

Gradient checkpointing is architecture-agnostic and works with any network. Reversible networks require the specific invertible residual structure, which is not compatible with all modern architectures (standard causal attention is not reversible, GQA is not reversible, the SwiGLU FFN is not reversible).

Choose reversible networks when you are designing a new architecture from scratch, need maximum memory efficiency, and are willing to constrain the architecture. They give better memory properties than checkpointing but worse than some modern alternatives like Flash Attention.

Choose gradient checkpointing when working with an existing architecture (LLaMA, GPT, BERT) that was not designed to be reversible. It is the universal, architecture-agnostic solution with controllable overhead.

In practice, gradient checkpointing combined with Flash Attention is the standard for transformer training. Reversible networks are primarily used in research contexts with specialized architectures.

Q4: A training job runs correctly for 100 steps, then OOMs during the backward pass at step 101. Gradient checkpointing is enabled. What could cause this specific failure pattern?

A: OOM during backward (not forward) with gradient checkpointing is a specific failure mode. Several causes:

First, growing sequence length or batch size across steps. Some dynamic batching schemes increase batch size over the first 100 steps as the GPU utilization ramp-up. If step 101 has a longer sequence or larger batch, the recomputed activations during backward exceed memory.

Second, gradient accumulation memory leak. If the training loop is not calling optimizer.zero_grad() correctly, gradients accumulate across steps. By step 101, each parameter has 100 steps worth of gradients accumulated in .grad, eventually causing OOM during the backward that tries to add to them.

Third, recompute buffer fragmentation. CUDA memory fragmentation can cause OOM even when there is technically enough free memory. The recompute in backward allocates and deallocates many small tensors rapidly. CUDA's allocator may fail to find a contiguous block. Fix: torch.cuda.empty_cache() periodically, or use PyTorch's new memory-efficient allocator (PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True).

Fourth, custom checkpointed function with state side effects. If the checkpointed function modifies state during forward (batch norm stats, KV cache), the recompute runs the modification again, potentially growing a cache unboundedly. Check for any state-modifying operations inside checkpointed regions.

Q5: How does gradient checkpointing interact with ZeRO-3 or FSDP FULL_SHARD in terms of communication overhead?

A: The interaction is additive but well-defined. In FSDP FULL_SHARD without checkpointing, the backward pass does one all-gather per layer to materialize the full parameters for gradient computation. With checkpointing, the backward pass does TWO all-gathers per checkpointed layer: one to materialize parameters for the recomputed forward pass, and one for the actual backward gradient computation.

So checkpointing + FSDP FULL_SHARD has approximately 2x the communication volume of FSDP FULL_SHARD alone, and approximately 3x the communication of vanilla DDP.

Why is this acceptable? Because the whole reason you are using FSDP FULL_SHARD in the first place is that the model does not fit in GPU memory without sharding. Similarly, checkpointing is used when activations do not fit without recomputation. The models that require both are typically in the 30-70B+ parameter range. At that scale, training throughput is usually limited by compute utilization (GPU MFU is the bottleneck), not by communication bandwidth - especially on NVLink nodes. The doubled communication is hidden behind the compute in well-tuned setups.

For cross-node training (InfiniBand), the communication may become a bottleneck. In that case, the standard optimization is to increase micro-batch size to amortize the per-step communication overhead, and to use pipeline parallelism across nodes (with FSDP within nodes).

Q6: Describe the sqrt(L) checkpointing strategy and implement it for a model with L transformer layers.

A: The L\sqrt{L} strategy places checkpoints every k=Lk = \lfloor\sqrt{L}\rfloor layers. For L=32L = 32 layers, k=5k = 5 or 6 (we place 5-6 checkpoints at layers 0, 6, 12, 18, 24, 30). During forward, only these checkpoint layers store their activations. During backward, to get gradients at any layer ii, we recompute from the nearest preceding checkpoint.

Total activation memory: approximately L\sqrt{L} checkpoints plus one segment being recomputed at any time = 2L2\sqrt{L} times the per-layer activation size. For L=32L=32, this is 232112\sqrt{32} \approx 11 layers worth of activations instead of 32.

import math

def forward_with_sqrt_checkpointing(layers, x):
L = len(layers)
k = max(1, int(math.sqrt(L))) # checkpoint interval

# Store checkpoints at each boundary
checkpoints = {0: x}

# Forward pass: compute all layers, store only checkpoint activations
for i, layer in enumerate(layers):
if i % k == 0 and i > 0:
# Checkpoint boundary: detach and store
# detach() stops gradient from flowing through the stored tensor
# The checkpoint mechanism handles recompute during backward
x = checkpoint(
lambda inp, idx=i: _run_segment(layers, inp, idx, k),
checkpoints[i - k],
use_reentrant=False
)
checkpoints[i] = x.detach()
# (In practice, torch.utils.checkpoint handles this automatically)

return x


def _run_segment(layers, x, start, k):
"""Recompute a segment of k layers starting from checkpoint at 'start'."""
for i in range(start, min(start + k, len(layers))):
x = layers[i](x)
return x

In practice, you implement this by wrapping every kk-th layer group in torch.utils.checkpoint.checkpoint() rather than the per-layer wrapping more commonly seen. The per-layer variant (k=1k=1) is simpler and only slightly worse in memory; use k>1k > 1 only when you have profiled that checkpointing compute overhead is the bottleneck.

Q7: Explain the "reentrant vs non-reentrant" checkpointing implementations in PyTorch. When does the distinction matter?

A: PyTorch's torch.utils.checkpoint.checkpoint() has two internal implementations controlled by the use_reentrant argument.

The reentrant implementation (the old default) works by registering a backward hook on the inputs to the checkpointed function. When backward reaches that hook, it re-runs the forward function from scratch (hence "reentrant" - the autograd engine re-enters the forward computation). This relies on global autograd state and Python's GIL. The hook mechanism assumes backward calls are not concurrent.

The non-reentrant implementation (new default in PyTorch 2.0+) uses a custom torch.autograd.Function subclass. The forward method saves the function and inputs, discards outputs. The backward method calls the saved function to recompute outputs and then runs the normal autograd backward on those outputs. This is entirely local state - no global autograd locks.

The distinction matters in three cases:

  1. FSDP: FSDP's all-gather in backward is triggered by parameter access, which happens inside the reentrant hook. The global lock in the reentrant mechanism serializes with other FSDP backward operations, causing deadlocks. Non-reentrant avoids this entirely.

  2. DataParallel: Similar issue with DP's backward synchronization.

  3. Models with kwargs (keyword arguments): The reentrant implementation cannot handle keyword arguments in the checkpointed function. checkpoint(fn, arg1, arg2) works, but checkpoint(fn, arg1, key=arg2) does not. Non-reentrant supports both.

For all new code in 2024+, use use_reentrant=False. The only reason to use reentrant is compatibility with pre-2.0 PyTorch versions or legacy code that relies on the specific gradient accumulation behavior of the reentrant mechanism.

© 2026 EngineersOfAI. All rights reserved.