Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the Spot Instances for ML demo on the EngineersOfAI Playground - no code required. :::

Training Cost Optimization

The $50K Training Run

The model was a 3B parameter transformer trained on a proprietary dataset. The first cost estimate came back at $50,000 - two months of compute budget, gone in one training run. The VP of ML signed off reluctantly, with a clear message: "This had better work, and we'd better figure out how to do it cheaper next time."

The run took 18 days on 32 A100s. The model achieved the target validation metrics. But when the team sat down to audit what had actually happened, the results were uncomfortable. GPU utilization averaged 52% - nearly half the compute had been wasted waiting on data loading and CPU preprocessing. The optimizer was AdamW in full float32, burning memory that forced a smaller batch size, which hurt both efficiency and convergence. Three checkpoint restarts after preemptions had wasted roughly 18 GPU-hours each time. And nobody had checked whether 3B parameters was actually the compute-optimal choice for the dataset size - it had just been the biggest model they could train in the available time.

Four weeks later, with a proper optimization plan, the team ran the same experiment: same dataset, same validation targets, same quality bar. The bill was $11,800. Not because they cut corners - the model performed slightly better than the first run - but because they applied systematic optimization at every layer of the training pipeline.

This lesson teaches that systematic process. Each technique is explained with its cost impact, implementation detail, and the tradeoffs you need to understand before applying it.


Why Training Cost is Different from Inference Cost

Inference cost is continuous - every request costs money, forever. Training cost is episodic - each run has a fixed price tag, and the goal is to minimize that price while achieving the target outcome (a trained model of sufficient quality).

This episodic nature creates a different set of optimization targets:

  • Utilization: Maximize GPU utilization per dollar
  • Convergence efficiency: Reach target performance in fewer steps
  • Compute-optimal configuration: Choose model size and data size correctly
  • Fault tolerance: Avoid losing work on preemptions without expensive redundancy
  • Memory efficiency: Fit larger effective batch sizes within GPU memory budget

The interaction between these factors is complex. Optimizing one often affects others - for example, reducing memory footprint (via gradient checkpointing) allows larger batches, which improves convergence speed, which may allow fewer total steps. A systematic approach addresses all layers together.


Historical Context

The modern understanding of training cost optimization coalesced around 2020–2022. Before the LLM era, training runs were relatively small and cost optimization was mostly about GPU utilization. The GPT-3 paper (Brown et al., 2020) changed the conversation by making training costs explicit - 175B parameters, 3.14 × 10²³ FLOPs - and forcing the community to think about cost as a first-class design constraint.

The Chinchilla paper (Hoffmann et al., 2022, DeepMind) was the most important contribution: it showed that most large language models at the time were undertrained - they had too many parameters for their compute budget and should have been trained on much more data. This shifted the entire field toward compute-optimal training.

Mixed precision training was developed at NVIDIA and popularized through Apex and later PyTorch AMP starting around 2018. Gradient checkpointing (Chen et al., 2016) existed earlier but became mainstream as models grew beyond GPU memory limits. The 8-bit Adam optimizer (Dettmers et al., 2022) and Adafactor (Shazeer & Stern, 2018) represent the frontier of memory-efficient optimization.


Technique 1: Spot Instance Strategy

The Opportunity

Spot instances (AWS) or preemptible VMs (GCP) are spare cloud capacity sold at 60–90% discount. An A100 that costs 3.06/hrondemandcosts3.06/hr on-demand costs 0.90–$1.20/hr on spot. The catch: they can be reclaimed with 2 minutes' notice (AWS) or 30 seconds' notice (GCP) when demand exceeds supply.

For training jobs - which can run for hours or days - spot interruptions are the key risk. The mitigation is checkpoint-based fault tolerance: save model state frequently enough that a preemption only loses a small amount of work.

Cost Impact Calculation

For an 18-day run on 32 A100s:

  • On-demand: 32 \times 3.06 \times 18 \times 24 = \42,577$
  • Spot at 60% discount: 32 \times 1.22 \times 18 \times 24 = \16,967$
  • Savings: \25,610$ (60%)

Even with two checkpoint restarts losing 2 hours each, total lost compute is 4 hours = 32 \times 1.22 \times 4 = \157$. Spot is still 60% cheaper.

Implementation: Checkpoint-Based Fault Tolerance

import os
import signal
import torch
from pathlib import Path
from datetime import datetime

class SpotInstanceTrainer:
"""
Training loop with spot instance fault tolerance.
Checkpoints every N steps and handles SIGTERM gracefully.
"""

