Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the Model Parallelism demo on the EngineersOfAI Playground - no code required. :::

Training Infrastructure at Scale

The 2% Daily Failure Problem

The infrastructure lead was staring at a utilization chart. Their 500-GPU cluster had been running a 70B parameter pretraining job for 11 days. The job was scheduled for 21 days. But it had checkpointed 47 times - much more than the scheduled hourly checkpoint. Investigation revealed: with 500 GPUs in a job, the probability that at least one GPU encounters a hardware fault in a 24-hour period was roughly:

P(at least one failure)=1(10.0002)5001e0.19.5%P(\text{at least one failure}) = 1 - (1 - 0.0002)^{500} \approx 1 - e^{-0.1} \approx 9.5\%

At a 2% daily failure rate per GPU, a 500-GPU job had a 10% chance of a hardware failure every day. Over a 21-day run, the probability of at least one failure was effectively 100%. Without automatic fault recovery, every hardware failure would mean manually restarting from the last checkpoint, wasting hours of compute and requiring human intervention at 2 AM.

The team had not designed for this. They had built a training loop, added checkpoint saving every hour, and called it "fault tolerant." But they had no mechanism to automatically restart from checkpoint after a GPU failure, no elastic training to continue with fewer GPUs while the failed GPU was replaced, and no monitoring to distinguish "job completed" from "job silently hung."

This lesson covers the full production training infrastructure picture: networking, scheduling, fault tolerance, and the operational patterns that keep long training runs alive.


Why Training Infrastructure Is Engineering, Not Configuration

Running a training job on a single machine is configuration: install drivers, install PyTorch, run the script. Running training on a 100+ GPU cluster for weeks is engineering: you are building a distributed system with all the failure modes of distributed systems - network partitions, hardware failures, state consistency after restart, resource scheduling, and monitoring.

Most ML teams discover this gap the hard way: a promising training run fails on day 8 of 14, the team loses 3 days of compute, and the delivery date slips. The team that anticipated this problem built automatic fault recovery and completes the same run reliably.


Cluster Networking: InfiniBand vs Ethernet

Network bandwidth between GPUs determines how quickly gradients can be synchronized. Two technologies compete:

Ethernet (100GbE/400GbE): Standard data center networking. 100 GbE = 12.5 GB/s per port. Modern RDMA over Converged Ethernet (RoCE) enables low-latency GPU-to-GPU communication over standard switches. Used by cloud providers (AWS EFA, Azure RDMA).

InfiniBand (HDR 200Gb, NDR 400Gb): Purpose-built for HPC cluster communication. HDR InfiniBand = 200 Gbps = 25 GB/s per port. NDR = 400 Gbps = 50 GB/s. Lower latency than Ethernet (~1 µs vs 2–5 µs for RDMA Ethernet). Used in on-premises HPC clusters and some cloud services.

For training a 70B parameter model across 64 GPUs:

  • Gradient size: 70B × 4 bytes (fp32 gradients) = 280 GB
  • Ring all-reduce across 64 nodes: ~2 × 280 GB × 63/64 ≈ 551 GB of data
  • Over 100 GbE (12.5 GB/s): ~44 seconds - catastrophically slow
  • Over HDR InfiniBand (25 GB/s): ~22 seconds - still very slow
  • But: NCCL all-reduce over a fully connected topology is not a single serial data transfer - it is pipelined and overlapped with computation. Effective bottleneck is per-node bandwidth, not total bytes.

The practical impact: for 8-GPU single-node training, NVLink (600 GB/s) dominates and network does not matter. For multi-node training, InfiniBand HDR/NDR or AWS EFA (equivalent to InfiniBand performance) is required for efficient training at scale.

def estimate_communication_overhead(
n_parameters: int,
n_nodes: int,
gpus_per_node: int = 8,
nvlink_bandwidth_gbs: float = 600.0,
inter_node_bandwidth_gbs: float = 25.0, # HDR InfiniBand per link
) -> dict:
"""
Estimate all-reduce communication time for DDP training.
"""
n_gpus = n_nodes * gpus_per_node

# Gradient bytes per step (fp16 gradients)
grad_bytes = n_parameters * 2

# Intra-node all-reduce (NVLink)
# Each node does a local reduce before inter-node communication
intra_node_time_ms = (grad_bytes / (nvlink_bandwidth_gbs * 1e9)) * 1000 * 2

