Skip to main content

Fault Tolerance in Large Cluster Training

The 3 AM Wake-Up Call

It is week four of a six-week training run. GPT-scale, 1024 H100s, $2.4 million in compute budget allocated. Your loss curve is clean, you are on track, and you finally slept more than five hours. Then your phone lights up at 3:14 AM.

A single NVLink switch on node 47 has started throwing uncorrectable ECC errors. NCCL is hanging. The collective operation that synchronizes gradients across all 128 nodes started 47 minutes ago and never finished. Every GPU in the cluster is frozen, waiting for node 47 to respond. Forty-seven minutes of compute time across 1024 GPUs - roughly $8,000 in wasted GPU hours - just evaporated. And it will keep evaporating until someone manually kills the job.

You stare at the monitoring dashboard. The last checkpoint saved 6 hours ago. That is 6 hours of training you will need to redo. At current throughput, that is 14 hours of wall-clock time to recover. The training run might still finish on schedule if you act fast. But if node 47's issue is hardware, not transient, you need to replace it, restart from checkpoint, and rebalance the cluster topology for the remaining nodes. All of this needs to happen before the morning standup.

This is not a hypothetical. Every team that has trained a large model at scale has a story like this. OpenAI's GPT-3 training logs mention multiple hardware failures. Meta's OPT-175B training diary, published openly, documents dozens of hardware failures, NaN gradients, and forced restarts over its 33-day training run. Google's PaLM training involved thousands of TPU chips over weeks and required careful fault handling to complete.

The math is brutal and unavoidable. With 1000 GPUs running continuously, if each GPU has a 1% chance of failing in any given day (a generous estimate - real server-grade GPU annual failure rates run 2-5%), you statistically lose a GPU roughly every 2-3 hours. At 10,000 GPUs, that becomes one failure every 15-20 minutes. Training a frontier model is no longer a software problem with occasional hardware hiccups. It is a systems reliability problem where failure is the default state and successful training is what you engineer for.

This lesson is about building the infrastructure to survive that reality: checkpointing strategies, failure detection, elastic training, and recovery automation that turns a catastrophic 3 AM event into a 15-minute automated recovery.


Why This Exists - The Single-Machine Assumption Breaks at Scale

Early deep learning frameworks were built around a single machine. You had one GPU, then four, then eight. If something failed, you restarted. The training run was short - hours or days - and losing progress was painful but survivable.

That assumption does not survive contact with modern scale. Consider what changes:

Training duration explodes. GPT-2 trained in days. GPT-3 required weeks. Llama-2 70B took roughly 1.7 million GPU-hours. At 2048 A100s, that is 35 days of continuous training. You cannot afford to restart from scratch after a hardware failure on day 28.

Cluster size multiplies failure probability. A single A100 might run for years without failure. But one A100 in a cluster of 4096 is a different statistical reality. With any reasonable hardware failure rate, failures become near-certain over multi-week runs.

All-reduce synchronization creates a single point of failure. Classic data-parallel training requires all GPUs to synchronize gradients every step via all-reduce. If one GPU hangs, all-reduce never completes. The entire cluster stalls. One failure in 4096 stops 4095 healthy GPUs.

The cost of lost progress is enormous. Checkpointing every hour when each hour represents $10,000 in compute at frontier scale makes the decision calculus different from checkpointing a laptop experiment.

Before fault tolerance infrastructure existed in any serious form, the approach was brute-force: checkpoint often (every 100-500 steps), monitor manually, restart from the last checkpoint when something breaks. This worked when clusters were small and training runs were short. It does not work when you are running 10,000 GPUs for six weeks.

The solution has three components: (1) smart checkpointing that minimizes both the performance overhead and the recovery cost, (2) failure detection that catches problems quickly instead of letting them stall the cluster for hours, and (3) elastic training that lets you reshape the job around failed nodes without restarting.


Historical Context - How the Industry Learned to Fail Gracefully

The early history of fault tolerance in distributed ML is largely a history of researchers adapting ideas from distributed systems and high-performance computing to the specific needs of neural network training.

Checkpoint-restart (the simplest form) was standard in HPC batch computing from the 1980s. You save state, and if the job dies, you restart from the saved state. The neural network training community borrowed this directly.

Elasticity as a concept came from cloud computing. Jeff Dean and Sanjay Ghemawat's MapReduce paper (2004) introduced the idea of stragglers and backup tasks - if a worker is slow, launch a duplicate and use whoever finishes first. This straggler mitigation idea eventually found its way into ML training frameworks.

TorchElastic was born at Facebook AI Research around 2019-2020. The core insight was simple but important: PyTorch's torch.distributed assumed a fixed world size. If you lost a node, the job died. TorchElastic allowed the world size to change dynamically - you could add or remove workers while training continued. It was integrated into PyTorch core as torch.distributed.elastic in PyTorch 1.9 (2021).

NCCL watchdog emerged from painful operational experience. Early large-scale training runs would hang silently when a collective communication operation stalled. Engineers would notice the job was consuming GPU power (all GPUs spinning on the hanging operation) but producing no output. The watchdog added a heartbeat mechanism: if a collective operation does not complete within a timeout, raise an error and trigger recovery rather than hanging indefinitely.

