Pretraining at Scale
The $4.6 Million Question
In 2020, OpenAI trained GPT-3 on approximately 300 billion tokens using roughly 3,640 petaflop-days of compute. At the market rate for cloud GPUs at the time, the bill was approximately $4.6 million. That is for a single training run. If something went wrong - a bug in the training code, a hyperparameter choice that proved sub-optimal - you ran it again and paid again.
This is the world of large-scale pretraining. The decisions made before you start a run - how to shard the model across GPUs, what precision to use, how to prepare the training data, how to detect and recover from loss spikes - these decisions are worth millions of dollars and months of compute time. A 10% improvement in training efficiency is worth $460,000 on a GPT-3-scale run.
An engineer named Wang Jun at a startup is tasked with pretraining a 7B model on 1 trillion tokens. She has a cluster of 512 A100 GPUs. The cluster costs 20,480 per hour - roughly 10 million. A single night of bad training due to a misconfigured learning rate scheduler costs $500,000 in wasted compute.
Wang Jun reads every paper on distributed training twice before writing a single line of code.
Why This Exists: A Single GPU Is Not Enough
A 70B parameter model in FP32 (4 bytes per parameter) requires 280GB just to store the weights. A single A100 has 80GB of memory. You cannot fit the model at all, let alone train it.
Even a 7B model in FP32 requires 28GB for weights alone. Add optimizer states (Adam requires 8 bytes per parameter: 4 for the first moment, 4 for the second) and you need another 56GB. Plus gradient storage (28GB more). Total: 112GB for a 7B model - still more than a single 80GB A100.
The solution is distributed training - spread the model, data, and optimizer states across many GPUs. How you spread them determines efficiency and feasibility.
The Three Forms of Parallelism
Data Parallelism
Each GPU holds a complete copy of the model. The training batch is split across GPUs. Each GPU computes gradients on its shard of the batch. Gradients are synchronized across GPUs (all-reduce) and all GPUs update their weights identically.
Works when: the model fits on a single GPU. Effective for models up to ~7B parameters on A100 80GB GPUs.
Problem: for large models, you cannot fit the model on a single GPU. Also, naive data parallelism requires storing the full optimizer state on each GPU.
Tensor Parallelism (Megatron-LM)
Split individual weight matrices across GPUs. For a feed-forward layer with weight matrix , split it column-wise across GPUs so each GPU holds .
For attention: split the , , projection matrices so each GPU handles some attention heads. The output projection is split row-wise.
Communication pattern: each forward pass requires an all-reduce after the attention and FFN layers (two all-reduces per transformer layer). This requires high-bandwidth interconnect - NVLink between GPUs on the same node (600GB/s) rather than PCIe (32GB/s) or ethernet between nodes.
Practical limit: tensor parallelism works well within a node (8 GPUs connected by NVLink). Across nodes, the communication overhead makes it less efficient.
Pipeline Parallelism
Split the model's layers across GPUs. GPU 1 holds layers 1-8, GPU 2 holds layers 9-16, and so on. Data flows through the pipeline.
The bubble problem: with a single microbatch, at any given time only one GPU is computing and the rest are idle. GPipe (Huang et al., 2019) solves this by splitting each batch into microbatches - while GPU 2 processes microbatch 1, GPU 1 is already processing microbatch 2. The bubble ratio (wasted time) is where is the number of pipeline stages. Using enough microbatches reduces the bubble to near zero.
Mixed Precision Training
Training in FP32 (32-bit floating point, 4 bytes per value) is wasteful. Modern GPUs have specialized hardware for FP16 and BF16 (16-bit formats) that is 2-4x faster and uses half the memory.
FP16 (IEEE 754 half-precision): 5 bits for exponent, 10 bits for mantissa. Range: approximately to . Problem: loss scale can overflow (values above 65504 become infinity) or underflow to zero. Requires loss scaling - multiply loss by a large scale factor, then divide gradients.
BF16 (Brain Float 16): 8 bits for exponent, 7 bits for mantissa. Same exponent range as FP32 (), but less precision. Does not overflow. Used by modern LLM training pipelines because it requires no loss scaling.
Mixed precision recipe (Micikevicius et al., 2018):
- Maintain master weights in FP32 (for numerical stability of weight updates)
- Cast weights to FP16/BF16 for forward and backward passes (faster compute, less memory bandwidth)
- Accumulate gradients in FP32
- Update master weights in FP32
Memory saving: weights use 2x less memory during forward/backward. Compute speedup: 2-4x on modern hardware.
# PyTorch automatic mixed precision (AMP)
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler() # For FP16 - handles loss scaling automatically
# Not needed for BF16
# Training loop
for batch in dataloader:
optimizer.zero_grad()
with autocast(dtype=torch.bfloat16): # BF16 forward pass
outputs = model(**batch)
loss = outputs.loss
# For BF16: no scaler needed
loss.backward()
optimizer.step()
# For FP16 (with scaler):
# scaler.scale(loss).backward()
# scaler.step(optimizer)
# scaler.update()
Gradient Checkpointing
The standard backpropagation algorithm stores all intermediate activations (the outputs of each layer) during the forward pass so they can be used during the backward pass. For a transformer with 80 layers, a batch size of 2, and sequence length 2048 in BF16: approximately 80 layers × 2048 tokens × 8192 hidden dim × 2 bytes × 2 batch = ~5.4GB just for activations per transformer layer stack. The full model might require 200-400GB for activations alone.
Gradient checkpointing (also called activation recomputation) trades compute for memory. Instead of storing all activations:
- During the forward pass, store only activations at "checkpoint" boundaries (e.g., every N layers)
- During the backward pass, recompute the non-stored activations from the checkpointed ones on the fly
Memory saving: roughly proportional to the square root of the number of layers checkpointed (or fully proportional if you checkpoint every layer and recompute all activations).
Compute cost: approximately 33% more forward pass FLOPs (you run some forward computations twice).
from torch.utils.checkpoint import checkpoint_sequential, checkpoint
# Gradient checkpointing for a transformer layer
class TransformerLayerWithCheckpoint(nn.Module):
def forward(self, x):
# Recompute attention activations during backward
# instead of storing them
return checkpoint(self.attention_block, x, use_reentrant=False)
# In HuggingFace, enable with one line:
model.gradient_checkpointing_enable()
Gradient Accumulation
Real batch size determines model quality and stability. GPT-3 was trained with a batch size of approximately 3.2 million tokens per step. A single A100 can handle maybe 32,000 tokens per step (batch size 16, sequence length 2048). To simulate the desired large batch size, you need gradient accumulation.
Instead of updating weights at every step, accumulate gradients over steps before updating:
accumulation_steps = 32 # Simulate 32x larger batch
optimizer.zero_grad()
for step, batch in enumerate(dataloader):
outputs = model(**batch)
loss = outputs.loss / accumulation_steps # Scale loss
loss.backward() # Accumulate gradients
if (step + 1) % accumulation_steps == 0:
optimizer.step() # Update once every k steps
optimizer.zero_grad()
The effective batch size is per_device_batch_size * num_gpus * accumulation_steps. This is mathematically equivalent to training with a larger batch size (assuming no batch normalization, which LLMs do not use).
ZeRO Optimizer: Partitioning at Scale
Adam optimizer stores 3x the model weights in memory: the weights themselves, plus the first moment (mean of gradients) and second moment (variance of gradients). For a 70B parameter model: 70B * 4 bytes * 3 = 840GB just for Adam states. No single machine has this.
DeepSpeed ZeRO (Zero Redundancy Optimizer, Rajbhandari et al., 2020) eliminates this redundancy by partitioning optimizer states, gradients, and parameters across GPUs.
ZeRO-1: Partition optimizer states only. Each GPU holds optimizer states for a subset of parameters. Weights and gradients still replicated. Memory saving: ~4x for Adam.
ZeRO-2: Partition optimizer states AND gradients. Memory saving: ~8x.
ZeRO-3: Partition optimizer states, gradients, AND model parameters. Each GPU holds only a shard of the weights. Memory saving: ~64x for a 64-GPU run. This means a 70B model can be trained on 64 GPUs without tensor parallelism.
Communication overhead: ZeRO-3 requires all-gather operations to assemble full parameters before each layer's computation, and reduce-scatter after each layer's backward pass. The communication volume is higher than ZeRO-1/2 but the memory savings often justify the cost.
# Using DeepSpeed ZeRO-3 with HuggingFace
from transformers import TrainingArguments
training_args = TrainingArguments(
deepspeed="ds_config_zero3.json", # DeepSpeed config file
...
)
# ds_config_zero3.json:
# {
# "zero_optimization": {
# "stage": 3,
# "offload_optimizer": {"device": "cpu"}, # ZeRO-Offload
# "offload_param": {"device": "cpu"}
# },
# "bf16": {"enabled": true},
# "gradient_clipping": 1.0
# }
Flash Attention: IO-Aware Exact Attention
Standard attention computes:
This requires materializing the attention matrix in GPU memory (HBM - high bandwidth memory). For tokens: 8192^2 * 2 bytes = 134MB per head per layer. A 70B model with 80 layers and 64 heads: ~688GB just for attention matrices. This is impossible.
Flash Attention (Dao et al., 2022) reformulates attention computation to never materialize the full attention matrix. Instead, it processes the computation in tiles that fit in SRAM (the fast on-chip memory). Key insight: the bottleneck in attention is not compute (FLOPs) but memory bandwidth (reading/writing large matrices to HBM).
By using tiled computation with online softmax normalization, Flash Attention achieves:
- Memory: instead of for the attention matrix
- Speed: 2-4x faster wall-clock time due to reduced HBM reads/writes
- Exact results: not an approximation - identical output to standard attention
Flash Attention 2 (Dao, 2023): Further optimization of the tiling and parallelism strategy. ~2x faster than Flash Attention 1 on A100 GPUs.
# Flash Attention is integrated into HuggingFace models
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2", # Requires flash-attn package
torch_dtype=torch.bfloat16,
)
# Install: pip install flash-attn --no-build-isolation
Training Data at Scale
What goes into training a frontier LLM? The data is the model - get this wrong and no amount of infrastructure excellence saves you.
The Pile (Gao et al., 2020): 825GB of diverse text from 22 sources including Common Crawl, GitHub, arXiv, PubMed, FreeLaw, and OpenWebText. Created by EleutherAI and used for GPT-Neo, GPT-J, GPT-NeoX.
C4 (Raffel et al., 2019): 750GB cleaned Common Crawl text. Filtering removes pages with fewer than 3 sentences, sentences ending in unusual characters, pages containing offensive content. Used for T5.
RedPajama (Together AI, 2023): Open replication of LLaMA's training data. ~1.2 trillion tokens. Components: Common Crawl (878B tokens), C4 (175B), GitHub (59B), Books (26B), Arxiv (28B), Wikipedia (24B), StackExchange (20B).
FineWeb (Hugging Face, 2024): 15 trillion tokens of high-quality filtered Common Crawl. Shows that aggressive quality filtering outperforms simply using more raw data.
Data quality over quantity: Chinchilla scaling laws (Hoffmann et al., 2022) showed that optimally trained models need about 20 tokens per parameter. But quality matters - Falcon (Technology Innovation Institute, 2023) trained on 1 trillion tokens with aggressive quality filtering and outperformed much larger models trained on more but lower-quality data.
Data mixture: the proportion of each data source matters. Code data (GitHub) improves reasoning ability even for non-coding tasks. Web text provides breadth. Books provide long-form coherence. Wikipedia and academic papers provide factual accuracy. Most frontier models train on a carefully designed mixture.
Training Instabilities and Loss Spikes
Even with the best hardware and code, large training runs encounter instabilities.
Loss spikes: sudden large increases in loss, often followed by partial recovery. Causes include:
- Bad batches (a batch that happened to contain many difficult or corrupted examples)
- Learning rate too high
- Gradient norm explosion
- Numerical issues with FP16
Gradient clipping: clip gradients to a maximum L2 norm (typically 1.0). This prevents any single step from making a catastrophically large weight update.
# Applied during training loop
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Recovery from spikes: Many teams maintain checkpoints every few hundred steps. When a spike occurs, roll back to before the spike and resume training with:
- A small batch of "clean" data to stabilize
- Slightly reduced learning rate
- Investigation of what caused the spike
GPT-3 training instability: Brown et al. (2020) report that they encountered training instabilities and had to restart from earlier checkpoints multiple times. This is normal at scale.
Real numbers on pretraining costs
- GPT-3 (175B, 300B tokens): ~$4.6M (OpenAI, 2020)
- LLaMA-2 (70B, 2T tokens): estimated ~$10-15M
- GPT-4: estimated $50-100M (speculative, architecture not disclosed)
- Training your own 7B model on 1T tokens on 512 A100s: ~$1-2M
- Training a 7B model from scratch on 100B tokens: ~$150K-250K
Code: A Production-Ready Training Configuration
"""
Production-scale training configuration for a 7B model.
Uses DeepSpeed ZeRO-3, Flash Attention, gradient checkpointing.
"""
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
)
import torch
def get_training_args(output_dir: str) -> TrainingArguments:
"""
Production training arguments for a 7B model.
Adjust batch size and accumulation for your GPU memory.
"""
return TrainingArguments(
output_dir=output_dir,
# Effective batch size = 4 * 8 GPUs * 16 accum = 512 sequences
# At 2048 tokens/sequence: ~1M tokens/step
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
# Learning rate schedule
learning_rate=3e-4, # Higher for pretraining than fine-tuning
lr_scheduler_type="cosine",
warmup_ratio=0.01, # 1% of steps for warmup
num_train_epochs=1, # Usually 1 epoch over pretraining data
# Precision
bf16=True, # BF16 for modern GPUs (A100, H100)
tf32=True, # Use TF32 for matrix multiplications
# Memory optimization
gradient_checkpointing=True, # Recompute activations in backward pass
# Stability
max_grad_norm=1.0, # Gradient clipping
# Logging and saving
logging_steps=10,
save_steps=500,
save_total_limit=5, # Keep only last 5 checkpoints
# DeepSpeed integration
deepspeed="ds_config.json", # ZeRO-3 configuration
# Performance
dataloader_num_workers=4,
dataloader_pin_memory=True,
report_to="wandb",
)
def load_model_for_pretraining(model_name: str):
"""Load model with Flash Attention and gradient checkpointing."""
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # 2-4x attention speedup
use_cache=False, # Must disable KV cache when using gradient checkpointing
)
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
return model
# DeepSpeed config for ZeRO-3
DS_CONFIG_ZERO3 = {
"zero_optimization": {
"stage": 3,
"overlap_comm": True,
"contiguous_gradients": True,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"gather_16bit_weights_on_model_save": True,
# Offload optimizer states to CPU if GPU memory is tight
# "offload_optimizer": {"device": "cpu", "pin_memory": True},
},
"bf16": {"enabled": True},
"gradient_clipping": 1.0,
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"steps_per_print": 100,
"wall_clock_breakdown": False,
}
def monitor_training_health(trainer_state):
"""
Monitor loss curve for training instabilities.
Returns True if a loss spike is detected.
"""
if len(trainer_state.log_history) < 10:
return False
recent_losses = [
log.get("loss", None)
for log in trainer_state.log_history[-10:]
if "loss" in log
]
if len(recent_losses) < 2:
return False
# Loss spike: current loss is 2x the recent average
recent_avg = sum(recent_losses[:-1]) / len(recent_losses[:-1])
current_loss = recent_losses[-1]
if current_loss > 2 * recent_avg:
print(f"WARNING: Loss spike detected! Current: {current_loss:.4f}, "
f"Recent avg: {recent_avg:.4f}")
return True
return False
Common Mistakes
Using FP16 instead of BF16 for modern hardware FP16 can overflow (max value 65504) during LLM training when gradients or activations are large. This causes NaN values and training crashes. BF16 has the same exponent range as FP32 and never overflows. If you have A100, H100, or newer GPUs, always use BF16 for pretraining. Only use FP16 with careful loss scaling on older hardware (V100, T4).
Forgetting to disable KV cache with gradient checkpointing
When you enable gradient_checkpointing=True and also have use_cache=True (the default), HuggingFace will throw a warning and disable the KV cache automatically. But some configurations silently fail. Always explicitly set use_cache=False when gradient checkpointing is enabled - they are incompatible because gradient checkpointing recomputes forward passes and the KV cache would be stale.
Under-preparing training data Data quality is the most important factor in final model quality, and it is irreversible - once the model is trained on bad data, you cannot un-train it. Common issues: near-duplicate documents (inflate benchmark scores by memorization), low-quality text (autogenerated, machine-translated back and forth), toxic content (shows up in model outputs), data contamination (test set data in training). Invest as much time in data curation as in the training infrastructure.
Tracking GPU utilization, not just loss
During distributed training, monitor GPU utilization (aim for 90%+ MFU - Model FLOP Utilization). Low utilization usually indicates a bottleneck: data loading (increase num_workers), communication overhead (reduce model sharding), or CPU-bound preprocessing. Tools: nvidia-smi dmon, W&B system monitoring, DeepSpeed's built-in profiler.
Interview Q&A
Q1: Explain tensor parallelism, pipeline parallelism, and data parallelism. When would you use each?
Data parallelism replicates the full model on each GPU and splits the batch - works when the model fits on a single GPU, scales to any number of GPUs with all-reduce communication. Tensor parallelism splits individual weight matrices across GPUs - used when the model does not fit on one GPU, requires high-bandwidth interconnect (NVLink), typically applied within a node (8 GPUs). Pipeline parallelism splits layers across GPUs - each GPU holds some layers and computation flows through them like a pipeline. In practice, large-scale pretraining uses all three simultaneously ("3D parallelism" in Megatron-LM): tensor parallel within a node, pipeline parallel across nodes, data parallel across pipeline replicas.
Q2: What is ZeRO-3 and what are its trade-offs?
ZeRO-3 (Zero Redundancy Optimizer stage 3) partitions all three components - optimizer states, gradients, and model parameters - across GPUs. Each GPU holds only 1/N of each. Memory saving is approximately 64x for 64 GPUs with Adam. The trade-off is communication: ZeRO-3 requires an all-gather before each layer's forward pass (to assemble the full weight tensor) and a reduce-scatter after each layer's backward pass. Communication volume is higher than ZeRO-1 or ZeRO-2. Whether ZeRO-3 is worth using depends on the communication bandwidth vs memory pressure trade-off. On fast interconnects (InfiniBand HDR, 200Gb/s), ZeRO-3 is efficient. On slower networks, ZeRO-2 with tensor parallelism is often better.
Q3: What is Flash Attention and why does it matter?
Flash Attention (Dao et al., 2022) is an IO-aware exact attention algorithm. Standard attention materializes the full attention matrix in GPU HBM, which is memory and creates a memory bottleneck. Flash Attention computes attention in tiles that fit in SRAM (on-chip fast memory), fusing the matrix multiply, softmax, and output projection into a single kernel. The result: never materialize the full attention matrix, memory complexity, 2-4x speedup on modern hardware. It is exact - produces identical results to standard attention. Flash Attention 2 extends this with better parallelism and is now standard in all major LLM training frameworks.
Q4: What is gradient checkpointing and what does the 33% compute overhead mean in practice?
Gradient checkpointing (activation recomputation) solves the memory bottleneck of storing all intermediate activations for backpropagation. Instead of storing all activations (~200-400GB for a 70B model), only checkpoint activations at layer boundaries. During the backward pass, recompute non-stored activations from the checkpointed ones. The 33% overhead means: for every 3 FLOPs you do normally, you do 4 FLOPs with checkpointing (1/3 extra). In practice, with modern hardware where compute is often cheaper than memory bandwidth, the speedup from smaller memory footprint can partially offset the compute overhead. For pretraining large models, gradient checkpointing is almost always used.
Q5: How do Chinchilla scaling laws change how you decide model size and training compute budget?
Chinchilla (Hoffmann et al., 2022) showed that previous large models (GPT-3, Gopher) were significantly undertrained - they had too many parameters for the amount of data they were trained on. The Chinchilla-optimal allocation given a compute budget : model size scales as and training tokens scales as . Practically, this means you should train on roughly 20 tokens per parameter. For a 7B model: train on ~140B tokens minimum. For a 70B model: train on ~1.4T tokens. LLaMA (Touvron et al., 2023) took this further, training their 7B model on 1T tokens - far more than Chinchilla-optimal - producing a smaller model that outperforms GPT-3 (175B) on most benchmarks. The lesson: if you are compute-constrained at inference time (serving millions of requests), train a smaller model longer on more data.
Advanced: Distributed Training Architecture
To run 3D parallelism at scale, you need to understand how the communication patterns interact. Here is a worked example for a 70B model on 256 A100 GPUs:
"""
Distributed training configuration for a 70B model on 256 A100 GPUs.
Uses Megatron-LM 3D parallelism.
Config dimensions:
- Tensor parallel (TP): 8 GPUs per node
- Pipeline parallel (PP): 4 stages (32 GPU nodes / 4)
- Data parallel (DP): 256 / (8 * 4) = 8 replicas
Each GPU sees: 70B / (8 TP * 4 PP) = 2.2B parameters
"""
# Megatron-LM configuration (conceptual - actual config is YAML)
MEGATRON_CONFIG = {
# Model
"num_layers": 80,
"hidden_size": 8192,
"num_attention_heads": 64,
"seq_length": 4096,
# Parallelism
"tensor_model_parallel_size": 8, # Split weight matrices 8 ways
"pipeline_model_parallel_size": 4, # 4 pipeline stages = 20 layers each
"micro_batch_size": 1, # Per-GPU batch (extremely small for 70B)
"global_batch_size": 1024, # Effective batch = 1 * 8 * 4 * 32 * 1 = large
# Memory
"use_flash_attn": True,
"recompute_activations": True, # Gradient checkpointing
# Precision
"bf16": True,
"loss_scale": 1.0, # No loss scaling needed for BF16
# Optimizer
"optimizer": "adam",
"adam_beta1": 0.9,
"adam_beta2": 0.95,
"adam_eps": 1e-8,
"weight_decay": 0.1,
# Learning rate
"lr": 3e-4,
"lr_decay_style": "cosine",
"lr_warmup_fraction": 0.01,
"min_lr": 3e-5,
}
def compute_memory_per_gpu(
total_params: float, # Billions
tp: int,
pp: int,
use_zero3: bool = False,
optimizer_bytes_per_param: int = 12, # Adam: 4 bytes FP16 weight + 4+4 bytes moments
) -> float:
"""
Estimate memory per GPU in GB for distributed training.
"""
params_per_gpu = (total_params * 1e9) / (tp * pp)
# Model weights in BF16 (2 bytes per param)
weight_memory = params_per_gpu * 2
# Gradients in BF16
gradient_memory = params_per_gpu * 2
# Optimizer states (Adam): FP32 weights + two moments
if use_zero3:
# ZeRO-3: optimizer states divided across all DP replicas
# Assume 8 DP replicas
optimizer_memory = params_per_gpu * optimizer_bytes_per_param / 8
else:
optimizer_memory = params_per_gpu * optimizer_bytes_per_param
total_bytes = weight_memory + gradient_memory + optimizer_memory
return total_bytes / 1e9
print("70B model memory per GPU:")
print(f" TP=8, PP=4, no ZeRO: {compute_memory_per_gpu(70, 8, 4):.1f} GB")
print(f" TP=8, PP=4, ZeRO-3: {compute_memory_per_gpu(70, 8, 4, True):.1f} GB")
# Example output:
# TP=8, PP=4, no ZeRO: 58.3 GB (fits on 80GB A100!)
# TP=8, PP=4, ZeRO-3: 31.7 GB (fits with room for activations)
Data Pipeline Engineering
An often-overlooked bottleneck in large-scale training: data loading. At full speed, a cluster of 256 A100 GPUs processes tokens faster than many storage systems can deliver them.
The data throughput calculation: a 256-GPU cluster at 300 TFLOP/s per GPU, processing a 70B model at 6 FLOPs per token per parameter ≈ 300 TF × 256 GPUs / (70B × 6) ≈ 183,000 tokens/second. A standard HDD delivers 1MB/s of decompressed text ≈ ~300 tokens/s. You need at least 600 HDDs (or much faster storage) to keep up.
Solutions in practice:
- Store training data as binary token ID files (no re-tokenization at runtime)
- Use NVMe SSDs with 3-7 GB/s throughput
- Parallel data loading:
DataLoaderwith many workers - Memory mapping (
np.memmap) for instant access to any position in huge files - Pre-shuffle during data preparation so training can stream sequentially
"""
Efficient data loading for pretraining with memory-mapped files.
"""
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class MemmapTokenDataset(Dataset):
"""
Memory-mapped dataset for large tokenized corpora.
Allows training on datasets that don't fit in RAM.
Supports random access in O(1) time.
"""
def __init__(
self,
data_path: str, # Path to .bin file of token IDs (uint16 or uint32)
seq_length: int = 2048,
dtype=np.uint16, # uint16 for vocab ≤ 65535 tokens
):
self.seq_length = seq_length
self.data = np.memmap(data_path, dtype=dtype, mode='r')
# Number of complete sequences in the file
self.num_sequences = (len(self.data) - 1) // seq_length
def __len__(self):
return self.num_sequences
def __getitem__(self, idx):
start = idx * self.seq_length
end = start + self.seq_length + 1 # +1 for the next-token labels
chunk = torch.from_numpy(self.data[start:end].astype(np.int64))
x = chunk[:-1] # Input: tokens 0..T-1
y = chunk[1:] # Labels: tokens 1..T
return {"input_ids": x, "labels": y}
def create_data_loader(
data_paths: list, # List of .bin files to train on
seq_length: int = 2048,
batch_size: int = 4,
num_workers: int = 8,
shuffle: bool = True,
) -> DataLoader:
"""Create a DataLoader from multiple binary token files."""
from torch.utils.data import ConcatDataset
datasets = [MemmapTokenDataset(p, seq_length) for p in data_paths]
combined = ConcatDataset(datasets)
return DataLoader(
combined,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True, # Faster GPU transfer
prefetch_factor=4, # Pre-load next batches while GPU is busy
)
def prepare_tokenized_corpus(
raw_text_path: str,
output_path: str,
tokenizer_name: str = "meta-llama/Llama-2-7b-hf",
chunk_size: int = 100_000, # Tokenize this many chars at a time
):
"""
Pre-tokenize a text corpus and save as binary token IDs.
Run this once before training - saves re-tokenization overhead.
"""
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# Use uint16 if vocab size less than 65535, else uint32
dtype = np.uint16 if tokenizer.vocab_size < 65535 else np.uint32
all_tokens = []
with open(raw_text_path, 'r', encoding='utf-8') as f:
text = f.read(chunk_size)
while text:
tokens = tokenizer.encode(text, add_special_tokens=False)
all_tokens.extend(tokens)
text = f.read(chunk_size)
# Add EOS token between documents (important for clean sequence boundaries)
# In practice, you'd also add document boundaries
token_array = np.array(all_tokens, dtype=dtype)
token_array.tofile(output_path)
print(f"Tokenized {len(all_tokens):,} tokens saved to {output_path}")
print(f"File size: {len(all_tokens) * np.dtype(dtype).itemsize / 1e6:.1f} MB")
return len(all_tokens)
Model FLOP Utilization (MFU): The Real Efficiency Metric
During distributed training, the fraction of theoretical peak performance actually used is called Model FLOP Utilization (MFU). It is the single most important metric for training efficiency.
Theoretical peak FLOP/s: for A100 SXM4 in BF16: 312 TFLOP/s. For H100 in BF16: 989 TFLOP/s.
Achieved MFU: top-tier implementations achieve 40-60% MFU. PaLM (Chowdhery et al., 2022) reported 46.2% MFU on 6,144 TPU v4 chips. LLaMA 3 (Touvron et al., 2024) reported approximately 43% MFU on 16,384 H100s.
What kills MFU:
- Communication overhead (all-reduce, pipeline bubbles)
- Data loading bottleneck
- Memory bandwidth bottleneck (especially for attention)
- CPU-GPU synchronization points
Theoretical FLOPs for a forward+backward pass on an LLM: approximately FLOPs, where is the number of model parameters and is the number of tokens in the batch. A training step at MFU 50% on 256 A100s (312 TFLOP/s each) processes: tokens per second.
Handling Training Instabilities
Loss spikes and training divergences are a common operational challenge in large-scale pretraining. Understanding their causes and mitigations is essential for anyone running or debugging long training runs.
Loss Spike Detection and Recovery
import numpy as np
from collections import deque
class TrainingStabilityMonitor:
"""Monitor training runs for loss spikes and other instabilities."""
def __init__(
self,
spike_threshold: float = 1.5, # Loss spike if > 1.5x rolling average
window_size: int = 100, # Rolling window for baseline
max_grad_norm: float = 1.0,
):
self.spike_threshold = spike_threshold
self.window_size = window_size
self.max_grad_norm = max_grad_norm
self.loss_history = deque(maxlen=window_size)
self.grad_norm_history = deque(maxlen=window_size)
self.spike_steps = []
def update(self, step: int, loss: float, grad_norm: float) -> dict:
"""Update monitor with current step metrics. Returns alerts."""
alerts = []
# Check for loss spike
if len(self.loss_history) >= 10:
rolling_avg = np.mean(list(self.loss_history)[-50:])
if loss > rolling_avg * self.spike_threshold:
alert = {
"type": "loss_spike",
"step": step,
"current_loss": loss,
"rolling_avg": rolling_avg,
"ratio": loss / rolling_avg,
}
alerts.append(alert)
self.spike_steps.append(step)
print(f"[Step {step}] LOSS SPIKE: {loss:.4f} vs avg {rolling_avg:.4f}")
# Check for gradient explosion
if grad_norm > self.max_grad_norm * 3:
alerts.append({
"type": "gradient_explosion",
"step": step,
"grad_norm": grad_norm,
"threshold": self.max_grad_norm * 3,
})
print(f"[Step {step}] GRADIENT EXPLOSION: {grad_norm:.2f}")
# Check for training plateau (loss not decreasing)
if len(self.loss_history) >= self.window_size:
first_half = np.mean(list(self.loss_history)[:self.window_size//2])
second_half = np.mean(list(self.loss_history)[self.window_size//2:])
if second_half >= first_half * 0.99: # Less than 1% improvement
alerts.append({
"type": "plateau",
"step": step,
"first_half_avg": first_half,
"second_half_avg": second_half,
})
self.loss_history.append(loss)
self.grad_norm_history.append(grad_norm)
return alerts
def suggest_mitigation(self, alert_type: str) -> str:
"""Return mitigation strategy for each alert type."""
mitigations = {
"loss_spike": (
"1. Check data batch - may have corrupted examples or outlier documents\n"
"2. Try reducing learning rate by 10x for 100 steps, then restore\n"
"3. Check if spike correlates with data source transitions\n"
"4. If using µP (maximal update parameterization), check LR scaling"
),
"gradient_explosion": (
"1. Verify gradient clipping is enabled (max_norm=1.0)\n"
"2. Reduce learning rate by 2-5x\n"
"3. Check for NaN/Inf in inputs (corrupted data)\n"
"4. If using BF16, consider switching to FP32 for the embedding layer"
),
"plateau": (
"1. Learning rate may be too low - try increasing by 2x\n"
"2. Check if learning rate schedule has decayed too aggressively\n"
"3. Model may be undertrained - verify you're at the Chinchilla-optimal point\n"
"4. Data may be exhausted - check dataset iteration"
),
}
return mitigations.get(alert_type, "Unknown alert type")
Common Causes of Loss Spikes
During large-scale pretraining, loss spikes at specific steps are often traced to:
-
Batch contamination: A batch that contains corrupted text (malformed encoding, repeated tokens, SQL injection patterns in web crawl data) can cause a spike. Solution: online data quality filtering during training.
-
LR warmup too short: Without sufficient warmup, early gradient updates are too large and can destabilize weight norms. Solution: warmup over at least 1,000 steps for models above 1B parameters.
-
Accumulation precision: With gradient accumulation and BF16, accumulated gradients can have precision errors. Solution: accumulate in FP32.
-
Attention instability: Very long sequences can cause attention logit overflow in FP16. Solution: use Flash Attention 2 (clips attention logits internally) or add QK-normalization (used in Llama 3).
Chinchilla Scaling Laws in Practice
Hoffmann et al. (2022) at DeepMind showed that most large models at the time were undertrained - too many parameters, too few tokens. The Chinchilla optimal ratio: ~20 tokens per parameter. For a 7B model: train on approximately 140B tokens.
def chinchilla_optimal_tokens(num_params: int) -> int:
"""Compute Chinchilla-optimal training tokens for a given model size."""
# Hoffmann et al. (2022): optimal tokens ≈ 20 × params
return 20 * num_params
def chinchilla_optimal_params(num_tokens: int) -> int:
"""Compute Chinchilla-optimal model size for a given token budget."""
return num_tokens // 20
def compute_training_flops(num_params: int, num_tokens: int) -> float:
"""Estimate training FLOPs. Kaplan et al.: ~6 × N × T FLOPs."""
return 6 * num_params * num_tokens
def recommend_training_config(compute_budget_flops: float) -> dict:
"""
Given a FLOP budget, recommend Chinchilla-optimal model size and token count.
Reference: Hoffmann et al. (2022)
"""
# From Chinchilla paper: optimal N = (C / 6 / 20) ^ 0.5
# where C is compute budget in FLOPs
optimal_params = (compute_budget_flops / 6 / 20) ** 0.5
optimal_tokens = 20 * optimal_params
return {
"optimal_params": f"{optimal_params / 1e9:.1f}B",
"optimal_tokens": f"{optimal_tokens / 1e9:.0f}B",
"training_flops": f"{compute_budget_flops:.2e}",
"note": "Chinchilla-optimal assumes training compute is the bottleneck. "
"For inference-heavy deployments, use more tokens than optimal "
"to train a smaller model (LLaMA strategy)."
}
# Examples:
print(recommend_training_config(6e23)) # GPT-3 compute budget
# → ~10B params, ~200B tokens (GPT-3 is 175B params, massively undertrained by this metric)
print(recommend_training_config(1e24)) # LLaMA 2 compute budget
# → ~13B params, ~260B tokens
The LLaMA "deliberate underparameterization" strategy
LLaMA (Touvron et al., 2023) deliberately chose to train a 7B model on 1 trillion tokens - 5x more than Chinchilla-optimal. The logic: inference is cheap and scales with model size, but training is a one-time cost. A smaller model trained on more tokens is cheaper to serve at every request. LLaMA-7B, undertrained by Chinchilla metrics but cheap to serve, outperformed GPT-3 (175B) on most benchmarks. This "efficient inference" philosophy became the dominant approach for open-source models. LLaMA 3 pushed further: 8B model trained on 15T tokens.
Interview Q&A
Q1: Why is BF16 preferred over FP16 for pretraining large language models?
BF16 (bfloat16) and FP16 both use 16 bits, but with different bit allocation. FP16 uses 5 bits for exponent (range: ~±65,504) and 10 bits for mantissa (precision). BF16 uses 8 bits for exponent (same as FP32, range: ~±3.4×10³⁸) and 7 bits for mantissa. The larger exponent range in BF16 means it rarely overflows or underflows - which is the primary cause of NaN in FP16 training. The smaller mantissa means slightly less precision, but gradient noise in stochastic gradient descent dwarfs this loss. Practical result: BF16 training rarely requires loss scaling (mandatory for FP16), rarely produces NaNs, and achieves nearly identical convergence to FP32. FP16 is still common in inference serving (better hardware support on some GPUs).
Q2: What is gradient checkpointing and what is the memory/compute trade-off?
During the backward pass, computing gradients for each layer requires the activations from the forward pass. Normally these activations are all stored in GPU memory simultaneously, which can require as much memory as the model weights themselves. Gradient checkpointing (Chen et al., 2016) saves memory by not storing intermediate activations - instead, it recomputes them during the backward pass by running a partial forward pass again from the nearest checkpointed activation. The trade-off: memory usage drops by roughly sqrt(L) where L is the number of layers, but the forward pass is computed 1.33x times instead of 1x - approximately 33% extra compute. For training large models on memory-constrained hardware, this is almost always worth it.
Q3: Explain data parallelism, tensor parallelism, and pipeline parallelism. When do you use each?
Data parallelism (DP): each GPU holds a full copy of the model and processes different batches. Gradients are averaged across GPUs with all-reduce. Simple to implement, works well when the model fits on a single GPU. Does not reduce per-GPU memory. Tensor parallelism (TP): individual weight matrices are split across GPUs - each GPU computes part of the matrix multiply. Requires high-bandwidth intra-node communication (NVLink). Use when the model doesn't fit on one GPU and all-reduce communication is too slow for DP. Pipeline parallelism (PP): different layers are placed on different GPUs. GPUs process micro-batches in a pipeline. Introduces pipeline "bubbles" (idle time at the edges of each micro-batch). Use for models with many layers. In practice: DP × TP × PP = 3D parallelism. Megatron-LM uses TP within a node (fast NVLink) and PP across nodes (slower NVLink), with DP across replica groups.
Q4: What is the Chinchilla finding and how did it change how we train models?
Hoffmann et al. (2022) showed that models like GPT-3 (175B params, 300B training tokens) were drastically undertrained. The compute-optimal trade-off: if you double compute, you should increase both model size and training tokens by equal proportions. The practical rule: train on roughly 20 tokens per parameter. Before Chinchilla, the field focused on maximizing model size. After Chinchilla, it shifted to data efficiency - smaller models trained longer often outperform larger models trained shorter. The LLaMA family went further: deliberately train smaller models on more data than Chinchilla-optimal, because inference cost (serving) scales with model size, not training compute.
Q5: What causes loss spikes during LLM pretraining and how do you recover?
Loss spikes (sudden jumps in training loss) have several common causes: (1) data quality issues - a batch with corrupted or degenerate text; (2) learning rate too high relative to the batch gradient signal; (3) precision issues - FP16 overflow causing NaN propagation; (4) gradient norm explosion during specific sequence types (very long or very repetitive sequences). Recovery strategies: reduce learning rate by 2-5x for 100-500 steps then restore; checkpoint frequently enough to roll back to the step before the spike; add online data quality filtering to remove degenerate batches; enable gradient norm clipping (max_norm=1.0). The Llama 3 paper (2024) describes handling several significant training disruptions during their 15T token run, including hardware failures, loss spikes, and data pipeline issues - robust monitoring and automatic recovery were essential.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Scaling Laws: Compute, Data & Parameters demo on the EngineersOfAI Playground - no code required.
:::
