Skip to main content

Training Cost Optimization

From 50Kto50K to 18K Without Losing Accuracy

The production training pipeline for the e-commerce product ranking model was running on a fixed monthly schedule: every two weeks, a full training run on a fleet of 8 A100 GPUs, 72 hours of runtime, at AWS on-demand pricing. Cost per run: 37,748.Tworunspermonth:37,748. Two runs per month: 75,496/month, $905,952/year.

A new ML platform engineer joined the team and asked a question nobody had asked before: "Has anyone ever measured whether we actually need 72 hours? And do we need to use on-demand pricing?"

Nobody had asked because the pipeline had been running this way for 18 months. It worked. It was automated. It was not broken. But "not broken" is not the same as "efficient."

Three weeks of analysis and testing produced a redesigned training pipeline. The same model, trained to the same target evaluation AUC of 0.893, now cost $18,200 per run - a 52% reduction. The changes:

  1. Spot instances with checkpoint-and-restart: 37,74837,748 → 13,512 per run (on-demand to spot pricing)
  2. Mixed precision (BF16): Training time reduced from 72h to 48h, saving 33% on spot cost
  3. Gradient checkpointing: Memory savings allowed a 25% larger batch size, further reducing runtime
  4. Early stopping: 4 of the 24 monthly runs (that historically didn't improve the model) were terminated at 60% runtime

The annual cost dropped from 905,952to905,952 to 436,800 - savings of $469,152/year with no accuracy loss.

This lesson shows you how to replicate this analysis.


:::tip 🎮 Interactive Playground Visualize this concept: Try the ML Cost & Unit Economics demo on the EngineersOfAI Playground - no code required. :::

Why This Exists: Training Runs Are Expensive By Default

Training is the most expensive phase of the ML lifecycle. A single large model training run can cost 50,00050,000–5,000,000. Even small models trained nightly at 500/runcost500/run cost 180,000/year. Teams that train multiple model variants for experiments multiply these costs quickly.

The default configuration of a training pipeline - on-demand instances, single precision, no early stopping, full recompute of activations - is optimized for engineering convenience and simplicity, not cost. Each default has a cost-optimized alternative that requires some engineering investment but pays back rapidly.

The training cost optimization problem has three levers:

  1. Price per hour: How much does each GPU-hour cost? (Spot vs. reserved vs. on-demand)
  2. Hours per run: How long does training take? (Mixed precision, batch size, efficient data loading)
  3. Runs to convergence: How many runs are needed to reach target performance? (Compute-optimal training, hyperparameter efficiency)

Optimizing any one of these levers reduces total cost. Optimizing all three achieves multiplicative reductions.


Historical Context

Training cost optimization became a serious engineering concern with the rise of large language models. The observation that training GPT-3 cost approximately $4.6M (Lambda Labs estimate, 2020) focused attention on whether training runs were economically efficient.

Mixed precision training was introduced by Micikevicius et al. (NVIDIA, 2018) and became practical with native hardware support on V100 and A100 GPUs. FP16 training using Tensor Cores can achieve 2–8× throughput improvement over FP32 on compatible hardware.

BFloat16 (BF16) was introduced with Google's TPUs and later supported on NVIDIA A100s. BF16 has the same dynamic range as FP32 (8 exponent bits) but lower precision (7 mantissa bits vs. 23), making it more numerically stable than FP16 for training without requiring manual loss scaling.

Gradient checkpointing (also called activation checkpointing) was formalized by Chen et al. (2016) as a memory-compute trade-off: recompute activations during the backward pass instead of storing them, reducing memory by O(sqrt(n)) at the cost of a ~30% compute overhead.

The Chinchilla paper (Hoffmann et al., DeepMind, 2022) demonstrated empirically that most large language models were under-trained - they had too many parameters for their training token budget. The optimal allocation follows a simple rule: scale training tokens proportionally to parameters.


Core Concepts

Spot Instances with Fault-Tolerant Training

Spot (AWS) and Preemptible (GCP) instances offer 60–80% discounts over on-demand pricing in exchange for the possibility of interruption with 2 minutes notice. For most training workloads, this is an acceptable trade-off.

The requirement: the training job must be able to resume from a checkpoint when interrupted. Without checkpoint-and-restart, a spot interruption loses all progress and the training run must start over - potentially turning a 60% savings into a net loss.

import torch
import os
from pathlib import Path
import signal
import sys

class CheckpointedTrainer:
"""
Training loop with checkpoint-and-restart support for spot instances.
Saves checkpoints periodically and resumes automatically on restart.
"""
def __init__(
self,
model,
optimizer,
scheduler,
checkpoint_dir: str,
checkpoint_interval_steps: int = 500,
max_checkpoints_to_keep: int = 3
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.checkpoint_interval = checkpoint_interval_steps
self.max_checkpoints = max_checkpoints_to_keep

self.global_step = 0
self.epoch = 0
self.best_val_metric = float("-inf")

# Register SIGTERM handler for graceful spot interruption
signal.signal(signal.SIGTERM, self._handle_sigterm)

def _handle_sigterm(self, signum, frame):
"""Save checkpoint on spot interruption signal."""
print("Received SIGTERM - spot interruption detected. Saving checkpoint...")
self._save_checkpoint(is_emergency=True)
sys.exit(0)

def _save_checkpoint(self, is_emergency: bool = False):
checkpoint = {
"global_step": self.global_step,
"epoch": self.epoch,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"best_val_metric": self.best_val_metric,
}

if is_emergency:
path = self.checkpoint_dir / "emergency_checkpoint.pt"
else:
path = self.checkpoint_dir / f"checkpoint_step_{self.global_step:08d}.pt"

torch.save(checkpoint, path)
print(f"Checkpoint saved: {path}")

# Clean up old checkpoints
if not is_emergency:
self._cleanup_old_checkpoints()

def _cleanup_old_checkpoints(self):
checkpoints = sorted(
self.checkpoint_dir.glob("checkpoint_step_*.pt"),
key=lambda p: int(p.stem.split("_")[-1])
)
for old_ckpt in checkpoints[:-self.max_checkpoints]:
old_ckpt.unlink()

def load_latest_checkpoint(self) -> bool:
"""Load the most recent checkpoint if one exists. Returns True if loaded."""
# Check for emergency checkpoint first
emergency = self.checkpoint_dir / "emergency_checkpoint.pt"
if emergency.exists():
checkpoint = torch.load(emergency, map_location="cpu")
print(f"Resuming from emergency checkpoint at step {checkpoint['global_step']}")
else:
checkpoints = sorted(self.checkpoint_dir.glob("checkpoint_step_*.pt"))
if not checkpoints:
return False
checkpoint = torch.load(checkpoints[-1], map_location="cpu")
print(f"Resuming from step {checkpoint['global_step']}")

self.global_step = checkpoint["global_step"]
self.epoch = checkpoint["epoch"]
self.best_val_metric = checkpoint["best_val_metric"]
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
return True

def train_step(self, batch) -> float:
self.model.train()
loss = self.model(**batch).loss
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1

if self.global_step % self.checkpoint_interval == 0:
self._save_checkpoint()

return loss.item()

Mixed Precision Training: FP16 and BF16

Mixed precision training keeps model weights in FP32 for numerical stability during gradient updates but performs the forward and backward passes in FP16 or BF16. Modern GPUs (V100, A100, H100) have Tensor Core hardware that executes FP16/BF16 matrix multiplications 4–8× faster than FP32 equivalents.

FP16 vs. BF16 trade-off:

  • FP16 has 10 mantissa bits (higher precision) but only 5 exponent bits (smaller dynamic range → overflow/underflow risk → requires loss scaling)
  • BF16 has 7 mantissa bits (lower precision) but 8 exponent bits (same dynamic range as FP32 → numerically stable, no loss scaling needed)
import torch
from torch.cuda.amp import autocast, GradScaler

class MixedPrecisionTrainer:
"""
Training loop with automatic mixed precision (AMP).
Supports both FP16 (with loss scaling) and BF16 (no scaling needed).
"""
def __init__(self, model, optimizer, dtype: str = "bfloat16"):
self.model = model
self.optimizer = optimizer
self.dtype = getattr(torch, dtype)

# Loss scaler only needed for FP16
self.use_scaler = (dtype == "float16")
self.scaler = GradScaler() if self.use_scaler else None

# Log throughput for benchmarking
self._step_times = []

def train_step(self, batch) -> float:
import time
start = time.perf_counter()

self.optimizer.zero_grad()

# Forward pass in reduced precision
with autocast(dtype=self.dtype):
output = self.model(**batch)
loss = output.loss

# Backward pass
if self.use_scaler:
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
# BF16: no scaling needed
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()

elapsed = time.perf_counter() - start
self._step_times.append(elapsed)

return loss.item()

def throughput_report(self) -> dict:
if not self._step_times:
return {}
avg_step_time = sum(self._step_times) / len(self._step_times)
return {
"avg_step_time_ms": avg_step_time * 1000,
"steps_per_second": 1 / avg_step_time,
}

Gradient Checkpointing

Gradient checkpointing trades compute for memory by recomputing activations during the backward pass instead of storing them. This allows training with larger batch sizes on the same hardware, which can improve convergence and reduce total training time.

The memory reduction is O(sqrt(n)) for a sequential model with n layers. For a 24-layer transformer, memory reduces by ~5×. The compute overhead is approximately 30% (one additional forward pass per backward pass).

from torch.utils.checkpoint import checkpoint_sequential, checkpoint

class MemoryEfficientTransformer(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.layers = torch.nn.ModuleList([
TransformerLayer(config) for _ in range(config.num_layers)
])
self.use_gradient_checkpointing = False

def enable_gradient_checkpointing(self):
self.use_gradient_checkpointing = True
print("Gradient checkpointing enabled. Memory reduced ~5×, compute +30%")

def forward(self, hidden_states, attention_mask=None):
if self.use_gradient_checkpointing and self.training:
# Use checkpointing for memory efficiency during training
def create_custom_forward(layer):
def custom_forward(*inputs):
return layer(*inputs)
return custom_forward

for layer in self.layers:
hidden_states = checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
use_reentrant=False # recommended for modern PyTorch
)
else:
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask)

return hidden_states


# In practice, use Hugging Face's built-in gradient checkpointing
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
model.gradient_checkpointing_enable() # one line - enables for the full model

# Memory comparison (approximate, BERT-base 12 layers, seq_len=512, batch_size=32)
# Without checkpointing: ~24 GB activation memory
# With checkpointing: ~5 GB activation memory
# Allows: larger batch size (32 → 64 → 128) → fewer steps → lower cost

Compute-Optimal Training: Chinchilla Scaling

The Chinchilla paper (Hoffmann et al., DeepMind, 2022) established that there is an optimal allocation of compute between model parameters and training tokens. Their key finding: for every doubling of model parameters, you should double the training tokens. Most models at the time of the paper were significantly under-trained.

The practical implication for cost optimization: if you are training a large model with a limited compute budget, you may get better performance by training a smaller model for longer than by training the larger model briefly.

The Chinchilla scaling law:

Nopt0.2536C0.4995,Dopt1.7424C0.5005N_{\text{opt}} \approx 0.2536 \cdot C^{0.4995}, \quad D_{\text{opt}} \approx 1.7424 \cdot C^{0.5005}

where NN is the optimal number of parameters, DD is the optimal number of training tokens, and CC is the total compute budget in FLOPs.

def chinchilla_optimal_allocation(compute_flops: float) -> dict:
"""
Compute the Chinchilla-optimal model size and training token count
for a given compute budget.

compute_flops: total FLOPs budget (e.g., 1e21 for ~$50K on A100s)
Returns: optimal parameters and training tokens
"""
# Chinchilla scaling coefficients
optimal_params = 0.2536 * (compute_flops ** 0.4995)
optimal_tokens = 1.7424 * (compute_flops ** 0.5005)

# A100 FLOPs/second for BF16 training (approximate: 312 TFLOPS, ~30% efficiency)
a100_effective_flops_per_sec = 312e12 * 0.30
a100_hours = compute_flops / a100_effective_flops_per_sec / 3600

# Spot cost estimate (us-east-1 p4d.24xlarge: $9.83/h spot)
spot_cost_usd = a100_hours * 9.83

return {
"compute_flops": compute_flops,
"optimal_params_billions": optimal_params / 1e9,
"optimal_tokens_billions": optimal_tokens / 1e9,
"estimated_a100_hours": a100_hours,
"estimated_spot_cost_usd": spot_cost_usd,
"compute_efficiency_note": (
"Train a model this size for this many tokens to maximize performance/cost"
)
}

# Example: $50K compute budget
budget_flops = 50_000 / 9.83 * 3600 * 8 * 312e12 * 0.30 # rough conversion
result = chinchilla_optimal_allocation(budget_flops)
print(f"Optimal model size: {result['optimal_params_billions']:.1f}B parameters")
print(f"Optimal training: {result['optimal_tokens_billions']:.0f}B tokens")

Efficient Data Loading

Slow data loading can starve GPU compute - the GPU sits idle waiting for the next batch. This wastes compute budget.

from torch.utils.data import DataLoader, Dataset
import torch.multiprocessing as mp

def create_efficient_dataloader(
dataset: Dataset,
batch_size: int,
num_workers: int = None, # default: 4× GPU count
prefetch_factor: int = 4, # batches to prefetch per worker
pin_memory: bool = True # faster CPU→GPU transfer
) -> DataLoader:
"""
Create a DataLoader optimized for GPU training throughput.
"""
if num_workers is None:
gpu_count = torch.cuda.device_count()
num_workers = max(4, gpu_count * 4)

return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=pin_memory and torch.cuda.is_available(),
drop_last=True, # avoid small final batch overhead
persistent_workers=True # keep workers alive between epochs
)