Async checkpointing became necessary as model sizes grew. A 70B parameter model in bf16 takes ~140GB of storage. Saving that synchronously blocks training for 30-60 seconds every checkpoint. At a checkpoint every 500 steps with a 1-second step time, you lose 6-12% of training time just to checkpointing. Async checkpointing moves the I/O to a background thread, overlapping it with the next training steps.

Microsoft's DeepSpeed and NVIDIA's Megatron both developed their own checkpoint formats and recovery mechanisms. The Megatron team in particular published detailed accounts of failure handling during GPT-3-scale experiments, establishing operational patterns that became industry standard.

The "aha moment" for the field came from the OPT-175B training diary published by Meta in 2022. Susan Zhang and the team documented every hardware failure, software crash, and manual intervention over the 33-day run. It made viscerally clear that fault tolerance was not a nice-to-have - it was the difference between finishing a training run and burning your compute budget on repeated restarts.


Core Concepts - The Math of Failure at Scale

Failure Rate Arithmetic

Let NN be the number of GPUs in your cluster, λ\lambda be the daily failure rate per GPU (as a probability), and TT be the expected time in days until the first failure.

For small λ\lambda, the cluster-level daily failure probability is approximately:

P(at least one failure today)1(1λ)NNλ for small NλP(\text{at least one failure today}) \approx 1 - (1 - \lambda)^N \approx N\lambda \text{ for small } N\lambda

And the expected time to first failure (in days) is approximately:

T1NλT \approx \frac{1}{N\lambda}

With N=1000N = 1000 and λ=0.01\lambda = 0.01 (1% daily failure rate per GPU):

T11000×0.01=110 days=2.4 hoursT \approx \frac{1}{1000 \times 0.01} = \frac{1}{10} \text{ days} = 2.4 \text{ hours}

With N=4096N = 4096 and λ=0.005\lambda = 0.005 (0.5% daily rate, modern server-grade hardware):

T14096×0.0050.049 days70 minutesT \approx \frac{1}{4096 \times 0.005} \approx 0.049 \text{ days} \approx 70 \text{ minutes}

This is why fault tolerance is not optional at frontier scale. You will see a hardware event roughly every hour to every few hours. The question is not whether you handle it, but how fast and how automatically.

The Checkpoint Recovery Cost Model

Define:

  • CC = checkpoint interval in steps
  • tst_s = time per training step (seconds)
  • tct_c = time to save a checkpoint (seconds)
  • trt_r = time to restore from checkpoint and restart (seconds)
  • λ\lambda = failure rate per step (probability that any given step triggers a failure)

The expected wasted compute per failure is approximately C2ts\frac{C}{2} \cdot t_s (on average, you lose half the steps since the last checkpoint).

The expected time lost per step due to checkpointing overhead is tcC\frac{t_c}{C}.

The total overhead rate (fraction of time lost) is:

overhead=tcCts+λCts2+λtr\text{overhead} = \frac{t_c}{C \cdot t_s} + \lambda \cdot \frac{C \cdot t_s}{2} + \lambda \cdot t_r

Taking the derivative with respect to CC and setting to zero gives the optimal checkpoint interval:

C=2tcλts2C^* = \sqrt{\frac{2 t_c}{\lambda \cdot t_s^2}}

This is the classic checkpoint optimization formula. In practice, you also have storage costs and the fact that λ\lambda is not constant (hardware often fails in bursts), so most teams use a mix of frequent lightweight checkpoints and less frequent full checkpoints.

Synchronous vs Asynchronous Checkpointing

Synchronous checkpointing: training stops, all ranks write their model shards to disk, training resumes. Simple, correct, but incurs full I/O latency at every checkpoint interval.

For a 70B model in bf16 across 8 nodes with NVMe local storage writing at 5 GB/s per node: each node writes approximately 140GB8=17.5GB\frac{140\text{GB}}{8} = 17.5\text{GB}, taking roughly 3.5 seconds. With a 1-second step time and checkpointing every 500 steps, synchronous checkpointing costs 3.5500=0.7%\frac{3.5}{500} = 0.7\% of training time. Manageable.

For a 540B model (PaLM-scale) at ~1TB total, with 8 GB/s NVMe: each node writes ~125GB, taking ~16 seconds. At 500-step intervals: 16500=3.2%\frac{16}{500} = 3.2\%. Now it matters.

Asynchronous checkpointing: immediately after a forward-backward-optimizer step, copy model state to CPU memory (fast, ~1-2 seconds), then write from CPU to disk in a background thread while the next training steps proceed. The critical insight is that CPU-to-disk I/O and GPU compute can overlap perfectly since they use different hardware.

The overlap is not free - you need enough CPU RAM to hold the checkpoint while writing. For a 70B model, that is ~140GB of CPU RAM per node, which is available on modern DGX/HGX systems (384-768GB CPU RAM). For a 540B model sharded across 128 nodes, each node holds ~8-9GB of parameters - comfortably within CPU RAM.

Elastic checkpointing: designed for variable world sizes. Instead of saving a checkpoint tied to a specific number of ranks and tensor parallel degree, elastic checkpointing saves the full (unsharded or consistently reshardable) model state, so it can be loaded with a different number of workers. This is what enables recovery from a node failure without using a spare - you just continue with N-1 nodes.


Straggler Detection and Mitigation

