Training Jobs on Kubernetes
The 6-Hour Training Job That Died at Hour 6
Your team just spent three weeks engineering a dataset pipeline and designing a new transformer-based fraud detection architecture. The training job is estimated at 7 hours on 4 A100 GPUs. You launch it Tuesday evening and go home. At 11:47pm, your phone buzzes: a Slack notification from the CI bot, "Training job fraud-transformer-v3-run12 has failed."
The job ran on spot (preemptible) nodes to save cost - 40% cheaper than on-demand. At hour 6:04, GCP reclaimed one of the four nodes. The PyTorch process on that node received SIGTERM with a 30-second warning. Your training code has no checkpoint logic, no graceful shutdown handler, and no restart mechanism. The entire 6-hour run is gone. All four nodes idle at 11:47pm, job status: Failed.
This is not an edge case. It is the standard experience of running ML training on cloud infrastructure without fault-tolerant design. Kubernetes provides the Job primitive and the Training Operator provides PyTorchJob and TFJob - but neither of them makes your training code automatically fault-tolerant. You have to design for failure explicitly. This lesson shows you how.
:::tip 🎮 Interactive Playground Visualize this concept: Try the Kubernetes for ML demo on the EngineersOfAI Playground - no code required. :::
Kubernetes Jobs - The Foundation
A Kubernetes Job runs a Pod to completion (exit code 0). Unlike a Deployment, which restarts pods that fail, a Job runs a fixed number of completions and then finishes. This is the right model for training: run the training script, complete, exit.
# Simple training job
apiVersion: batch/v1
kind: Job
metadata:
name: fraud-model-train-v3
namespace: team-fraud
labels:
app: fraud-training
version: "v3"
experiment: "transformer-arch"
spec:
backoffLimit: 2 # retry up to 2 times on failure
activeDeadlineSeconds: 28800 # kill job after 8 hours (safety net)
ttlSecondsAfterFinished: 86400 # auto-delete job 24h after completion
template:
spec:
restartPolicy: Never # for Jobs: Never or OnFailure
containers:
- name: trainer
image: registry.company.com/fraud-trainer:v3.0
command:
- python
- train.py
- --epochs=50
- --batch-size=256
- --learning-rate=3e-4
- --checkpoint-dir=/checkpoints
resources:
requests:
cpu: "16"
memory: "128Gi"
nvidia.com/gpu: 4
limits:
cpu: "16"
memory: "128Gi"
nvidia.com/gpu: 4
volumeMounts:
- name: checkpoints
mountPath: /checkpoints
- name: training-data
mountPath: /data
volumes:
- name: checkpoints
persistentVolumeClaim:
claimName: fraud-training-checkpoints
- name: training-data
persistentVolumeClaim:
claimName: fraud-training-data-pvc
restartPolicy: Never vs restartPolicy: OnFailure:
Never: if the container exits non-zero, a new pod is created (up tobackoffLimittimes). You can inspect the failed pod's logs after the job fails. Use this for debugging.OnFailure: the same pod is restarted in place. Logs from previous attempts are lost. Use only if your training code handles restarts from checkpoint internally.
Checkpoint-Driven Fault Tolerance
The fundamental pattern for fault-tolerant ML training: checkpoint frequently, resume from checkpoint on restart. Without this, any pod failure means restarting from epoch 0.
# train.py - fault-tolerant training loop
import torch
import os
import signal
import sys
from pathlib import Path
CHECKPOINT_DIR = Path(os.environ.get("CHECKPOINT_DIR", "/checkpoints"))
CHECKPOINT_INTERVAL = int(os.environ.get("CHECKPOINT_INTERVAL_EPOCHS", "5"))
checkpoint_requested = False
def handle_sigterm(signum, frame):
"""Spot node reclamation sends SIGTERM 30 seconds before kill."""
global checkpoint_requested
print("SIGTERM received - scheduling emergency checkpoint before exit")
checkpoint_requested = True
signal.signal(signal.SIGTERM, handle_sigterm)
def save_checkpoint(model, optimizer, scheduler, epoch, loss, path):
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"loss": loss,
}
tmp_path = path.with_suffix(".tmp")
torch.save(checkpoint, tmp_path)
tmp_path.rename(path) # atomic rename - prevents corrupt checkpoints
print(f"Checkpoint saved: epoch {epoch}, loss {loss:.4f}")
def load_checkpoint(model, optimizer, scheduler):
checkpoint_path = CHECKPOINT_DIR / "latest.pt"
if not checkpoint_path.exists():
return 0 # start from epoch 0
print(f"Resuming from checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
start_epoch = checkpoint["epoch"] + 1
print(f"Resuming from epoch {start_epoch}")
return start_epoch
def train():
model = FraudTransformer()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
start_epoch = load_checkpoint(model, optimizer, scheduler)
model.cuda()
for epoch in range(start_epoch, 50):
# Training loop
train_loss = run_epoch(model, optimizer, train_loader)
scheduler.step()
# Regular checkpoint
if epoch % CHECKPOINT_INTERVAL == 0:
save_checkpoint(
model, optimizer, scheduler, epoch, train_loss,
CHECKPOINT_DIR / "latest.pt"
)
# Emergency checkpoint on SIGTERM
if checkpoint_requested:
print("Emergency checkpoint triggered by SIGTERM")
save_checkpoint(
model, optimizer, scheduler, epoch, train_loss,
CHECKPOINT_DIR / "emergency.pt"
)
sys.exit(0) # clean exit - Job will retry with backoffLimit
# Final model save
torch.save(model.state_dict(), CHECKPOINT_DIR / "model_final.pt")
print("Training complete.")
CronJobs - Scheduled Retraining
For production models that need periodic retraining (daily on new data, weekly full retrain), use CronJobs:
apiVersion: batch/v1
kind: CronJob
metadata:
name: fraud-model-weekly-retrain
namespace: team-fraud
spec:
schedule: "0 2 * * 0" # every Sunday at 2am UTC
concurrencyPolicy: Forbid # don't start new job if previous still running
successfulJobsHistoryLimit: 3 # keep last 3 successful job records
failedJobsHistoryLimit: 3 # keep last 3 failed job records
jobTemplate:
spec:
backoffLimit: 1
activeDeadlineSeconds: 43200 # 12 hour max
template:
spec:
restartPolicy: Never
containers:
- name: trainer
image: registry.company.com/fraud-trainer:latest-stable
command:
- python
- retrain.py
- --data-start-date=$(date -d "30 days ago" +%Y-%m-%d)
- --output-dir=/model-output
resources:
requests:
nvidia.com/gpu: 4
memory: "128Gi"
limits:
nvidia.com/gpu: 4
memory: "128Gi"
concurrencyPolicy: Forbid is important for training CronJobs. If a training run takes longer than expected (data pipeline issues, larger dataset), you don't want the next scheduled run to start alongside it, competing for the same GPU quota.
PyTorchJob - Distributed Training with the Training Operator
The Kubernetes Training Operator (formerly Kubeflow Training Operator) provides custom resource definitions (CRDs) for distributed ML training. PyTorchJob manages distributed PyTorch training using torchrun, handling pod creation, rendezvous, and failure detection automatically.
apiVersion: kubeflow.org/v1
kind: PyTorchJob
metadata:
name: fraud-transformer-v3
namespace: team-fraud
spec:
pytorchReplicaSpecs:
Master:
replicas: 1
restartPolicy: OnFailure # restart this pod if it crashes
template:
spec:
tolerations:
- key: nvidia.com/gpu
operator: Exists
effect: NoSchedule
containers:
- name: pytorch
image: registry.company.com/fraud-trainer:v3.0
command:
- torchrun
- --nproc_per_node=4 # GPUs per node
- --nnodes=4 # total nodes
- --node_rank=$(RANK) # set by Training Operator
- --master_addr=$(MASTER_ADDR)
- --master_port=23456
- train_distributed.py
- --checkpoint-dir=/checkpoints
resources:
requests:
nvidia.com/gpu: 4
memory: "128Gi"
limits:
nvidia.com/gpu: 4
memory: "128Gi"
volumeMounts:
- name: checkpoints
mountPath: /checkpoints
volumes:
- name: checkpoints
persistentVolumeClaim:
claimName: fraud-training-checkpoints-rwx # ReadWriteMany
Worker:
replicas: 3
restartPolicy: OnFailure
template:
spec:
tolerations:
- key: nvidia.com/gpu
operator: Exists
effect: NoSchedule
containers:
- name: pytorch
image: registry.company.com/fraud-trainer:v3.0
command:
- torchrun
- --nproc_per_node=4
- --nnodes=4
- --node_rank=$(RANK)
- --master_addr=$(MASTER_ADDR)
- --master_port=23456
- train_distributed.py
- --checkpoint-dir=/checkpoints
resources:
requests:
nvidia.com/gpu: 4
memory: "128Gi"
limits:
nvidia.com/gpu: 4
memory: "128Gi"
volumeMounts:
- name: checkpoints
mountPath: /checkpoints
volumes:
- name: checkpoints
persistentVolumeClaim:
claimName: fraud-training-checkpoints-rwx
The Training Operator injects environment variables into each pod:
MASTER_ADDR: DNS name of the master podMASTER_PORT: port for rendezvous (default 23456)RANK: global rank of this pod (0 for master, 1, 2, 3 for workers)WORLD_SIZE: total number of processes (4 nodes × 4 GPUs = 16)
Distributed Training Code Pattern
# train_distributed.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
def setup():
dist.init_process_group(backend="nccl")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
return local_rank
def train():
local_rank = setup()
rank = dist.get_rank()
world_size = dist.get_world_size()
model = FraudTransformer().cuda(local_rank)
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
start_epoch = load_checkpoint(model, optimizer, rank) # rank 0 loads, broadcasts
# Use DistributedSampler to partition data across workers
sampler = torch.utils.data.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
)
loader = DataLoader(dataset, sampler=sampler, batch_size=256)
for epoch in range(start_epoch, 50):
sampler.set_epoch(epoch) # ensures different shuffle each epoch
train_loss = run_epoch(model, optimizer, loader)
# Only rank 0 saves checkpoints
if rank == 0 and epoch % 5 == 0:
save_checkpoint(model, optimizer, epoch, train_loss)
dist.destroy_process_group()
Fault Tolerance - What Happens When a Worker Pod Dies
With restartPolicy: OnFailure, if a worker pod crashes, Kubernetes restarts it. PyTorch's torchrun with elastic training support (--min-nodes, --max-nodes) can continue training with fewer workers, then absorb the restarted worker when it comes back.
For non-elastic training (fixed nnodes), when any worker pod dies, all workers stop waiting for rendezvous, and the PyTorchJob controller restarts all pods simultaneously. The training resumes from the last checkpoint - which is why the 5-epoch checkpoint interval is critical.
:::tip TorchElastic for Spot Node Resilience
Use torchrun with elastic training parameters for spot-node workloads:
command:
- torchrun
- --nproc_per_node=4
- --min-nodes=2 # can continue with as few as 2 nodes
- --max-nodes=4 # up to 4 nodes
- --rdzv_backend=etcd # or c10d for K8s-native rendezvous
- --rdzv_endpoint=$(MASTER_ADDR):2379
- train_distributed.py
When a spot node is reclaimed, the remaining 3 nodes continue training with 3/4 of the compute. When a new node is provisioned, it joins the elastic group. Training throughput varies but never stops. :::
Spot Node Handling with Node Lifecycle Labels
# Pod spec for spot-tolerant training
spec:
tolerations:
- key: cloud.google.com/gke-spot # GKE spot
operator: Exists
effect: NoSchedule
- key: kubernetes.azure.com/scalesetpriority # AKS spot
operator: Equal
value: "spot"
effect: NoSchedule
nodeSelector:
cloud.google.com/gke-spot: "true" # target spot nodes for cost savings
Production Notes
Job naming for reproducibility: Include experiment parameters in the Job name. fraud-transformer-lr3e4-bs256-epoch50 is more useful than fraud-training-job-12 when reviewing cluster history.
Resource cleanup: Use ttlSecondsAfterFinished: 86400 on all Jobs. Without it, completed/failed Jobs accumulate in the cluster and clutter kubectl get jobs output. 24 hours gives the team time to inspect logs before automatic cleanup.
Job parallelism for hyperparameter search: The Job spec.parallelism field runs multiple pods concurrently for the same Job. This is useful for embarrassingly parallel hyperparameter grid searches:
spec:
parallelism: 8 # run 8 experiments simultaneously
completions: 16 # total of 16 experiments (8 at a time)
completionMode: Indexed # each pod gets a unique index (JOB_COMPLETION_INDEX env var)
Each pod reads its JOB_COMPLETION_INDEX to pick a different hyperparameter configuration from a pre-defined grid.
Common Mistakes
:::danger No Checkpoint Logic in Long Training Jobs Running a 6+ hour training job without checkpoint save/load logic is the most expensive mistake in ML on Kubernetes. Spot node preemption, pod eviction from ResourceQuota enforcement, node hardware failures, and Out-of-Memory kills all terminate pods unexpectedly. Without checkpoints, every failure means restarting from epoch 0. Implement checkpoint save every 5–10 epochs minimum. On a 50-epoch job with 5-epoch checkpoints, the worst-case restart loses 5 epochs of work, not 50. :::
:::warning Using restartPolicy: Always for Training Jobs
restartPolicy: Always is for Deployments (long-running services). For Jobs, use Never or OnFailure. With Always, a Job pod that completes successfully will be restarted immediately, creating an infinite loop of training runs. Use OnFailure if you want Kubernetes to restart the pod on crashes, and Never if you want manual retry control via backoffLimit.
:::
:::warning Shared PVC With ReadWriteOnce for Multi-Node Training
ReadWriteOnce PVCs can only be mounted by pods on a single node. If you use a ReadWriteOnce PVC for checkpoint storage in a distributed training job with pods on multiple nodes, only the master pod's node can mount the PVC - all worker pods will fail to mount and stay in Pending. Use ReadWriteMany (NFS, AWS EFS, GCP Filestore) for any storage shared across multiple training pods.
:::
Interview Q&A
Q1: What is the difference between a Kubernetes Job and a Deployment, and when would you use each for ML?
A Deployment runs pods indefinitely, restarting them when they fail, and is designed for stateless long-running services like model serving APIs. A Job runs pods until a certain number of completions are reached (pods exit with code 0), then the Job is considered complete. Use Deployments for ML serving (the model server should always be running). Use Jobs for ML training (run training once, complete, done). For periodic retraining, wrap a Job in a CronJob to trigger it on a schedule. The key behavioral difference: Deployments restart failed pods indefinitely; Jobs retry up to backoffLimit times then fail permanently.
Q2: A distributed PyTorchJob running on 4 nodes fails 6 hours into training when one spot node is reclaimed. How do you design the system to recover automatically?
Three layers of defense: (1) Checkpoint every N epochs using atomic file writes to a ReadWriteMany PVC. Only rank 0 writes. On restart, all workers load the checkpoint and resume from the last completed epoch. (2) Handle SIGTERM in the training loop - spot nodes send SIGTERM 30 seconds before termination. Save an emergency checkpoint to the PVC in that window. Exit with code 0 so the Job controller sees it as a completed completion, not a failure. (3) Use TorchElastic with --min-nodes=2 --max-nodes=4 for elastic training - if a node is lost, remaining workers continue training, and the replacement node joins automatically. With all three layers, a 6-hour job losing a node at hour 6 resumes from epoch 28 (assuming 5-epoch checkpoints at a roughly 30-epoch total run), losing only one checkpoint interval of compute.
Q3: How does the Kubeflow Training Operator's PyTorchJob differ from a raw Kubernetes Job for distributed training?
A raw Kubernetes Job can run multiple pods in parallel but provides no mechanism for distributed training coordination. You would have to manually manage rendezvous (how workers find each other), set MASTER_ADDR/MASTER_PORT environment variables, ensure all pods start simultaneously for rendezvous, and handle pod failure and restart coordination. PyTorchJob (Training Operator) automates all of this: it creates a headless Service for worker discovery, injects MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE environment variables automatically, provides separate restart policies for master and worker pods, and integrates with elastic training for fault-tolerant rendezvous. It also provides a higher-level status (PyTorchJobSucceeded, PyTorchJobFailed) that's more meaningful than raw pod states for observability.
Q4: Walk through the checkpoint save and load pattern for fault-tolerant distributed PyTorch training.
On save (called periodically from rank 0 only): gather model.module.state_dict() (unwrap DDP wrapper), optimizer state_dict, scheduler state_dict, and current epoch into a dict. Write to a temporary file first (checkpoint.tmp), then atomically rename to checkpoint.pt. The atomic rename prevents partially-written checkpoints from being read by a restarting worker. On load (called at startup by all ranks): rank 0 checks if checkpoint.pt exists, loads it, restores model, optimizer, and scheduler state, then broadcasts the start_epoch to all ranks via dist.broadcast. Workers that don't have rank 0's state get it through the broadcast. Return the start_epoch to the training loop. This pattern handles the case where some workers find the checkpoint and others don't (e.g., they were on a different node with a separate cache).
Q5: What is the ttlSecondsAfterFinished field on a Kubernetes Job and why is it important for ML training pipelines?
ttlSecondsAfterFinished specifies how long Kubernetes keeps the Job and its pods after they complete or fail. Without it, all completed/failed training Jobs accumulate in the cluster indefinitely. Over time, a busy ML cluster can have hundreds or thousands of dead Job objects, making kubectl get jobs useless and consuming etcd storage. Setting ttlSecondsAfterFinished: 86400 (24 hours) gives your team enough time to inspect logs and results from recent runs, then automatically cleans up. For critical training runs where you need the Job object for post-mortem analysis, use a longer TTL (7 days) or export logs to a persistent log aggregation system (Loki, Elasticsearch) before the TTL expires.