Distributed Training Overhead

Distributed training (across multiple GPUs or nodes) introduces communication overhead. The efficiency of a distributed training run is measured by the scaling efficiency: how close the throughput (samples/second) is to linear with the number of GPUs.

A 4-GPU training run with 85% scaling efficiency trains at 3.4× the speed of a single GPU, not 4×. The 15% overhead comes from gradient synchronization (all-reduce operations) between GPUs.

def estimate_distributed_efficiency(
single_gpu_throughput: float, # samples/second on 1 GPU
n_gpus: int,
inter_gpu_bandwidth_gbps: float = 600, # NVLink: 600 Gbps; PCIe: ~64 Gbps
model_params_billion: float = 1.0,
batch_size_per_gpu: int = 32
) -> dict:
"""
Estimate distributed training efficiency and overhead.
"""
# Gradient size = 4 bytes per parameter (FP32 gradients)
gradient_size_bytes = model_params_billion * 1e9 * 4
gradient_size_gb = gradient_size_bytes / 1e9

# All-reduce communication time (ring all-reduce: 2 * (n-1)/n * gradient_size / bandwidth)
bandwidth_bytes_per_sec = inter_gpu_bandwidth_gbps * 1e9 / 8
allreduce_time_sec = 2 * ((n_gpus - 1) / n_gpus) * gradient_size_bytes / bandwidth_bytes_per_sec