A straggler is a worker that is slower than its peers. In synchronous training, the cluster moves at the pace of the slowest worker. A GPU with thermal throttling, a node with a slow NIC, or a host with background I/O competing for PCIe bandwidth can silently slow the entire cluster by 5-15%.

Detection

The simplest signal is step time variance. In a healthy cluster, all ranks complete their forward-backward pass within a few milliseconds of each other. Track the distribution of step completion times across ranks. A rank consistently at the 95th percentile or higher is a straggler candidate.

NCCL provides timing hooks. You can instrument collective operations to measure how long each rank waits at the barrier before all-reduce begins. A rank with consistently high barrier wait time is finishing its compute late.

import torch
import torch.distributed as dist
import time

def timed_all_reduce(tensor, op=dist.ReduceOp.SUM):
"""Wrapper around all_reduce that logs per-rank timing."""
rank = dist.get_rank()
world_size = dist.get_world_size()

# Record when this rank arrives at the barrier
arrival_time = time.perf_counter()

# Synchronize so we can measure barrier wait
dist.barrier()
barrier_wait = time.perf_counter() - arrival_time

# Now do the actual all-reduce
t0 = time.perf_counter()
dist.all_reduce(tensor, op=op)
allreduce_time = time.perf_counter() - t0

# Gather timings across all ranks (optional, for monitoring)
timing_tensor = torch.tensor(
[arrival_time, barrier_wait, allreduce_time],
device=tensor.device
)
all_timings = [torch.zeros_like(timing_tensor) for _ in range(world_size)]
dist.all_gather(all_timings, timing_tensor)

if rank == 0:
barrier_waits = [t[1].item() for t in all_timings]
if max(barrier_waits) > 0.1: # 100ms threshold
slow_rank = barrier_waits.index(max(barrier_waits))
print(f"[WARNING] Straggler detected: rank {slow_rank} "
f"barrier wait {barrier_waits[slow_rank]:.3f}s "
f"vs mean {sum(barrier_waits)/len(barrier_waits):.3f}s")

return tensor

Mitigation

Options for handling a detected straggler:

  1. Log and alert - the simplest approach. Flag the straggler in monitoring, let an operator investigate. Suitable when stragglers are rare.

  2. Node replacement via elastic training - if a node is consistently slow (not transiently), use torchelastic to gracefully remove it from the training group and replace with a fresh node. This requires elastic training to be configured.

  3. Profile and fix - often the root cause is diagnosable: GPU thermal throttling (check nvidia-smi -q -d TEMPERATURE), PCIe contention, or a runaway CPU process. Fix the root cause and the straggler resolves.

  4. Async-SGD / backup workers - borrowed from MapReduce: allow training to proceed with N-k workers' gradients if k workers are slow. This introduces gradient staleness but keeps the cluster moving. Used in some asynchronous distributed training systems but less common in modern synchronous training due to convergence concerns.


NCCL Watchdog - Hang Detection

The most dangerous failure mode in distributed training is a silent hang. A collective operation (all-reduce, all-gather, broadcast) waits for all ranks to participate. If one rank crashes or hangs without the others knowing, the remaining ranks wait indefinitely. Without a watchdog, this is silent: CPUs show the process running, GPUs show compute utilization (the hanging collective consumes GPU cycles), but no training progress happens.

NCCL's watchdog (enabled in PyTorch via environment variables) detects this by setting a timeout on collective operations. If a collective does not complete within the timeout, NCCL raises an error, which PyTorch translates into an exception that can trigger recovery logic.

Configuring NCCL Timeouts

# Set NCCL operation timeout (default is 10 minutes, too long)
export NCCL_TIMEOUT=300 # 5 minutes in seconds

# Enable NCCL async error handling (critical for hang detection)
export NCCL_ASYNC_ERROR_HANDLING=1

# Verbose NCCL logging for debugging hangs
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL

# NCCL socket timeout (for TCP-based connections)
export NCCL_SOCKET_TIMEOUT=60

In PyTorch's torch.distributed.init_process_group, you can pass a timeout directly:

import torch.distributed as dist
import datetime

dist.init_process_group(
backend="nccl",
timeout=datetime.timedelta(minutes=5), # Overrides NCCL_TIMEOUT
init_method="env://"
)

Watchdog Thread Pattern

For production training, wrap your training loop with a watchdog thread that independently monitors heartbeats:

import threading
import time
import torch.distributed as dist


class TrainingWatchdog:
"""
Independent thread that monitors training heartbeats.
If the training loop stops updating the heartbeat for
longer than `timeout_seconds`, trigger an alert or recovery.
"""

def __init__(self, timeout_seconds: int = 300):
self.timeout = timeout_seconds
self.last_heartbeat = time.time()
self._stop_event = threading.Event()
self._thread = threading.Thread(target=self._watch, daemon=True)

def start(self):
self._thread.start()

def heartbeat(self):
"""Call this at the end of each training step."""
self.last_heartbeat = time.time()

def stop(self):
self._stop_event.set()

def _watch(self):
while not self._stop_event.is_set():
elapsed = time.time() - self.last_heartbeat
if elapsed > self.timeout:
rank = dist.get_rank() if dist.is_initialized() else -1
print(
f"[WATCHDOG] ALERT: No heartbeat for {elapsed:.0f}s "
f"on rank {rank}. Possible hang detected."
)
# In production: trigger PagerDuty alert, write to shared
# storage, attempt graceful shutdown
self._trigger_recovery()
time.sleep(30) # Check every 30 seconds