def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
checkpoint_dir: str,
checkpoint_every_steps: int = 500,
):
self.model = model
self.optimizer = optimizer
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_every_steps = checkpoint_every_steps
self.global_step = 0
self._shutdown_requested = False

# Register SIGTERM handler (AWS sends this 2 min before preemption)
signal.signal(signal.SIGTERM, self._handle_sigterm)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)

def _handle_sigterm(self, signum, frame):
"""Emergency checkpoint on spot preemption signal."""
print("SIGTERM received - saving emergency checkpoint...")
self._save_checkpoint(emergency=True)
self._shutdown_requested = True

def _save_checkpoint(self, emergency: bool = False):
suffix = "emergency" if emergency else f"step-{self.global_step}"
path = self.checkpoint_dir / f"checkpoint-{suffix}.pt"

torch.save({
"global_step": self.global_step,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"timestamp": datetime.utcnow().isoformat(),
}, path)

# Keep only last 3 checkpoints to manage storage costs
checkpoints = sorted(self.checkpoint_dir.glob("checkpoint-step-*.pt"))
for old_ckpt in checkpoints[:-3]:
old_ckpt.unlink()
print(f"Removed old checkpoint: {old_ckpt}")

print(f"Saved checkpoint: {path}")

def load_latest_checkpoint(self) -> int:
"""Resume from latest checkpoint. Returns starting step."""
checkpoints = sorted(
list(self.checkpoint_dir.glob("checkpoint-step-*.pt")) +
list(self.checkpoint_dir.glob("checkpoint-emergency.pt"))
)

if not checkpoints:
return 0

latest = checkpoints[-1]
state = torch.load(latest, map_location="cpu")
self.model.load_state_dict(state["model_state_dict"])
self.optimizer.load_state_dict(state["optimizer_state_dict"])
self.global_step = state["global_step"]
print(f"Resumed from step {self.global_step} ({latest})")
return self.global_step

def train_step(self, batch) -> float:
"""Single training step. Returns loss value."""
self.optimizer.zero_grad()
loss = self.model(**batch).loss
loss.backward()
self.optimizer.step()

self.global_step += 1

# Periodic checkpoint
if self.global_step % self.checkpoint_every_steps == 0:
self._save_checkpoint()

return loss.item()

def should_stop(self) -> bool:
return self._shutdown_requested

:::tip Checkpoint frequency vs storage cost Checkpoint every 500–1000 steps for long training runs. Each checkpoint for a 7B model is ~28 GB. Checkpointing every 100 steps generates 300+ GB/day. Balance interruption recovery time (longer interval = more wasted work on preemption) against storage cost. :::


Technique 2: Mixed Precision Training

The Opportunity

Modern GPUs have hardware support for lower-precision arithmetic. The NVIDIA A100 delivers:

  • FP32 (float32): 312 TFLOPS
  • BF16 (bfloat16): 312 TFLOPS with tensor cores → effectively 19.5 TFLOPS → No - actually:
    • FP32 dense: 19.5 TFLOPS
    • BF16 tensor core: 312 TFLOPS - 16× more throughput

Training in BF16 instead of FP32 gives roughly 1.5–2× throughput improvement in practice (actual speedup is limited by memory bandwidth, not just compute throughput). It also halves memory usage for activations, allowing larger batch sizes or longer sequences.

The Risk: Numerical Stability

Not all operations are safe in FP16/BF16. Loss calculations involving sums over long sequences, softmax, and layer norms can underflow to zero or overflow to NaN. The solution is automatic mixed precision (AMP): keep weights in FP32 for gradient updates, compute forward/backward passes in BF16, and use dynamic loss scaling to prevent gradient underflow.

import torch
from torch.cuda.amp import autocast, GradScaler

def train_with_amp(
model: torch.nn.Module,
dataloader,
optimizer: torch.optim.Optimizer,
num_epochs: int,
device: str = "cuda",
) -> list[float]:
"""
Training loop with automatic mixed precision.
BF16 is preferred over FP16 for stability on A100/H100.
"""
model = model.to(device)

# Use BF16 on Ampere/Hopper GPUs - more numerically stable than FP16
# Falls back to FP16 on older GPUs
use_bf16 = torch.cuda.is_bf16_supported()
dtype = torch.bfloat16 if use_bf16 else torch.float16
print(f"Using {'BF16' if use_bf16 else 'FP16'} mixed precision")

# GradScaler only needed for FP16 (BF16 has larger dynamic range)
scaler = GradScaler() if not use_bf16 else None

losses = []

for epoch in range(num_epochs):
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()