# Compute time per step
batch_size_total = batch_size_per_gpu * n_gpus
compute_time_per_step = batch_size_total / (single_gpu_throughput * n_gpus)

# Scaling efficiency
communication_overhead = allreduce_time_sec / (compute_time_per_step + allreduce_time_sec)
scaling_efficiency = 1 - communication_overhead

actual_throughput = single_gpu_throughput * n_gpus * scaling_efficiency

return {
"n_gpus": n_gpus,
"single_gpu_throughput": single_gpu_throughput,
"theoretical_throughput": single_gpu_throughput * n_gpus,
"actual_throughput": actual_throughput,
"scaling_efficiency_pct": scaling_efficiency * 100,
"communication_overhead_pct": communication_overhead * 100,
"allreduce_time_ms": allreduce_time_sec * 1000,
"recommendation": (
"Good scaling efficiency (>80%)" if scaling_efficiency > 0.8 else
"Moderate scaling - consider gradient compression" if scaling_efficiency > 0.6 else
"Poor scaling - use gradient checkpointing or reduce model size per GPU"
)
}

# NVLink-connected A100s (efficient)
nvlink_efficiency = estimate_distributed_efficiency(
single_gpu_throughput=850, # samples/sec on 1 A100
n_gpus=8,
inter_gpu_bandwidth_gbps=600, # NVLink A100
model_params_billion=1.3
)
print(f"NVLink 8xA100: {nvlink_efficiency['scaling_efficiency_pct']:.1f}% efficient")