# Inter-node all-reduce time (InfiniBand)
# Ring all-reduce: each node sends/receives ~2 * grad_bytes * (n_nodes-1)/n_nodes
inter_node_bytes = 2 * grad_bytes * (n_nodes - 1) / n_nodes
inter_node_time_ms = (inter_node_bytes / (inter_node_bandwidth_gbs * 1e9)) * 1000

total_comm_ms = max(intra_node_time_ms, inter_node_time_ms)

return {
"n_gpus": n_gpus,
"gradient_gb": round(grad_bytes / 1e9, 2),
"intra_node_allreduce_ms": round(intra_node_time_ms, 1),
"inter_node_allreduce_ms": round(inter_node_time_ms, 1),
"estimated_total_allreduce_ms": round(total_comm_ms, 1),
"note": (
"Communication is bottlenecked by inter-node bandwidth"
if inter_node_time_ms > intra_node_time_ms
else "Communication is bottlenecked by intra-node NVLink"
),
}

NCCL: GPU Collective Operations

NCCL (NVIDIA Collective Communications Library) implements the all-reduce, all-gather, reduce-scatter, and broadcast primitives that distributed training depends on. PyTorch's dist.all_reduce() calls NCCL under the hood.

NCCL automatically selects the optimal algorithm (ring, tree, or NVLink all-reduce) based on the hardware topology it discovers at initialization.

import torch.distributed as dist

def profile_nccl_allreduce(tensor_size_gb: float = 1.0):
"""
Measure NCCL all-reduce bandwidth - run this on your actual cluster
before starting a long training job to baseline communication performance.
"""
import time

if not dist.is_initialized():
dist.init_process_group("nccl")

rank = dist.get_rank()
n_bytes = int(tensor_size_gb * 1e9)
n_floats = n_bytes // 4

# Create test tensor
tensor = torch.randn(n_floats, device="cuda")

# Warm up
for _ in range(5):
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()

# Benchmark
n_trials = 20
dist.barrier()
start = time.perf_counter()

for _ in range(n_trials):
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()

elapsed = time.perf_counter() - start
avg_time_ms = (elapsed / n_trials) * 1000

# Bus bandwidth: ring all-reduce sends 2*(N-1)/N * data_size
n = dist.get_world_size()
bus_bytes = 2 * (n - 1) / n * n_bytes
bus_bandwidth_gbs = bus_bytes / ((elapsed / n_trials) * 1e9)

if rank == 0:
print(f"All-reduce benchmark ({tensor_size_gb:.1f} GB tensor, {n} GPUs):")
print(f" Mean time: {avg_time_ms:.1f} ms")
print(f" Bus bandwidth: {bus_bandwidth_gbs:.1f} GB/s")
print(f" Theoretical peak: {dist.get_world_size() * 25:.0f} GB/s (InfiniBand HDR)")

Job Scheduling: Slurm

Slurm is the dominant HPC job scheduler for on-premises GPU clusters. Understanding Slurm is essential for working with institutional computing clusters (universities, national labs, many large companies).

#!/bin/bash
# Slurm batch script for distributed PyTorch training
#SBATCH --job-name=llm-pretrain
#SBATCH --nodes=8 # 8 nodes
#SBATCH --ntasks-per-node=8 # 8 GPUs per node = 64 total GPUs
#SBATCH --gres=gpu:8 # request 8 GPUs per node
#SBATCH --cpus-per-task=8 # 8 CPU cores per GPU
#SBATCH --mem=480G # 480GB RAM per node
#SBATCH --time=7-00:00:00 # 7 day time limit
#SBATCH --partition=gpu-high
#SBATCH --output=logs/pretrain_%j.out
#SBATCH --error=logs/pretrain_%j.err
#SBATCH --mail-type=FAIL,END

# Load modules
module load cuda/12.1 nccl/2.18 openmpi/4.1

# Set NCCL environment for InfiniBand
export NCCL_IB_DISABLE=0
export NCCL_IB_HCA=mlx5_0 # InfiniBand HCA device
export NCCL_NET_GDR_LEVEL=4 # Use GPUDirect RDMA
export NCCL_DEBUG=WARN

# Get master node for distributed coordination
MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n 1)
export MASTER_ADDR
export MASTER_PORT=29500

echo "Training on $SLURM_NNODES nodes, $SLURM_NTASKS GPUs total"
echo "Master: $MASTER_ADDR:$MASTER_PORT"