# Forward pass in reduced precision
with autocast(dtype=dtype):
outputs = model(**batch)
loss = outputs.loss

if scaler is not None:
# FP16 path: scale loss to prevent gradient underflow
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
else:
# BF16 path: no scaling needed
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()

losses.append(loss.item())

return losses

Expected gains:

  • Throughput: +40–80% tokens/sec
  • Memory: −30–40% activation memory
  • Cost reduction: 40–45% for same number of training steps

Technique 3: Gradient Checkpointing

The Memory-Compute Tradeoff

During the backward pass, PyTorch needs to access the activations computed during the forward pass to calculate gradients. By default, it stores all activations in GPU memory - expensive for large models and long sequences. Gradient checkpointing discards intermediate activations during the forward pass and recomputes them on-demand during the backward pass.

The tradeoff: ~33% extra compute for ~60% memory savings. This is almost always worth it for large models because it allows:

  1. Training larger models on the same hardware
  2. Larger effective batch sizes (fewer gradient accumulation steps needed)
  3. Longer sequence lengths
from torch.utils.checkpoint import checkpoint_sequential
import torch.nn as nn

class MemoryEfficientTransformer(nn.Module):
def __init__(self, config):
super().__init__()
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.num_layers)
])
self.norm = nn.LayerNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.use_gradient_checkpointing = False

def enable_gradient_checkpointing(self):
self.use_gradient_checkpointing = True
# Also enable for HuggingFace models:
# model.gradient_checkpointing_enable()

def forward(self, input_ids, attention_mask=None):
x = self.embed(input_ids)

if self.use_gradient_checkpointing and self.training:
# checkpoint_sequential divides layers into chunks
# Each chunk's activations are discarded and recomputed on backward
x = checkpoint_sequential(
self.layers,
segments=len(self.layers) // 4, # 4 segments
input=x,
)
else:
for layer in self.layers:
x = layer(x, attention_mask)

x = self.norm(x)
return self.lm_head(x)

HuggingFace one-liner:

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model.gradient_checkpointing_enable()

# Memory comparison (7B model, sequence length 2048, batch size 4):
# Without checkpointing: ~80 GB activation memory
# With checkpointing: ~12 GB activation memory
# Extra compute cost: ~33%

Technique 4: Memory-Efficient Optimizers

Adam is the default optimizer for transformer training. It stores two moment tensors per parameter (m and v) - meaning optimizer state is 2× model size. For a 7B model (14 GB in BF16), Adam adds 28 GB of optimizer state in FP32 (56 GB total just for optimizer + model). This forces smaller batch sizes, slower convergence, and higher cost.

Adafactor: Near-Zero Optimizer Memory

Adafactor (Shazeer & Stern, 2018) approximates Adam's second moment using factored low-rank representations, reducing optimizer memory from O(n) to O(√n):

v^tr^tc^tT/1^Tr^t\hat{v}_t \approx \hat{r}_t \hat{c}_t^T / \hat{1}^T \hat{r}_t

Where r^t\hat{r}_t and c^t\hat{c}_t are row and column factor vectors. Memory reduction: 90% for large weight matrices.

from transformers.optimization import Adafactor

# Adafactor with adaptive learning rate (no lr scheduling needed)
optimizer = Adafactor(
model.parameters(),
scale_parameter=True, # adapt lr to parameter scale
relative_step=True, # lr schedule built-in
warmup_init=True, # warm up lr from small value
)

# Adafactor with fixed learning rate (when you have a scheduler)
optimizer = Adafactor(
model.parameters(),
lr=1e-3,
scale_parameter=False,
relative_step=False,
)

8-bit Adam: Full Quality, Half Memory

bitsandbytes 8-bit Adam (Dettmers et al., 2022) stores optimizer states in 8-bit quantized form, cutting optimizer memory by 75% with negligible quality loss:

import bitsandbytes as bnb

# Drop-in replacement for torch.optim.AdamW
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=2e-5,
betas=(0.9, 0.999),
)

Memory savings comparison for 7B model:

OptimizerOptimizer State MemoryQuality
AdamW FP3256 GBBaseline
AdamW BF1628 GB-0.1% perplexity
8-bit Adam14 GB-0.2% perplexity
Adafactor6 GB-0.5–1% perplexity

Technique 5: Compute-Optimal Training (Chinchilla)

This is the highest-leverage technique - it determines whether you're building the right model before you spend a dollar on compute.

The Chinchilla Insight