def _trigger_recovery(self):
# Signal the job scheduler to restart this rank
# or write a "hang detected" flag to shared storage
import os
os.makedirs("/tmp/training_flags", exist_ok=True)
with open("/tmp/training_flags/hang_detected", "w") as f:
f.write(f"hang_at_{time.time()}")

Elastic Training with TorchElastic

TorchElastic (torch.distributed.elastic) allows a training job to continue when the number of workers changes. Workers can leave (due to failure or preemption) or join (replacements) without restarting the entire job. This is the foundation of robust large-scale training.

How TorchElastic Works

TorchElastic wraps your training script with a rendezvous mechanism. Workers synchronize at rendezvous points to agree on the current world size and rank assignments. When a worker fails:

  1. The remaining workers detect the failure (via NCCL timeout or heartbeat).
  2. Workers rendezvous again to form a new group with a smaller world size.
  3. Training resumes from the last checkpoint with the new world size.

The key requirement: your training loop must be re-entrant. That means: checkpoint frequently, reload from checkpoint on restart, and do not store training state in global variables that survive across the rendezvous.

Training Script Structure for Elastic Training

import os
import torch
import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.utils.distributed import get_free_port
import torch.nn as nn
from pathlib import Path


CHECKPOINT_DIR = Path("/shared-storage/checkpoints")


@record # Required: wraps function to capture and report errors
def main():
# TorchElastic sets these environment variables
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

torch.cuda.set_device(local_rank)

dist.init_process_group(
backend="nccl",
init_method="env://",
timeout=torch.distributed.default_pg_timeout,
)