# Launch training with torchrun on all nodes
srun torchrun \
--nproc_per_node=8 \
--nnodes=$SLURM_NNODES \
--node_rank=$SLURM_NODEID \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train.py \
--model_size 70b \
--batch_size 4 \
--gradient_accumulation_steps 16 \
--checkpoint_dir /scratch/checkpoints/llm-pretrain

Fault Tolerance and Automatic Restart

The critical insight: with 500 GPUs at 2% daily failure rate, you should expect a hardware failure on average every 24 hours / (500 × 0.02) = 2.4 hours. Your training loop must handle this automatically.

import os
import time
import signal
import torch
import torch.distributed as dist
from pathlib import Path

class FaultTolerantTrainer:
"""
Training loop with automatic checkpoint save/restore and
graceful handling of SIGTERM (preemption) and hardware failures.
"""
def __init__(
self,
model,
optimizer,
lr_scheduler,
checkpoint_dir: str,
checkpoint_interval_steps: int = 500,
):
self.model = model
self.optimizer = optimizer
self.scheduler = lr_scheduler
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_interval = checkpoint_interval_steps
self.global_step = 0
self.epoch = 0
self._shutdown_requested = False

# Register signal handler for graceful preemption
signal.signal(signal.SIGTERM, self._handle_sigterm)

def _handle_sigterm(self, signum, frame):
"""Handle preemption/timeout gracefully - save checkpoint before dying."""
rank = dist.get_rank() if dist.is_initialized() else 0
if rank == 0:
print("SIGTERM received - saving emergency checkpoint")
self._save_checkpoint(is_emergency=True)
self._shutdown_requested = True

def _save_checkpoint(self, is_emergency: bool = False):
"""Save checkpoint. Only rank 0 saves to avoid race conditions."""
if dist.get_rank() != 0:
return

checkpoint = {
"global_step": self.global_step,
"epoch": self.epoch,
"model_state_dict": self.model.module.state_dict()
if hasattr(self.model, "module") else self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict() if self.scheduler else None,
"timestamp": time.time(),
}

# Save to temporary file first, then rename atomically
# This prevents corrupt checkpoints if the process dies mid-write
tmp_path = self.checkpoint_dir / "checkpoint_tmp.pt"
final_path = self.checkpoint_dir / f"checkpoint_step_{self.global_step}.pt"
latest_path = self.checkpoint_dir / "checkpoint_latest.pt"

torch.save(checkpoint, tmp_path)
tmp_path.rename(final_path)

# Atomic symlink update for "latest" pointer
if latest_path.exists():
latest_path.unlink()
latest_path.symlink_to(final_path)

prefix = "EMERGENCY" if is_emergency else "checkpoint"
print(f"[Rank 0] {prefix} saved at step {self.global_step}: {final_path}")

def load_latest_checkpoint(self) -> bool:
"""Attempt to restore from latest checkpoint. Returns True if successful."""
latest_path = self.checkpoint_dir / "checkpoint_latest.pt"
if not latest_path.exists():
print("No checkpoint found - starting from scratch")
return False

checkpoint = torch.load(str(latest_path), map_location="cpu")
self.global_step = checkpoint["global_step"]
self.epoch = checkpoint["epoch"]

# Load model state (handle FSDP/DDP wrapped models)
if hasattr(self.model, "module"):
self.model.module.load_state_dict(checkpoint["model_state_dict"])
else:
self.model.load_state_dict(checkpoint["model_state_dict"])

self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler and checkpoint.get("scheduler_state_dict"):
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

rank = dist.get_rank() if dist.is_initialized() else 0
if rank == 0:
saved_time = time.strftime(
"%Y-%m-%d %H:%M:%S",
time.localtime(checkpoint["timestamp"])
)
print(f"Restored from checkpoint at step {self.global_step} (saved {saved_time})")
return True

def train(self, data_loader):
"""Main training loop with fault tolerance."""
# Skip batches already processed before checkpoint
steps_to_skip = self.global_step % len(data_loader)

for epoch in range(self.epoch, 100): # max epochs
self.epoch = epoch

for step, batch in enumerate(data_loader):
if step < steps_to_skip:
continue # fast-forward past already-processed data
steps_to_skip = 0

# Standard training step
self._train_step(batch)
self.global_step += 1

# Periodic checkpoint
if self.global_step % self.checkpoint_interval == 0:
self._save_checkpoint()