The 2022 DeepMind paper "Training Compute-Optimal Large Language Models" (Hoffmann et al.) showed that for a given compute budget CC, the optimal model size NN^* and training tokens DD^* satisfy:

N(C6)0.5N^* \approx \left(\frac{C}{6}\right)^{0.5}

D(C6)0.5×20D^* \approx \left(\frac{C}{6}\right)^{0.5} \times 20

In plain English: for every parameter, you need approximately 20 tokens of training data to use your compute budget optimally. The Chinchilla 70B model trained on 1.4T tokens outperformed GPT-3 (175B, 300B tokens) on most benchmarks - using 4× less compute.

Before Chinchilla, teams built bigger models and trained them for less time. After Chinchilla, the right question is: "given my compute budget, what is the optimal model-data tradeoff?"

def chinchilla_optimal_config(
compute_budget_flops: float,
a: float = 0.5,
b: float = 0.5,
c: float = 6.0,
) -> dict:
"""
Compute Chinchilla-optimal model size and training tokens.

Args:
compute_budget_flops: Total FLOPs budget (e.g., 1e23 for ~$50K on A100s)
a, b: Scaling exponents (0.5 from Chinchilla paper)
c: FLOPs per parameter per token (6 = 2 forward + 4 backward)

Returns:
Optimal N (parameters) and D (tokens)
"""
import math

# Optimal: N* = (C/(6))^0.5 (simplified)
n_optimal = (compute_budget_flops / c) ** a
d_optimal = (compute_budget_flops / c) ** b * 20 # ~20 tokens/param

return {
"compute_budget_flops": compute_budget_flops,
"optimal_parameters": n_optimal,
"optimal_tokens": d_optimal,
"tokens_per_parameter": d_optimal / n_optimal,
"parameter_billions": n_optimal / 1e9,
"token_billions": d_optimal / 1e9,
}


# Example: $50K budget on 32 A100s over 5 days
# 32 GPUs × 312 TFLOPS × 0.45 MFU × 5 days × 86400 sec ≈ 1.96 × 10^20 FLOPs
budget = 32 * 312e12 * 0.45 * 5 * 86400

config = chinchilla_optimal_config(budget)
print(f"Optimal parameters: {config['parameter_billions']:.1f}B")
print(f"Optimal tokens: {config['token_billions']:.0f}B")
print(f"Tokens per parameter: {config['tokens_per_parameter']:.0f}")

# Common mistake: training a 3B model on 60B tokens with this budget
# Chinchilla optimal: ~1.2B model on 24B tokens
# The 3B model is "overtrained compute-wise" - you'd do better with a smaller model
# trained on more data OR a larger model trained on fewer data

Putting It All Together: The $12K Training Plan

Here is the optimization plan that reduced the 50Krunto50K run to 11,800:

OptimizationMechanismCost Reduction
Spot instances60% discount on compute-$25,000
Mixed precision1.5× throughput on same hardware-$8,000
Gradient checkpointing2× batch size → 20% fewer steps-$3,000
8-bit AdamLarger batches → faster convergence-$2,000
Chinchilla rescalingRight model size for data budget-$8,000

Not all savings are independent - they compound in some places and substitute in others. The final result was $11,800, with the model performing slightly better than the original on validation metrics because the Chinchilla-optimal configuration improved sample efficiency.


Production Engineering Notes

Gradient Accumulation for Effective Batch Size

When you can't fit your target batch size in GPU memory even with gradient checkpointing, use gradient accumulation: accumulate gradients over N micro-batches before the optimizer step. Effective batch size = micro-batch × accumulation steps × number of GPUs.

accumulation_steps = 8 # target effective batch = 8 × micro_batch × n_gpus

for i, batch in enumerate(dataloader):
with autocast(dtype=torch.bfloat16):
loss = model(**batch).loss / accumulation_steps # normalize

loss.backward()

if (i + 1) % accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()

Data Loading as a Bottleneck

GPU starvation from slow data loading is common and expensive. Profile with:

import time
import torch

# Measure data loading time vs compute time
data_time, compute_time = 0, 0

for batch in dataloader:
t0 = time.perf_counter()
batch = {k: v.cuda() for k, v in batch.items()}
t1 = time.perf_counter()

with autocast():
loss = model(**batch).loss
loss.backward()

t2 = time.perf_counter()
data_time += t1 - t0
compute_time += t2 - t1

ratio = data_time / (data_time + compute_time)
print(f"Data loading: {ratio:.1%} of total time")
# If > 20%, increase num_workers, use prefetching, or cache to local SSD

Common Mistakes