# Load model
model = build_model().cuda(local_rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Always restore from latest checkpoint
start_step = load_checkpoint(model, optimizer, rank, world_size)

dataloader = build_dataloader(
start_step=start_step,
world_size=world_size,
rank=rank
)

for step, batch in enumerate(dataloader, start=start_step):
loss = train_step(model, optimizer, batch)

if step % 100 == 0 and rank == 0:
print(f"Step {step}, loss: {loss:.4f}")

# Checkpoint every 500 steps
if step % 500 == 0:
save_checkpoint(model, optimizer, step, rank)

dist.destroy_process_group()


def save_checkpoint(model, optimizer, step, rank):
"""Save model + optimizer state to shared storage."""
if rank == 0:
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
ckpt_path = CHECKPOINT_DIR / f"step_{step:08d}.pt"
torch.save({
"step": step,
"model_state_dict": model.module.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}, ckpt_path)
# Write a "latest" pointer so recovery knows where to restart
(CHECKPOINT_DIR / "latest").write_text(str(ckpt_path))
print(f"[Rank 0] Checkpoint saved: {ckpt_path}")

# All ranks wait for rank 0 to finish writing
dist.barrier()


def load_checkpoint(model, optimizer, rank, world_size):
"""Restore from latest checkpoint if it exists."""
latest_ptr = CHECKPOINT_DIR / "latest"
if not latest_ptr.exists():
print(f"[Rank {rank}] No checkpoint found, starting from step 0")
return 0

ckpt_path = Path(latest_ptr.read_text().strip())
checkpoint = torch.load(ckpt_path, map_location=f"cuda:{rank % 8}")
model.module.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_step = checkpoint["step"] + 1
print(f"[Rank {rank}] Restored from step {start_step - 1}")
return start_step


if __name__ == "__main__":
main()

Launching with TorchElastic (torchrun)

# Basic elastic launch: min 4 workers, max 8 workers
torchrun \
--nnodes=1:8 \
--nproc_per_node=8 \
--rdzv_backend=c10d \
--rdzv_endpoint=master_node:29400 \
--rdzv_id=my_training_job_001 \
--max_restarts=3 \
train.py

# For multi-node with etcd rendezvous (more robust)
torchrun \
--nnodes=4:16 \
--nproc_per_node=8 \
--rdzv_backend=etcd \
--rdzv_endpoint=etcd_server:2379 \
--rdzv_id=my_training_job_001 \
--max_restarts=10 \
train.py

The --nnodes=4:16 means: minimum 4 nodes, maximum 16 nodes. The job will start when at least 4 nodes are ready and can accommodate additional nodes joining up to 16.


Async Checkpointing Implementation

For large models where synchronous checkpointing introduces unacceptable overhead, async checkpointing overlaps I/O with training:

import threading
import torch
import copy
from pathlib import Path
from queue import Queue, Empty
from typing import Optional
import time


class AsyncCheckpointer:
"""
Saves checkpoints asynchronously in a background thread.
Training continues while the checkpoint is being written.

Usage:
checkpointer = AsyncCheckpointer(save_dir="/shared/checkpoints")
checkpointer.start()

for step, batch in enumerate(dataloader):
# ... train step ...
if step % 500 == 0:
checkpointer.save_async(model, optimizer, step)

checkpointer.wait_and_stop() # Block until last checkpoint finishes
"""

def __init__(self, save_dir: str, max_keep: int = 3):
self.save_dir = Path(save_dir)
self.max_keep = max_keep
self._queue: Queue = Queue(maxsize=2) # Max 2 pending checkpoints
self._thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._saved_checkpoints = []

def start(self):
self.save_dir.mkdir(parents=True, exist_ok=True)
self._thread = threading.Thread(target=self._worker, daemon=False)
self._thread.start()

def save_async(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer,
step: int, rank: int = 0):
"""
Copy model state to CPU and queue for async disk write.
The CPU copy is fast (~1-2s for 70B on DGX). The disk write
happens in a background thread.
"""
if rank != 0:
return # Only rank 0 writes (for simple DP; adjust for model parallel)

t0 = time.perf_counter()

# Copy state to CPU (this is fast - uses DMA transfer)
cpu_state = {
"step": step,
"model_state_dict": {
k: v.cpu().clone() for k, v in model.module.state_dict().items()
},
"optimizer_state_dict": copy.deepcopy(optimizer.state_dict()),
}

copy_time = time.perf_counter() - t0
print(f"[AsyncCheckpointer] CPU copy took {copy_time:.2f}s at step {step}")

# Non-blocking put; if queue is full, block briefly (back-pressure)
self._queue.put((step, cpu_state), block=True, timeout=60)

def _worker(self):
"""Background thread: write checkpoints to disk."""
while not self._stop_event.is_set():
try:
step, state = self._queue.get(timeout=1.0)
except Empty:
continue

t0 = time.perf_counter()
ckpt_path = self.save_dir / f"step_{step:08d}.pt"
torch.save(state, ckpt_path)
write_time = time.perf_counter() - t0

# Update latest pointer atomically
latest_tmp = self.save_dir / "latest.tmp"
latest_tmp.write_text(str(ckpt_path))
latest_tmp.rename(self.save_dir / "latest")

self._saved_checkpoints.append(ckpt_path)
print(f"[AsyncCheckpointer] Written step {step} in {write_time:.2f}s")

# Rotate old checkpoints
self._rotate()
self._queue.task_done()

def _rotate(self):
"""Delete oldest checkpoints if we exceed max_keep."""
while len(self._saved_checkpoints) > self.max_keep:
old = self._saved_checkpoints.pop(0)
if old.exists():
old.unlink()
print(f"[AsyncCheckpointer] Deleted old checkpoint: {old}")

def wait_and_stop(self):
"""Block until the write queue is empty, then stop the worker."""
self._queue.join()
self._stop_event.set()
if self._thread:
self._thread.join()

Spot Instance Preemption Handling on AWS and GCP

Cloud spot/preemptible instances offer 60-90% discounts over on-demand pricing. For a training run that would cost 500kondemand,spotpricingcanbringthisto500k on-demand, spot pricing can bring this to 75-150k. The catch: the cloud provider can reclaim spot instances with 2 minutes warning (AWS) or 30 seconds warning (GCP).

AWS Spot Instance Interruption Notice

AWS sends an HTTP request to the EC2 instance metadata service (IMDS) 2 minutes before termination. You poll this endpoint and, when the notice appears, trigger a checkpoint and graceful shutdown.

import requests
import threading
import time
import signal
import sys


class SpotInterruptionHandler:
"""
Polls AWS IMDS for spot interruption notices.
When detected, triggers checkpoint save and graceful exit.
"""

IMDS_TOKEN_URL = "http://169.254.169.254/latest/api/token"
IMDS_SPOT_ACTION_URL = (
"http://169.254.169.254/latest/meta-data/spot/interruption-action"
)
POLL_INTERVAL = 5 # seconds

def __init__(self, on_interrupt_callback):
"""
on_interrupt_callback: callable that saves checkpoint and
cleans up. Will be called from the monitor thread.
"""
self.on_interrupt = on_interrupt_callback
self._stop = threading.Event()
self._thread = threading.Thread(target=self._monitor, daemon=True)

def start(self):
self._thread.start()

def stop(self):
self._stop.set()

def _get_imds_token(self) -> str:
"""Get IMDSv2 session token (required on modern AWS instances)."""
resp = requests.put(
self.IMDS_TOKEN_URL,
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
timeout=1
)
return resp.text

def _monitor(self):
token = None
while not self._stop.is_set():
try:
if token is None:
token = self._get_imds_token()

resp = requests.get(
self.IMDS_SPOT_ACTION_URL,
headers={"X-aws-ec2-metadata-token": token},
timeout=1
)
if resp.status_code == 200:
action = resp.text.strip()
print(
f"[SpotHandler] INTERRUPTION NOTICE received: {action}. "
"Triggering emergency checkpoint."
)
self.on_interrupt(action)
return # Stop monitoring after interrupt

except requests.exceptions.Timeout:
pass # IMDS not reachable or no notice
except requests.exceptions.ConnectionError:
token = None # Reset token on connection error

time.sleep(self.POLL_INTERVAL)


# Integration with training loop
def build_spot_aware_training_loop(model, optimizer, dataloader, checkpointer):
interrupted = threading.Event()

def emergency_save(action):
print(f"[Emergency] Saving checkpoint before spot termination ({action})")
checkpointer.save_async(model, optimizer, step=-1)
checkpointer.wait_and_stop()
interrupted.set()

handler = SpotInterruptionHandler(on_interrupt_callback=emergency_save)
handler.start()

for step, batch in enumerate(dataloader):
if interrupted.is_set():
print("[Training] Spot interruption detected. Exiting gracefully.")
sys.exit(0) # Exit with code 0; job scheduler will restart

loss = train_step(model, optimizer, batch)

if step % 500 == 0:
checkpointer.save_async(model, optimizer, step)

handler.stop()

GCP Preemptible VM Shutdown Signal

GCP sends a ACPI G2 Soft Off signal (equivalent to SIGTERM) with 30 seconds warning. Handle it with a SIGTERM handler:

import signal
import sys
import torch.distributed as dist


def setup_gcp_preemption_handler(model, optimizer, checkpointer):
"""
GCP gives 30 seconds on preemption via SIGTERM.
Save a best-effort checkpoint and exit cleanly.
"""

def sigterm_handler(signum, frame):
rank = dist.get_rank() if dist.is_initialized() else 0
print(f"[Rank {rank}] SIGTERM received (GCP preemption). "
"Saving emergency checkpoint.")

if rank == 0:
# Synchronous save - we only have 30 seconds
torch.save({
"step": current_step,
"model_state_dict": model.module.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"emergency": True,
}, "/shared-storage/checkpoints/emergency.pt")

# Let other ranks know we are stopping
if dist.is_initialized():
dist.barrier()

sys.exit(0)

signal.signal(signal.SIGTERM, sigterm_handler)

Checkpoint Recovery Time Optimization

Recovery time - the wall-clock time from failure detection to resumed training - has three components:

  1. Detection time: how long before the system knows something is wrong (NCCL timeout: 5-10 minutes by default, optimized: 1-2 minutes)
  2. Restart time: kill old processes, launch new ones, initialize NCCL communicators (1-3 minutes)
  3. Checkpoint load time: read model weights from storage into GPU memory

For a 70B model in bf16 (~140GB) loaded from NFS at 5 GB/s: ~28 seconds per node, but all 8 nodes load in parallel from the same file, so effective time is still ~28 seconds if using a parallel filesystem like Lustre.

For a 540B model: ~1TB total, ~125GB per node at 16 nodes. At 5 GB/s per node: ~25 seconds. Modern parallel filesystems (Lustre at 100+ GB/s aggregate) can load much faster.

Optimization strategies:

  • Keep checkpoints in fast storage: NVMe local SSDs vs NFS vs object storage (S3/GCS) have order-of-magnitude throughput differences. Local NVMe at 6 GB/s vs S3 at 0.5 GB/s per connection means 12x faster recovery from local storage.
  • Checkpoint sharding: each rank saves only its own shard. All ranks load in parallel. Linear scaling with node count.
  • In-memory shadow copy: on DGX systems with 384-768GB CPU RAM and fast NVLink, keep the last checkpoint in CPU memory. Recovery from CPU memory to GPU is much faster than from disk.
  • Reduce NCCL timeout: the biggest leverage on recovery time. Lowering from 10 minutes to 2 minutes saves 8 minutes per failure. At one failure every 2 hours, this is a 6.7% improvement in effective training throughput.

Mermaid Diagrams

Failure Detection and Recovery Flow

Async Checkpointing Timeline

Spot Preemption Handling Architecture


Production Engineering Notes

Checkpoint Storage Architecture

In production, checkpoints need to survive node failures. Storing checkpoints only on local NVMe is fast but dangerous - if the node with the checkpoint fails before it is replicated, you lose the checkpoint. The standard architecture:

  • Level 1 (fast): Local NVMe on each node - used for the most recent checkpoint. Fast write, fast read, survives node restart but not node replacement.
  • Level 2 (safe): Shared parallel filesystem (Lustre, GPFS, WekaFS) - visible from all nodes. Write after each checkpoint. Slower than NVMe but survives any single node failure.
  • Level 3 (archival): Object storage (S3, GCS) - write every Nth checkpoint (e.g., every 5000 steps). Cheap, durable, slow for recovery.

Recovery path: try Level 1 first (fastest), fall back to Level 2, fall back to Level 3.

NCCL Environment Configuration for Production

# /etc/profile.d/nccl_training.sh - applied to all training nodes

# Reduce hang detection time from 10 minutes to 3 minutes
export NCCL_TIMEOUT=180

# Enable async error handling (allows graceful recovery instead of crash)
export NCCL_ASYNC_ERROR_HANDLING=1

# Disable InfiniBand relaxed ordering for stability
# (slight throughput loss but eliminates a class of rare hangs)
export NCCL_IB_DISABLE=0
export NCCL_IB_TIMEOUT=23
export NCCL_IB_RETRY_CNT=7

# Pin NCCL communicator threads to specific CPU cores
# (prevents OS from migrating them, reduces latency variance)
export NCCL_THREAD_THRESHOLDS="512 512 64"

# Enable NCCL socket connections fallback if IB fails
export NCCL_SOCKET_IFNAME=eth0

# Set CUDA device order to match PCI bus (consistent across reboots)
export CUDA_DEVICE_ORDER=PCI_BUS_ID

Monitoring What Actually Matters

Don't just monitor loss and throughput. In a fault-tolerant training setup, also monitor:

  • Checkpoint write latency - should be stable. Sudden increases indicate storage issues.
  • Checkpoint age - how long since the last successful checkpoint. Alert if > 2x normal interval.
  • Failure count and distribution - which nodes fail most often? Repeat failures from the same node or rack indicate hardware issues that need physical intervention.
  • Recovery time per incident - track this over time. Increasing recovery times indicate accumulating issues (failing disks, degraded network).
  • Step time variance across ranks - P99 vs P50 step time. Widening gap indicates straggler development.
  • NCCL timeout events - each timeout event is a potential hang. Track frequency.

Checkpoint Format for Model Parallelism

When training with tensor parallelism (TP) or pipeline parallelism (PP), each rank holds only a shard of the model. Checkpoint format matters for recovery:

# BAD: saves sharded state (tied to specific TP/PP degree and rank count)
# If you change TP degree after recovery, this checkpoint is incompatible
torch.save(model.state_dict(), f"rank_{rank}_ckpt.pt")

# BETTER: consolidate to full model on rank 0 before saving
# Supports recovery with different TP/PP configurations
from megatron.training import save_checkpoint
# Megatron handles consolidation internally

# OR with FSDP: use full_state_dict context manager
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import FullStateDictConfig, StateDictType

cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
state_dict = model.state_dict()
if dist.get_rank() == 0:
torch.save(state_dict, "full_model_ckpt.pt")

Common Mistakes

:::danger Do Not Set NCCL_TIMEOUT Too High

The default NCCL timeout of 10 minutes means a hang can waste 10 minutes of cluster time per failure before being detected. At 1024 GPUs and 0.008/GPUhour(internalcost),10minutesofwastedcomputeis 0.008/GPU-hour (internal cost), 10 minutes of wasted compute is ~1,400 per incident. Lower it to 3-5 minutes for production training. Some teams go as low as 60-90 seconds if their operations consistently complete well within that window.

# Too high (default - costs money on hangs)
export NCCL_TIMEOUT=600

# Better for production
export NCCL_TIMEOUT=180

:::

:::danger Never Checkpoint Only to Local Storage

If your checkpoint strategy writes only to local NVMe and the node fails before replication, your checkpoint is gone. You will recover from an older checkpoint, losing more progress than necessary. Always replicate to at least one additional storage tier (shared filesystem or object storage) within a reasonable time window.

:::

:::warning Async Checkpointing CPU Memory Requirement

Async checkpointing requires enough CPU RAM to hold a complete copy of the model during the write window. Before enabling async checkpointing, verify:

  • CPU RAM available per node >= (model parameters per node in bytes) * 1.2 (safety margin)
  • For a 70B model on 8 nodes: each node holds ~17.5GB of params. You need at least 21GB free CPU RAM per node dedicated to checkpointing. Modern DGX systems have 384-768GB total, so this is fine. But on smaller servers (e.g., 64GB RAM), async checkpointing of large models will OOM.

:::

:::warning TorchElastic Rendezvous Requires Shared State

The rendezvous backend (c10d or etcd) must be accessible from all nodes in the cluster. If your rendezvous server (the master node or etcd endpoint) is behind a different network segment than some worker nodes, or if the rendezvous port is blocked by a firewall, TorchElastic will fail silently or hang at startup. Verify connectivity before your training run:

# From each worker node, verify rendezvous is reachable
nc -zv master_node 29400

:::

:::warning Do Not Save Optimizer State With Adam on the Same Schedule as Weights

Adam optimizer state is roughly 2x the model parameter size (two moment estimates per parameter). For a 70B model, saving optimizer state adds ~280GB to each checkpoint. If you checkpoint every 100 steps and each checkpoint takes 10 minutes to write, you are spending a lot of I/O on checkpoints you will never use. Consider:

  • Saving optimizer state every Nth checkpoint (e.g., every 5th) and saving weights-only checkpoints in between
  • Using float16/bf16 optimizer state compression (reduces by 2x)
  • Only saving optimizer state from rank 0 if using data parallelism (optimizer state is replicated across ranks)

:::


Interview Q&A

Q1: Why does adding more GPUs to a training cluster increase the probability of hardware failure during a run, and what is the mathematical relationship?

The probability that at least one GPU fails during a training run of duration TT with NN GPUs each having an independent daily failure rate λ\lambda is approximately P=1(1λ)NTP = 1 - (1-\lambda)^{N \cdot T}. For small λ\lambda, this is approximately NTλN \cdot T \cdot \lambda. So failure probability scales linearly with both cluster size and run duration. A 1000-GPU cluster running for 30 days with 0.1% daily failure rate per GPU has P1000×30×0.001=30P \approx 1000 \times 30 \times 0.001 = 30, which when plugged back into the exact formula gives P=1(10.001)300001e301.0P = 1 - (1-0.001)^{30000} \approx 1 - e^{-30} \approx 1.0. Failure is near-certain. This is why fault tolerance is not optional at frontier scale.

Q2: Explain the tradeoff between checkpoint frequency and checkpoint overhead. What is the optimal checkpoint interval?

Checkpointing too rarely means you lose more work per failure (expected loss = half the checkpoint interval). Checkpointing too frequently creates I/O overhead proportional to tc/(Cts)t_c / (C \cdot t_s) where tct_c is checkpoint write time, CC is interval in steps, and tst_s is step time. The optimal interval that minimizes total overhead is C=2tc/(λts2)C^* = \sqrt{2t_c / (\lambda \cdot t_s^2)} where λ\lambda is the per-step failure probability. In practice, async checkpointing decouples tct_c from the critical path, shifting the dominant term to the expected recovery work. With async checkpointing, you can checkpoint more frequently (every 100-500 steps) with minimal training overhead, reducing expected recovery loss at negligible cost.

Q3: What is TorchElastic and how does it differ from standard torch.distributed?

Standard torch.distributed requires a fixed world size specified at startup. If any rank dies, the job crashes and must be restarted from scratch (or from the last checkpoint with a manual restart). TorchElastic (torch.distributed.elastic) adds a rendezvous mechanism that allows the world size to change dynamically. Workers can leave (due to failure or preemption) or join (replacements or scale-up) without restarting the entire job. The surviving workers rendezvous, elect new rank assignments, reload from the last checkpoint, and continue training - all automatically. This is critical for long training runs on spot instances or large clusters with frequent individual node failures.

Q4: How does NCCL hang detection work and why is the default timeout too conservative for production?

NCCL's watchdog monitors collective operations (all-reduce, all-gather, broadcast) and raises an error if any operation exceeds a timeout (default: 10 minutes). The timeout is necessary because if one rank dies or hangs in the middle of a collective, all other ranks will wait indefinitely - they have no way to know the partner rank is dead. The 10-minute default is conservative to avoid false positives on slow networks or large models. In production, this is too long: a 10-minute silent hang wastes enormous compute. Modern production setups reduce this to 2-5 minutes. You can reduce safely if you profile your all-reduce operations and know they complete in well under 2 minutes on your network. Set NCCL_TIMEOUT in seconds and NCCL_ASYNC_ERROR_HANDLING=1 to enable clean error propagation.

Q5: What is the difference between synchronous and asynchronous checkpointing, and when would you use each?

Synchronous checkpointing blocks all training until the checkpoint is written to disk. It is simple and safe but adds I/O latency directly to training time. For small models (sub-30B parameters) with fast NVMe storage, this latency is often under 5 seconds and acceptable. Asynchronous checkpointing copies model state to CPU memory (fast: a few seconds for DMA transfer) and writes to disk in a background thread while training continues. This hides I/O latency but requires enough CPU RAM to hold the model state during the write window, and introduces a subtle risk: if the system crashes after the CPU copy but before the disk write completes, the checkpoint is lost. In practice, async checkpointing is preferred for large models (70B+) where synchronous I/O overhead would exceed 1-2% of training time, and where CPU RAM is abundant (DGX systems with 384GB+ RAM).

Q6: How do you handle spot instance preemption on AWS vs GCP, and what are the differences?

AWS Spot instances send a 2-minute interruption notice via IMDS metadata (accessible at http://169.254.169.254/latest/meta-data/spot/interruption-action). You poll this endpoint every 5 seconds in a background thread. When a notice appears, you have approximately 2 minutes to save a checkpoint and shut down cleanly. GCP Preemptible VMs send SIGTERM with a 30-second warning. You install a SIGTERM handler that immediately triggers a synchronous checkpoint save. The key difference is the warning window: AWS's 2 minutes allows async checkpointing to complete, while GCP's 30 seconds requires your checkpoint write to be very fast (either a very small model, fast NVMe, or a partial/gradient-only save). For large models on GCP, some teams maintain a hot CPU-memory shadow of the last checkpoint specifically to enable a fast SIGTERM-triggered save.

Q7: Describe a production fault tolerance architecture for a 1024-GPU training run on spot instances.

The complete architecture: (1) Use TorchElastic with --nnodes=120:128 and etcd rendezvous, so the job can continue with 15 fewer nodes. (2) Configure NCCL timeout to 3 minutes with async error handling. (3) Implement async checkpointing every 500 steps to local NVMe, with replication to Lustre parallel filesystem every 1000 steps and to S3 every 5000 steps. (4) Run an IMDS-polling thread on each node that triggers emergency synchronous checkpoint on spot interruption notice, then exits with code 0 to signal the job scheduler. (5) Use a watchdog thread per rank that alerts monitoring if no heartbeat in 5 minutes. (6) Configure the job scheduler (Slurm or Kubernetes) to automatically restart the TorchElastic job when workers exit cleanly (code 0), pointing to the latest checkpoint. (7) Monitor checkpoint age, per-rank step time variance, and NCCL timeout events in real time.


Summary

Fault tolerance in large cluster training is not a safety net - it is a core infrastructure requirement. The arithmetic is clear: at 1000+ GPUs and multi-week training runs, hardware failures happen every few hours. The difference between a team that ships a model and one that burns their budget on repeated restarts is the quality of their fault tolerance infrastructure.

The key components: NCCL watchdog with aggressive timeouts (3-5 minutes, not the default 10), async checkpointing that hides I/O behind training compute, TorchElastic for elastic recovery without full restarts, and spot preemption handlers that turn forced shutdowns into graceful checkpoints. These are not exotic research ideas - they are table stakes for any serious large-scale training effort.

Build the monitoring first: if you cannot see checkpoint age, per-rank timing variance, and NCCL timeout events in real time, you are flying blind. Then automate the recovery: human-in-the-loop recovery at 3 AM on day 28 of a training run is an organizational failure, not a technical one.

© 2026 EngineersOfAI. All rights reserved.