# Handle preemption signal
if self._shutdown_requested:
print(f"Shutdown requested at step {self.global_step}")
return

def _train_step(self, batch):
"""Single training step - implement with your model logic."""
raise NotImplementedError

Elastic Training

Standard distributed training requires a fixed number of workers. If one node fails, the entire job dies. Elastic training allows the job to continue with fewer workers.

PyTorch Elastic (torchelastic / torchrun) supports this natively:

# Training script designed for elastic execution
# Launch with: torchrun --min_nodes=4 --max_nodes=8 --nproc_per_node=8 train.py

import torch.distributed.elastic.multiprocessing.errors as errors

@errors.record
def main():
"""
Elastic training: handles node membership changes.
torchrun will restart with updated world_size if nodes join/leave.
"""
dist.init_process_group("nccl")
world_size = dist.get_world_size()
rank = dist.get_rank()

print(f"Elastic training: rank {rank}/{world_size}")

# Load model and restore from checkpoint
trainer = FaultTolerantTrainer(...)
trainer.load_latest_checkpoint()

# torchrun handles rendezvous: when a node fails and is replaced,
# remaining nodes detect the failure, save checkpoint, and restart
# with new world_size. The training script runs again from main().
trainer.train(data_loader)

if __name__ == "__main__":
main()

Monitoring Long Training Runs

A training job running for 7+ days requires active monitoring to detect silent failures: training loss not decreasing, GPU utilization dropping, gradient norms exploding, or memory leaks causing gradual slowdown.

class TrainingMonitor:
"""Monitor training health and alert on anomalies."""

def __init__(self, alert_fn=None):
self.loss_history = []
self.grad_norm_history = []
self.step_time_history = []
self.alert_fn = alert_fn or print

def record_step(self, step: int, loss: float, grad_norm: float, step_time_ms: float):
self.loss_history.append((step, loss))
self.grad_norm_history.append((step, grad_norm))
self.step_time_history.append((step, step_time_ms))

# Check for anomalies every 100 steps
if step % 100 == 0 and step > 200:
self._check_for_anomalies(step, loss, grad_norm, step_time_ms)

def _check_for_anomalies(self, step, loss, grad_norm, step_time_ms):
recent_losses = [l for _, l in self.loss_history[-50:]]
older_losses = [l for _, l in self.loss_history[-200:-100]]

avg_recent = sum(recent_losses) / len(recent_losses)
avg_older = sum(older_losses) / len(older_losses)

# Loss not decreasing
if avg_recent >= avg_older * 0.98: # less than 2% improvement
self.alert_fn(f"WARNING: Loss stagnation at step {step}. "
f"Recent avg: {avg_recent:.4f}, older: {avg_older:.4f}")

# Gradient explosion
if grad_norm > 10.0:
self.alert_fn(f"WARNING: Gradient norm spike at step {step}: {grad_norm:.2f}")

# Training slowdown (memory leak or thermal throttling)
recent_times = [t for _, t in self.step_time_history[-20:]]
older_times = [t for _, t in self.step_time_history[-100:-20]]
avg_recent_time = sum(recent_times) / len(recent_times)
avg_older_time = sum(older_times) / len(older_times)

if avg_recent_time > avg_older_time * 1.15: # 15% slowdown
self.alert_fn(f"WARNING: Step time increased 15%+ at step {step}. "
f"Recent: {avg_recent_time:.0f}ms, older: {avg_older_time:.0f}ms. "
"Check for memory leak or thermal throttling.")

Production Engineering Notes

Use NCCL_DEBUG=WARN during training, not INFO. NCCL INFO logging produces gigabytes of log output per hour and can itself slow training due to I/O overhead. Keep WARN level in production; switch to INFO only when debugging communication issues.

Checkpoint to fast local NVMe, then async-copy to object storage. Saving a 140 GB checkpoint directly to NFS or S3 during training stalls all GPUs during the save. Save first to local NVMe (< 5 seconds), then spawn a background process to upload to S3. The training job continues while the upload happens.

Set training timeout alarms. If no new checkpoint appears within 2× expected checkpoint interval, page the on-call engineer. Hanging NCCL collectives (a GPU stuck in all-reduce with no progress) can freeze the entire cluster silently.


Common Mistakes