:::danger Not checkpointing before a long run Starting a 20-hour training run with no checkpointing means a crash at hour 19 wastes all 20 hours of compute. Always implement checkpointing before starting any run over 2 hours. The checkpoint overhead is less than 0.5% of total compute for 500-step intervals. :::

:::danger Training a massive model on a tiny dataset If you have 1B tokens of training data and you're training a 10B parameter model, you're using ~0.1 tokens per parameter - drastically undertrained by Chinchilla standards. The model will memorize the training set rather than generalize. Check your tokens-per-parameter ratio. Target 20× for general-purpose models; domain-specific fine-tuning can be lower. :::

:::warning Using FP16 instead of BF16 on modern hardware FP16 has a smaller dynamic range than BF16 and requires loss scaling to prevent gradient underflow. On A100/H100/TPUv4+, BF16 is strictly better: same throughput, larger dynamic range, no loss scaling needed. Only use FP16 for older GPUs (V100 and earlier) that don't support BF16. :::

:::warning Forgetting to validate on holdout set during long runs A training run that converges to the wrong local optimum - overfitting, learning rate too high, data contamination - wastes all subsequent compute. Run validation every 1,000–5,000 steps and set up early stopping. Don't wait for the full run to finish to discover the model diverged at step 2,000. :::


Interview Q&A

Q: How do you estimate training cost before starting a run?

A: I use the 6ND formula: FLOPs ≈ 6 × N × D, where N is parameter count and D is token count. Then divide by effective GPU throughput (peak FLOPS × MFU × number of GPUs) to get wall-clock time, and multiply by per-GPU hourly cost. I always add 20% for overhead (checkpointing, validation, data loading). For a 7B model on 100B tokens: 6 × 7e9 × 1e11 = 4.2 × 10²¹ FLOPs. On 8 A100s at 40% MFU: 8 × 312e12 × 0.4 = 998 TFLOPS effective. Time: 4.2e21 / 998e12 / 3600 ≈ 1,167 hours. At 3/GPUhr:3/GPU-hr: 28K. That's the first estimate - then I apply optimizations.

Q: What is the Chinchilla scaling law and why does it matter for cost?

A: Chinchilla (Hoffmann et al., 2022) showed that for a given compute budget, there is an optimal model-data tradeoff: approximately 20 training tokens per parameter. Before this paper, the field overtrained on model size - GPT-3 was 175B parameters trained on 300B tokens, when Chinchilla optimal would have been a ~4B model trained on 80B tokens for the same compute budget. This matters enormously for cost because: (1) smaller models have lower inference cost forever, and (2) if your training data budget is fixed, you may be training a model that's 5–10× larger than necessary. I check the tokens-per-parameter ratio before every training run.

Q: Walk me through how to reduce a training job from FP32 to BF16.

A: Three steps. First, wrap the forward pass with torch.cuda.amp.autocast(dtype=torch.bfloat16). Second, if on FP16 (not BF16), add a GradScaler and scale the loss before backward. Third, verify the model converges correctly on a short debugging run - compare loss curves between FP32 and BF16 for the first 1,000 steps. The main risk is numerical instability in operations with large dynamic range - usually gradients in early training with high learning rates. If you see NaN losses, reduce learning rate slightly or use gradient clipping. BF16 is generally safe on A100 without any scaling.

Q: When is gradient checkpointing worth it?

A: Almost always for models over 1B parameters, and often for smaller models when training on long sequences. The math: ~33% extra compute cost, ~60% memory reduction. The indirect benefit is that the memory savings allow larger batch sizes. If you can double your batch size, you improve convergence efficiency enough to train in fewer total steps, which reduces cost. The only case where I'd skip it: when memory isn't a constraint and you need maximum throughput for a short, fixed training run. Check first whether you're memory-bound or compute-bound - nvidia-smi memory utilization. If below 80%, you probably don't need it.

Q: How do you handle spot instance preemptions in distributed training?

A: Checkpoint-based restart is the standard approach. Save model state, optimizer state, global step, and random seeds every 500–1000 steps to durable storage (S3/GCS, not local disk). Register a SIGTERM handler that triggers an immediate checkpoint - AWS gives 2 minutes warning before preemption, which is enough time to checkpoint even a 65B model. For distributed training, use elastic training frameworks (PyTorch Elastic, Horovod with elastic mode) that can handle a node leaving and rejoining. The key metrics to track: average wasted compute per preemption (should be under 5% of total training compute) and average preemption frequency (varies by instance type and region - p3 spot is preempted less than p4d spot in my experience).

© 2026 EngineersOfAI. All rights reserved.