# PCIe-connected T4s (inefficient for large models)
pcie_efficiency = estimate_distributed_efficiency(
single_gpu_throughput=250,
n_gpus=4,
inter_gpu_bandwidth_gbps=64, # PCIe bandwidth
model_params_billion=1.3
)
print(f"PCIe 4xT4: {pcie_efficiency['scaling_efficiency_pct']:.1f}% efficient")

The 50Kto50K to 18K Breakdown

OptimizationBeforeAfterSavings
Spot instance pricing$37,748 (on-demand)$13,512 (spot)$24,236 (-64%)
Mixed precision (BF16)72 hours48 hours33% of spot cost
Larger batch size (gradient checkpointing)48h40h17% additional
Early stopping (4 runs/month)24 runs × 40h20 runs × 40h~$4,512
Total$75,496/month$21,800/month$53,696/month (-71%)

Common Mistakes

:::danger Not implementing checkpoint-and-restart before using spot instances Using spot instances without checkpoint-and-restart support means a single interruption restarts your entire training run from epoch 0. If your 48-hour training run is interrupted at hour 40 and restarts from scratch, you don't save money - you spend more (40 hours lost plus the new complete run). Always implement checkpointing before enabling spot. :::

:::warning Using FP16 instead of BF16 on A100s without loss scaling FP16 has limited dynamic range and will produce NaN losses on some models without a loss scaler. BF16, supported on A100s and H100s, has the same dynamic range as FP32 and doesn't require a loss scaler. If you have A100 or newer hardware, prefer BF16 for simplicity and stability. Use FP16 only on V100s and T4s that don't support BF16. :::