:::danger Not testing checkpoint restore before starting a multi-week job Many teams only discover their checkpoint/restore logic is broken when they need to use it - after a GPU failure on day 8. Before launching any training run longer than 4 hours, test: save a checkpoint, kill the job, restore from checkpoint, verify training continues correctly with no loss spike. This 30-minute test prevents countless 2 AM incidents. :::

:::warning Using network storage (NFS/GPFS) for frequent checkpoints Saving checkpoints over a network filesystem adds 5–60 seconds to each save, during which all training GPUs are paused. For a job saving every 30 minutes, this can waste 30+ minutes of GPU time per day. Always save to local NVMe SSD first, then copy to permanent storage asynchronously. :::

:::tip Add "canary" validation runs to detect training instability early Before committing to a 7-day training run, run 500 steps and check: does loss decrease smoothly? Is gradient norm stable? Is GPU utilization above 70%? Running 500 steps takes under 30 minutes and validates your training loop, data pipeline, and cluster configuration before you invest weeks of GPU time. :::


Interview Questions

Q1: Why does GPU cluster training fail so frequently, and how do you engineer for it?

With 500 GPUs at ~0.02% hourly failure rate per GPU, expected time to first failure = 1 / (500 × 0.0002) = 10 hours. Multi-week training runs will experience multiple hardware failures guaranteed. Engineering for this requires: (1) frequent atomic checkpointing to local NVMe + async backup to object storage, (2) automatic restart from latest checkpoint via the job scheduler, (3) elastic training to continue with fewer GPUs while a failed node is replaced, (4) monitoring with alerts for silent hangs (no progress after 2× expected checkpoint interval). Without these, every hardware failure costs hours of GPU time and requires human intervention.

Q2: What is the difference between InfiniBand and Ethernet for GPU clusters, and when does it matter?

InfiniBand (HDR/NDR) provides 200–400 Gbps bandwidth with ~1 µs latency and RDMA capability. Ethernet (RoCE) provides 100–400 Gbps with 2–5 µs latency. For training, what matters is effective bandwidth per GPU pair during all-reduce. InfiniBand's lower latency and native RDMA give 1.5–2× better all-reduce bandwidth efficiency compared to Ethernet at equivalent line rates. For small models or single-node training, this is irrelevant. For multi-node training of 10B+ parameter models where communication overhead is 20–40% of step time, InfiniBand can meaningfully reduce training time.

Q3: Explain the atomic checkpoint save pattern and why it matters.

Writing a 140 GB checkpoint file takes 30–60 seconds even to fast NVMe. If the process dies halfway through writing, you have a corrupt checkpoint that cannot be restored. Atomic checkpointing: write to checkpoint_tmp.pt, then os.rename() to checkpoint_step_N.pt. The rename is atomic at the filesystem level - it either succeeds completely or does not happen. Simultaneously update a checkpoint_latest.pt symlink to point to the new checkpoint. On restart, the training job reads checkpoint_latest.pt which either points to the most recent complete checkpoint or to the previous one if the last write was interrupted.

Q4: Your training loss suddenly spikes at step 8,432 after running smoothly for 2 days. What are the likely causes and how do you investigate?

Likely causes: (1) corrupted gradient (NaN/Inf) propagating through the model - check gradient norm history for a spike at that exact step; (2) bad batch in the training data - log which data shard and batch index was being processed at step 8,432; (3) learning rate scheduler misconfiguration causing an unexpectedly large step; (4) checkpoint restore issue if there was a restart near that step - verify the checkpoint was clean. Investigation: restore from the last good checkpoint (before the spike), reproduce the spike on a single node with debug logging enabled, print gradient norms per layer for the problematic batch to localize where the NaN originated. If it is a data issue, mark that batch as corrupted and skip it.

Q5: How do you handle a training job where one node runs consistently slower than others (straggler problem)?

First, profile the slow node: nvidia-smi dmon -s u to check GPU utilization and temperature; check CPU utilization, memory bandwidth, and network interface statistics. Common causes: (1) thermal throttling - one GPU or CPU running too hot, throttling clock speeds; (2) hardware fault causing a GPU to run slower; (3) network interface degraded (partial link failure giving half-speed). Mitigation: for thermal issues, improve cooling or cap power limits on that node. For hardware faults, migrate the job away from that node and file a hardware ticket. For persistent stragglers in DDP (where the slowest GPU determines each step's duration), use torch.distributed's timeout parameter to detect and fail fast when a node stalls, then restart with that node excluded from the elastic training group.

© 2026 EngineersOfAI. All rights reserved.