:::tip Profile GPU utilization before optimizing training speed Before spending engineering time on mixed precision and gradient checkpointing, profile your training loop. If GPU utilization is below 70%, the bottleneck may be data loading or preprocessing rather than compute. Optimizing compute when data loading is the bottleneck produces no throughput improvement. :::


Interview Q&A

Q: How would you reduce the cost of a training run that currently costs $50,000?

A: Work through the three levers in order of impact. First, compute price per hour: switch from on-demand to spot instances with checkpoint-and-restart. This typically reduces cost by 60–80% immediately. Second, compute hours: enable mixed precision (BF16 on A100s, FP16 elsewhere) for 2–4× throughput improvement. Enable gradient checkpointing to allow larger batch sizes, which reduces total steps to convergence. Ensure data loading is not the bottleneck (profile GPU utilization - it should be above 80%). Third, training runs needed: implement early stopping based on validation metrics to terminate runs that aren't converging. Use learning rate warmup and proper scheduling to ensure you converge efficiently. Combine these and a 50Krunoftenreducesto50K run often reduces to 15–20K without any accuracy loss.

Q: What is Chinchilla scaling and how does it affect training cost decisions?

A: The Chinchilla paper (Hoffmann et al., 2022) showed that for a fixed compute budget, the optimal strategy is to train a smaller model on more data, rather than a larger model on less data. The optimal model size scales as the square root of the compute budget; the optimal number of training tokens also scales as the square root. The practical implication: if you're training a 7B parameter model for 500B tokens, you could potentially train a 3B parameter model for 1T tokens with the same compute budget and get comparable or better performance. Knowing this lets you make economically rational decisions: instead of always defaulting to larger models, choose model size and training duration together to maximize performance per dollar.

Q: What is gradient checkpointing and when would you enable it?

A: Gradient checkpointing recomputes activations during the backward pass instead of storing them in memory. It reduces activation memory by O(sqrt(n)) for n layers, at the cost of ~30% more compute (one additional forward pass). Enable it when: your model doesn't fit in GPU memory at the target batch size, you want to increase batch size to improve convergence or throughput, or you're training on GPUs with limited memory (V100 16GB, T4 16GB) and can't afford to store all activations. Don't enable it when memory is not a constraint - it adds compute overhead for no benefit. In practice, enable it for models with more than 300M parameters being trained with batch sizes above 16.

Q: What is the scaling efficiency of distributed training and why does it matter for cost?

A: Scaling efficiency is the ratio of actual training throughput to theoretical linear throughput. An 8-GPU run with 85% efficiency trains at 6.8× single-GPU speed, not 8×. The gap comes from gradient synchronization (all-reduce) overhead between GPUs. For NVLink-connected GPUs (A100 SXM), all-reduce is fast and efficiency is 85–95%. For PCIe-connected GPUs or multi-node training over Ethernet, efficiency drops to 50–75%. Efficiency matters for cost because you pay for all 8 GPUs regardless. At 60% efficiency, you're paying for 8 GPUs but getting 4.8 GPU equivalents of useful work - the other 3.2 GPUs are paying for communication overhead. If single-node efficiency is poor, using fewer GPUs per job and running more jobs in parallel may be more cost-effective than scaling up per job.

© 2026 EngineersOfAI. All rights reserved.