Skip to main content

Storage IO for Training Pipelines

Reading time: ~35 min - Interview relevance: High - Target roles: ML Infrastructure, MLOps Engineer, ML Platform

The Production Scenario: 30% GPU Utilization on a $40,000 Server

The GPU utilization dashboard shows 30%. You stare at it for a moment, convinced something is wrong with the metric. It is not. Your eight A100s, each capable of 312 TFLOPS of bfloat16 matrix math, are spending 70% of their time idle. Waiting.

You check the model. Fine. You check the network. NCCL collectives are fast, NVLink is at full bandwidth during the brief moments the GPUs are actually computing. You check memory. Plenty of VRAM headroom. You run iostat -x 1 on the training server. The disk IO column makes your stomach drop: %util is at 100% for every storage device, read throughput is at 650 MB/s, and your training script is in a tight loop of read() syscalls.

The culprit is the data loading pipeline. Your dataset - 14 million ImageNet-scale images stored as individual JPEG files - is sitting on a network-attached storage volume with a spinning disk array behind it. The metadata server is overwhelmed by millions of stat() calls checking file existence. The actual image data is streaming in at a rate that cannot keep pace with the GPU's appetite. Eight H100s need roughly 5,000 images per second to stay busy at a batch size of 2048. Your storage is delivering 200 images per second.

This scenario is not rare. It is the default outcome when an ML engineer spins up a powerful GPU cluster without treating the storage subsystem as a first-class engineering concern. The GPU is the star of every spec sheet, but storage IO is the invisible tax that determines whether you actually use it.

The problem compounds as model sizes grow. Pre-training a large language model on a trillion tokens requires reading every byte of that data one or more times. A trillion tokens of text in UTF-8 is roughly 600 GB to a few terabytes depending on tokenization. At 7 GB/s per NVMe drive, you need real infrastructure to read that fast enough to feed a 1024-GPU cluster. At 80% utilization instead of 95% utilization, you are paying for 20% more compute time to finish the same job - or roughly $100,000 extra for a month-long training run.

Understanding storage IO for ML is not about knowing the specs of every filesystem. It is about building the intuition to identify where time goes, what the actual hardware constraints are, and how to design data pipelines that eliminate IO as the bottleneck. That is what this lesson covers.


Why This Exists - The Impedance Mismatch Between Storage and Compute

The Speed Gap

Modern GPU clusters have an extreme mismatch between compute throughput and storage bandwidth:

  • A single H100 GPU: 3.35 TB/s GPU memory bandwidth, 989 TFLOPS FP16
  • A single NVMe SSD: 7 GB/s sequential read, ~1M random IOPS
  • Ratio: GPU memory bandwidth is roughly 480x faster than a single NVMe

A training server with 8 H100s needs approximately:

  • For image classification (ResNet-50, batch 2048): 5,000 images/s x 150 KB avg = 750 MB/s continuous read
  • For LLM pre-training (sequence length 4096, batch 512): 2M tokens/s = roughly 2 GB/s continuous read
  • For video training (224x224x16 frames): 10,000 clip sequences/s = 5 GB/s continuous read

One NVMe at 7 GB/s can theoretically handle these rates. But "theoretical" storage specs are not production storage reality.

Where Theoretical Bandwidth Goes

The gap between "7 GB/s NVMe" and actual training throughput has several sources:

Small file access patterns. Deep learning datasets are often stored as millions of individual files. An ImageNet dataset has 1.28 million JPEG files. Accessing them individually means one open(), stat(), read(), close() system call sequence per sample. Each sequence involves multiple filesystem metadata operations. On local NVMe, this might hit 10 GB/s for large sequential reads but only 500-800 MB/s for small random files.

Shuffling requirements. Training requires random sampling from the full dataset each epoch. This means the access pattern is completely random across the entire dataset size. For a 1 TB dataset on spinning disk, random IO throughput is determined by seek time (5-10 ms per seek), limiting effective throughput to 10-20 MB/s even though sequential reads would achieve 200+ MB/s. SSD random IO is better but still far below sequential.

Network filesystem overhead. NFS, Lustre, and other distributed filesystems add metadata server round trips for each file operation. Accessing one million files can mean one million metadata server requests. At 1 ms per request, that is 1000 seconds just for metadata.

CPU preprocessing bottlenecks. JPEG decoding, image augmentation (random crop, flip, color jitter), and tokenization all run on CPU. For a typical vision training job, CPU preprocessing can consume 3-6 threads per GPU worth of CPU cores. An 8-GPU training server running 48 preprocessing workers may still find the workers are slower than the GPUs can consume data.

The Solution Space

The ML community solved the storage IO problem through several complementary approaches that we will cover in depth:

  1. Containerized data formats - pack many samples into large sequential files (WebDataset, FFCV, TFRecord)
  2. CPU preprocessing offload - move augmentation to GPU (DALI)
  3. Aggressive prefetching - overlap IO with compute using DataLoader workers
  4. Distributed storage - scale-out file systems that provide enough aggregate bandwidth
  5. Asynchronous checkpointing - prevent checkpoint writes from stalling compute

Historical Context - From ImageFolder to WebDataset

The ImageFolder Era (2012-2016)

When AlexNet won ImageNet in 2012, the standard data loading pattern was torchvision.datasets.ImageFolder: a directory full of subdirectories, one per class, each containing thousands of individual JPEG files. This was intuitive and required no preprocessing. It also worked fine when you had one GPU and a local SSD.

As GPU counts and dataset sizes grew, the ImageFolder pattern hit walls. The file system metadata overhead grew linearly with file count. Shuffling millions of files on spinning disk was catastrophically slow. Network filesystems struggled with millions of small file operations.

The TFRecord Pattern (2016-2018)

Google's TensorFlow team recognized this problem and introduced TFRecord: a simple binary file format where multiple samples are packed sequentially. Instead of one file per sample, you have hundreds of shards, each containing thousands of samples. Sequential reads from shards are fast. Distributing shards across multiple storage nodes is natural.

TFRecord solved the small-file problem but had a significant weakness: the format was TensorFlow-specific, required schema definition with protobuf, and was not easily inspectable or compatible with other tools.

WebDataset (2019-present)

Chris Ré's group at Stanford and the open-source community developed WebDataset: pack samples into .tar archives (called shards), store shards on any storage system (local, NFS, S3), and stream from them sequentially. The key insight was using the standard tar format - any tool can read the shards, any storage system can host them, and the sequential read pattern is maximally friendly to every storage technology.

WebDataset is now the de facto standard for large-scale ML training data storage. It is used internally at most major ML research organizations.

DALI - GPU Data Preprocessing (2018-present)

NVIDIA's DALI (Data Augmentation Library for Deep Learning) addressed the CPU preprocessing bottleneck. JPEG decoding, random cropping, color jitter, normalization - all of it runs on the GPU using dedicated hardware JPEG decoders and CUDA kernels. This frees CPU cores from preprocessing duty and eliminates the copy from CPU-pinned memory to GPU that the standard PyTorch path requires.


Core Concepts: NVMe Characteristics That Matter for ML

Sequential vs. Random IO

NVMe SSDs have very different performance profiles for sequential and random access:

Access PatternThroughputNotes
Sequential read (large files)5-7 GB/sPCIe 4.0 x4 NVMe limit
Sequential write4-6 GB/sSlightly slower than read
Random read 4K0.5-1.5 GB/sIOPS limited
Random read 4K (queue depth 1)50-100 MB/sLatency limited

The key insight: NVMe achieves its rated bandwidth only for sequential access with deep IO queues. A dataset of individual small files accessed in random order - the standard ImageFolder pattern - achieves closer to "queue depth 1 random read" performance, which is 50-70x slower than sequential reads.

Queue depth is crucial. NVMe controllers can service multiple IO requests concurrently. At queue depth 32 (32 outstanding IO requests), a modern NVMe delivers near peak throughput. At queue depth 1 (one request at a time), you are bound by per-request latency, not bandwidth.

PyTorch's DataLoader with num_workers=8 implicitly creates 8 workers that each issue IO requests independently - this provides queue depth through concurrency. This is one reason increasing num_workers often improves throughput even when the raw CPU preprocessing is not the bottleneck.

The IOPS Budget for Training

Let's compute the IOPS budget for a concrete example. Training ResNet-50 on ImageNet:

  • Batch size: 256 images per GPU, 8 GPUs = 2048 images per training step
  • Step time at high GPU utilization: ~200ms per step
  • Required sample throughput: 2048 / 0.2 = 10,240 images/second
  • Average ImageNet JPEG size: ~130 KB
  • Required read throughput: 10,240 x 130 KB = 1.33 GB/s
  • Required IOPS: 10,240 IOPS (one IO per file)

A modern NVMe at 700,000 random IOPS has enormous headroom for 10,240 IOPS. The bottleneck is not IOPS - it is the filesystem metadata overhead and the single-threaded Python DataLoader worker overhead. This is why switching to WebDataset (which reads large sequential shards rather than individual files) often provides 3-5x throughput improvement even though the underlying storage device is the same.


Math: Storage Bandwidth Required for Different Workloads

Let's formalize the storage bandwidth calculation.

For a training job with:

  • NGPUN_{\text{GPU}} GPUs
  • BB local batch size per GPU
  • tstept_{\text{step}} seconds per training step
  • SS average sample size in bytes

Required storage read bandwidth:

BWrequired=NGPU×B×Ststep\text{BW}_{\text{required}} = \frac{N_{\text{GPU}} \times B \times S}{t_{\text{step}}}

For LLM pre-training with sequence packing:

  • NGPU=512N_{\text{GPU}} = 512 GPUs across 64 nodes
  • B=4B = 4 sequences per GPU (global batch = 2048)
  • Sequence length = 4096 tokens, 2 bytes per token (int16) = 8 KB per sequence
  • tstep=2.0t_{\text{step}} = 2.0 seconds (for a 7B parameter model)

BWrequired=512×4×81922.0=838 MB/s\text{BW}_{\text{required}} = \frac{512 \times 4 \times 8192}{2.0} = 838 \text{ MB/s}

This is achievable with 2-3 local NVMe drives on the central data-serving nodes, or roughly 10 standard NFS server hard disk drives. For larger clusters or video datasets, the numbers scale proportionally.

The more important insight is the compute-to-storage ratio: if your GPUs are faster (newer hardware, smaller batch) or your samples are larger (video, high-res images), the required bandwidth grows proportionally. Profiling IO usage at the beginning of a new training setup - before scaling to hundreds of nodes - is essential.


Mermaid Diagram: Data Pipeline Stages and Bottleneck Points


Code: Profiling IO Before Optimizing

Never optimize without measuring. These tools tell you exactly where time is going.

import subprocess
import time
import threading
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms

def profile_dataloader_throughput(
dataset_path: str,
batch_size: int = 256,
num_workers: int = 8,
prefetch_factor: int = 2,
max_batches: int = 100
) -> dict:
"""
Profile actual data loading throughput.
Returns samples/sec, MB/sec, and bottleneck indicator.
"""
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder(dataset_path, transform=transform)

loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
persistent_workers=True,
)

start = time.perf_counter()
total_samples = 0
total_bytes = 0
batch_times = []

for i, (images, labels) in enumerate(loader):
batch_start = time.perf_counter()

# Move to GPU - this is what the training loop would do
images = images.cuda(non_blocking=True)

batch_end = time.perf_counter()
batch_times.append(batch_end - batch_start)

total_samples += images.shape[0]
# float32 x 3 x 224 x 224
total_bytes += images.numel() * 4

if i >= max_batches:
break

elapsed = time.perf_counter() - start
samples_per_sec = total_samples / elapsed
gb_per_sec = total_bytes / elapsed / 1e9

avg_batch_ms = sum(batch_times) / len(batch_times) * 1000
p99_batch_ms = sorted(batch_times)[int(len(batch_times) * 0.99)] * 1000

print(f"=== DataLoader Throughput Report ===")
print(f"Configuration: workers={num_workers}, prefetch={prefetch_factor}")
print(f"Throughput: {samples_per_sec:.0f} samples/sec")
print(f"Throughput: {gb_per_sec:.2f} GB/sec (decoded tensors)")
print(f"Avg batch load time: {avg_batch_ms:.1f} ms")
print(f"P99 batch load time: {p99_batch_ms:.1f} ms")

# Compute a rough GPU stall estimate
# Assume GPU can process a batch in ~100ms (typical for ResNet-50)
gpu_batch_ms = 100
stall_fraction = max(0, avg_batch_ms - gpu_batch_ms) / avg_batch_ms
if stall_fraction > 0.05:
print(f"WARNING: IO likely stalling GPU ({stall_fraction:.0%} of batch time is IO)")
print("Consider: increasing num_workers, switching to WebDataset, using DALI")
else:
print("IO is not the primary bottleneck at this configuration")

return {
"samples_per_sec": samples_per_sec,
"gb_per_sec": gb_per_sec,
"avg_batch_ms": avg_batch_ms,
"p99_batch_ms": p99_batch_ms,
}


def monitor_disk_io_during_training(interval_sec: float = 5.0, duration_sec: float = 60.0):
"""
Run iostat in background and report disk utilization.
Tells you whether storage is saturated during training.
"""
print(f"Monitoring disk IO for {duration_sec}s (interval {interval_sec}s)...")
print("Run your training script now.")
print()

try:
result = subprocess.run(
["iostat", "-x", str(interval_sec), str(int(duration_sec / interval_sec))],
capture_output=True,
text=True,
timeout=duration_sec + 10
)
print(result.stdout)
except FileNotFoundError:
print("iostat not available. Install with: apt-get install sysstat")
print()
print("Alternative: use /proc/diskstats parsing")
with open("/proc/diskstats", "r") as f:
print(f.read()[:500])

Code: WebDataset for High-Throughput Data Loading

WebDataset is the standard solution for high-throughput training data. Here is how to set it up properly.

import webdataset as wds
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import io
from PIL import Image

# === Step 1: Create WebDataset shards from an existing dataset ===

def create_webdataset_shards(
image_dir: str,
output_pattern: str,
samples_per_shard: int = 1000,
max_shards: int = None
):
"""
Convert an ImageFolder-style dataset to WebDataset tar shards.

output_pattern: e.g. "dataset/train-%06d.tar"
Creates shards like: dataset/train-000000.tar, train-000001.tar, etc.
"""
import os
import glob
import tarfile
import json

# Find all image files
extensions = ('.jpg', '.jpeg', '.png', '.JPEG')
image_files = []
for root, dirs, files in os.walk(image_dir):
for f in files:
if any(f.endswith(ext) for ext in extensions):
label = os.path.basename(root)
image_files.append((os.path.join(root, f), label))

print(f"Found {len(image_files)} images")

# Build label-to-index mapping
labels = sorted(set(label for _, label in image_files))
label_to_idx = {label: idx for idx, label in enumerate(labels)}

# Write shards
shard_idx = 0
sample_idx = 0
tar = None
output_path = None

for img_path, label in image_files:
if sample_idx % samples_per_shard == 0:
if tar is not None:
tar.close()
print(f"Wrote shard: {output_path}")
output_path = output_pattern % shard_idx
tar = tarfile.open(output_path, "w")
shard_idx += 1

if max_shards and shard_idx > max_shards:
break

# Each sample gets a key like "000001234"
key = f"{sample_idx:09d}"

# Add image file
with open(img_path, "rb") as f:
img_data = f.read()
img_info = tarfile.TarInfo(name=f"{key}.jpg")
img_info.size = len(img_data)
tar.addfile(img_info, io.BytesIO(img_data))

# Add label as JSON
label_json = json.dumps({"cls": label_to_idx[label]}).encode()
label_info = tarfile.TarInfo(name=f"{key}.json")
label_info.size = len(label_json)
tar.addfile(label_info, io.BytesIO(label_json))

sample_idx += 1

if tar:
tar.close()

print(f"Created {shard_idx} shards with {sample_idx} total samples")


# === Step 2: Load with WebDataset ===

def create_webdataset_loader(
shard_pattern: str,
batch_size: int = 256,
num_workers: int = 8,
shuffle_buffer: int = 5000,
is_training: bool = True
) -> DataLoader:
"""
Create a high-throughput DataLoader using WebDataset.

shard_pattern: glob pattern like "s3://bucket/train-{000000..000999}.tar"
or local path "data/train-*.tar"
"""
import json

# Augmentation pipeline
if is_training:
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
else:
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def decode_sample(sample):
"""Convert raw bytes from tar to tensor + label."""
img_bytes = sample["jpg"]
json_bytes = sample["json"]

# Decode image
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
img_tensor = transform(img)

# Decode label
label_data = json.loads(json_bytes.decode())
label = label_data["cls"]

return img_tensor, label

# Build dataset pipeline
dataset = (
wds.WebDataset(shard_pattern, resampled=is_training)
.shuffle(shuffle_buffer if is_training else 0)
.decode("pil") # decode JPEGs to PIL images
.to_tuple("jpg", "json") # extract specific keys
.map_tuple(
lambda x: transform(x), # transform image
lambda x: json.loads(x.decode())["cls"] # extract label
)
.batched(batch_size, partial=not is_training)
)

loader = DataLoader(
dataset,
batch_size=None, # batching done in WebDataset pipeline
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
)

return loader


# === Step 3: Streaming from S3 ===

def create_s3_streaming_loader(
bucket: str,
prefix: str,
batch_size: int = 256,
num_workers: int = 8,
num_shards: int = 1000
) -> DataLoader:
"""
Stream WebDataset shards directly from S3.
Requires AWS credentials and webdataset >= 0.2.0
"""
import json

# Generate S3 URLs for all shards
shard_urls = [
f"s3://{bucket}/{prefix}/shard-{i:06d}.tar"
for i in range(num_shards)
]

# WebDataset handles S3 URLs transparently via boto3 or pipe:
# For large-scale training, use "pipe:aws s3 cp s3://... -" pattern
# which streams via subprocess and avoids Python's GIL on S3 reads
pipe_pattern = f"pipe:aws s3 cp s3://{bucket}/{prefix}/" + "{}.tar -"

transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

dataset = (
wds.WebDataset(shard_urls, resampled=True)
.shuffle(5000)
.decode("pil")
.to_tuple("jpg", "cls")
.map_tuple(transform, int)
.batched(batch_size)
)

return DataLoader(
dataset,
batch_size=None,
num_workers=num_workers,
pin_memory=True,
)

Code: DALI for GPU-Side Data Preprocessing

NVIDIA DALI eliminates the CPU preprocessing bottleneck by running image decode and augmentation on the GPU.

# DALI requires: pip install nvidia-dali-cuda120

try:
import nvidia.dali as dali
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.pipeline import pipeline_def
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy

@pipeline_def
def imagenet_pipeline(data_dir: str, is_training: bool, device_id: int = 0):
"""
DALI pipeline for ImageNet training.
All augmentation runs on GPU after JPEG decode.
Typical speedup vs CPU pipeline: 2-4x for large batch sizes.
"""
# Read files from disk (CPU)
jpegs, labels = fn.readers.file(
file_root=data_dir,
random_shuffle=is_training,
pad_last_batch=not is_training,
name="Reader",
)

# Decode JPEG directly to GPU tensor
# mixed device: decode on CPU, copy to GPU in one step
if is_training:
images = fn.decoders.image_random_crop(
jpegs,
device="mixed",
output_type=types.RGB,
random_aspect_ratio=[0.75, 1.33],
random_area=[0.08, 1.0],
num_attempts=100,
)
images = fn.resize(
images,
device="gpu",
resize_x=224,
resize_y=224,
)
else:
images = fn.decoders.image(
jpegs,
device="mixed",
output_type=types.RGB,
)
images = fn.resize(images, device="gpu", resize_shorter=256)
images = fn.crop(images, device="gpu", crop=[224, 224])

# Augmentation on GPU
if is_training:
images = fn.flip(images, device="gpu",
horizontal=fn.random.coin_flip(probability=0.5))
images = fn.color_twist(
images,
device="gpu",
brightness=fn.random.uniform(range=[0.6, 1.4]),
contrast=fn.random.uniform(range=[0.6, 1.4]),
saturation=fn.random.uniform(range=[0.6, 1.4]),
)

# Normalize to float32, standard ImageNet stats
images = fn.crop_mirror_normalize(
images,
device="gpu",
dtype=types.FLOAT,
output_layout="CHW",
crop=(224, 224),
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
)

labels = labels.gpu()
return images, labels


def create_dali_loader(
data_dir: str,
batch_size: int = 256,
num_threads: int = 4,
device_id: int = 0,
is_training: bool = True,
):
"""
Create a DALI-powered data loader for ImageNet.
Returns a DALIClassificationIterator compatible with PyTorch training loops.
"""
pipe = imagenet_pipeline(
data_dir=data_dir,
is_training=is_training,
batch_size=batch_size,
num_threads=num_threads,
device_id=device_id,
)
pipe.build()

loader = DALIClassificationIterator(
pipe,
reader_name="Reader",
last_batch_policy=LastBatchPolicy.PARTIAL if not is_training else LastBatchPolicy.FILL,
auto_reset=True,
)

return loader

except ImportError:
print("DALI not installed. Install with:")
print(" pip install nvidia-dali-cuda120")
print()
print("DALI key benefits:")
print(" - JPEG decode on GPU (nvjpeg): 5-10x faster than Pillow")
print(" - All augmentation on GPU: frees CPU for other work")
print(" - Zero-copy GPU tensors: no CPU->GPU copy for preprocessed images")
print(" - Async prefetch: pipeline overlap across multiple streams")

Code: PyTorch DataLoader Tuning

Before jumping to DALI or WebDataset, tune the standard PyTorch DataLoader correctly.

import torch
from torch.utils.data import DataLoader, Dataset
import time
import os

def find_optimal_num_workers(
dataset: Dataset,
batch_size: int = 256,
max_workers: int = None,
test_batches: int = 50
) -> int:
"""
Benchmark different num_workers settings to find the optimal value.
Run this once during initial setup of a new training environment.
"""
if max_workers is None:
max_workers = min(os.cpu_count(), 32)

workers_to_test = [0, 1, 2, 4, 8, 16, max_workers]
workers_to_test = [w for w in workers_to_test if w <= max_workers]
workers_to_test = list(set(workers_to_test)) # deduplicate
workers_to_test.sort()

results = {}

for num_workers in workers_to_test:
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
prefetch_factor=2 if num_workers > 0 else None,
persistent_workers=(num_workers > 0),
)

# Warmup
it = iter(loader)
for _ in range(3):
next(it)

# Benchmark
start = time.perf_counter()
for i, batch in enumerate(loader):
if i >= test_batches:
break
# Simulate GPU transfer
if isinstance(batch, (list, tuple)):
_ = batch[0].cuda(non_blocking=True)
else:
_ = batch.cuda(non_blocking=True)

elapsed = time.perf_counter() - start
samples_per_sec = test_batches * batch_size / elapsed

results[num_workers] = samples_per_sec
print(f"num_workers={num_workers:3d}: {samples_per_sec:8.0f} samples/sec")

del loader

best_workers = max(results, key=results.get)
print(f"\nOptimal num_workers: {best_workers}")
return best_workers


# === DataLoader configuration for different scenarios ===

def create_training_loader(dataset, batch_size=256, num_workers=8):
"""Production-ready DataLoader for training."""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True, # Required for non_blocking GPU transfers
prefetch_factor=2, # Each worker prefetches 2 batches ahead
persistent_workers=True, # Keep workers alive between epochs
# Avoids re-spawning workers (saves 10-30s per epoch)
drop_last=True, # Avoid partial batches for cleaner training
)


def create_evaluation_loader(dataset, batch_size=512, num_workers=4):
"""DataLoader for evaluation - larger batches, no shuffle."""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
prefetch_factor=3, # Larger prefetch for sequential access
persistent_workers=True,
drop_last=False,
)


# === Direct IO for bypassing page cache ===

def explain_direct_io():
"""
O_DIRECT bypasses the OS page cache for file reads.
Relevant when:
1. Dataset is too large to fit in RAM (prevents cache pollution)
2. You need predictable IO latency (cache misses cause spikes)
3. You are reading each file exactly once (caching is wasteful)

For ML training, O_DIRECT is often counterproductive because:
- It requires aligned IO (512-byte or 4096-byte aligned buffers)
- It eliminates read-ahead prefetching by the kernel
- PyTorch's DataLoader workers cannot use it easily

The WebDataset approach is better: shard-based sequential reads
naturally benefit from kernel read-ahead without O_DIRECT.
"""
print("O_DIRECT: use for database-style workloads, not ML training.")
print("WebDataset sequential shard reads + kernel read-ahead is better.")

Mermaid Diagram: Distributed Storage Options Comparison


Code: Asynchronous Checkpointing

Synchronous checkpointing stalls training. For a 70B parameter model in bfloat16, saving to disk takes 30-120 seconds depending on the storage system. Asynchronous checkpointing overlaps the save with continued training.

import torch
import torch.nn as nn
import threading
import queue
import time
import os
from pathlib import Path
from typing import Optional, Dict, Any


class AsyncCheckpointManager:
"""
Non-blocking checkpoint manager.

Saves checkpoints in a background thread while training continues.
Guarantees at-least-one valid checkpoint on failure.
"""

def __init__(
self,
checkpoint_dir: str,
max_checkpoints: int = 3,
compression: bool = False
):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints
self.compression = compression

# Queue for passing state dicts to save thread
self._queue: queue.Queue = queue.Queue(maxsize=1)
self._thread: Optional[threading.Thread] = None
self._error: Optional[Exception] = None
self._saved_checkpoints = []
self._lock = threading.Lock()

self._start_save_thread()

def _start_save_thread(self):
"""Start the background save thread."""
self._thread = threading.Thread(
target=self._save_worker,
daemon=True,
name="checkpoint-saver"
)
self._thread.start()

def _save_worker(self):
"""Background thread: drain queue and write checkpoints."""
while True:
item = self._queue.get()
if item is None: # Sentinel for shutdown
break

state_dict, path, metadata = item
try:
save_obj = {"state_dict": state_dict, **metadata}
# Write to temp file first, then rename (atomic on POSIX)
tmp_path = str(path) + ".tmp"
torch.save(save_obj, tmp_path)
os.rename(tmp_path, path)

with self._lock:
self._saved_checkpoints.append(str(path))
# Prune old checkpoints
while len(self._saved_checkpoints) > self.max_checkpoints:
old = self._saved_checkpoints.pop(0)
try:
os.remove(old)
except FileNotFoundError:
pass

print(f"[Checkpoint] Saved: {path}")

except Exception as e:
self._error = e
print(f"[Checkpoint] ERROR saving {path}: {e}")

finally:
self._queue.task_done()

def save(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
step: int,
epoch: int,
extra: Optional[Dict[str, Any]] = None
) -> bool:
"""
Initiate async checkpoint save.

Copies state dict to CPU (fast, ~1-2s for 7B params),
then queues the CPU copy for background writing.

Returns True if save was queued, False if previous save is still in progress.
"""
if self._error:
raise RuntimeError(f"Background save failed: {self._error}")

if self._queue.full():
# Previous save still in progress
print(f"[Checkpoint] Skipping step {step} - previous save in progress")
return False

# Copy state dict to CPU - this happens in the training thread
# but is fast because it is just memory copies (NVLink/PCIe), not disk IO
t0 = time.perf_counter()
cpu_state = {k: v.cpu() for k, v in model.state_dict().items()}
cpu_opt_state = {
"state": {
k: {
ik: iv.cpu() if isinstance(iv, torch.Tensor) else iv
for ik, iv in v.items()
}
for k, v in optimizer.state_dict()["state"].items()
},
"param_groups": optimizer.state_dict()["param_groups"],
}
copy_time = time.perf_counter() - t0

print(f"[Checkpoint] State dict CPU copy: {copy_time:.2f}s")

path = self.checkpoint_dir / f"checkpoint-step{step:09d}.pt"
metadata = {
"step": step,
"epoch": epoch,
"optimizer_state": cpu_opt_state,
}
if extra:
metadata.update(extra)

# Non-blocking put (queue size 1, so this should succeed since we checked)
self._queue.put_nowait((cpu_state, path, metadata))
return True

def wait(self):
"""Block until all pending saves complete."""
self._queue.join()

def shutdown(self):
"""Graceful shutdown: finish pending saves then stop thread."""
self.wait()
self._queue.put(None) # Sentinel
self._thread.join()


# === Integration example: training loop with async checkpointing ===

def training_loop_with_async_checkpoint(
model: nn.Module,
optimizer: torch.optim.Optimizer,
train_loader,
checkpoint_dir: str,
checkpoint_every_n_steps: int = 1000,
total_steps: int = 100000
):
"""
Training loop that checkpoints without stalling GPU work.
"""
checkpointer = AsyncCheckpointManager(
checkpoint_dir=checkpoint_dir,
max_checkpoints=3,
)

model.train()
step = 0
epoch = 0

try:
for epoch in range(1000):
for batch in train_loader:
inputs, labels = batch
inputs = inputs.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)

# Forward pass
optimizer.zero_grad()
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, labels)

# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()

step += 1

if step % checkpoint_every_n_steps == 0:
# This takes ~1-2s for state dict copy, then returns
# The actual disk write happens in background
checkpointer.save(
model=model,
optimizer=optimizer,
step=step,
epoch=epoch,
extra={"loss": loss.item()},
)

if step >= total_steps:
break

if step >= total_steps:
break

finally:
# Ensure final checkpoint is written before exiting
print("Waiting for final checkpoint to complete...")
checkpointer.wait()
checkpointer.shutdown()

Production Engineering Notes

Choosing Storage for Different Scenarios

Local NVMe (single node or small cluster) - Best absolute performance. Each training node reads from its own local drives. No network contention. The problem is data distribution: when you have 64 training nodes, each needs its own copy of the dataset, or you need complex shard distribution logic. Local NVMe shines for single-node training or when you can replicate data cheaply.

NFS (Network File System) - The default for small academic clusters. Easy to set up. Has serious scalability problems for ML: the NFS server becomes a bottleneck at scale, metadata operations (open, stat, readdir) are particularly slow for datasets with millions of files, and throughput typically tops out at 200-500 MB/s per client due to protocol overhead. Acceptable for datasets under 100 GB and clusters under 8 nodes. Not suitable for production-scale training.

Lustre / GPFS - Purpose-built distributed filesystems for HPC and ML clusters. Lustre separates metadata servers (MDS) from object storage servers (OSS), allowing aggregate bandwidth of hundreds of GB/s across a cluster. GPFS (IBM Spectrum Scale) provides similar capabilities with better commercial support. Most large cloud-based ML training (AWS SageMaker, Azure ML, Google Cloud AI) uses variants of these systems under the hood. Requires expertise to configure correctly - stripe count, stripe size, and client tuning are all important.

S3 / GCS (Object Storage) - Infinite scale, pay-per-use, no maintenance. The latency is high (20-100ms per request) but throughput scales with the number of parallel readers. The key is using WebDataset shards: instead of millions of random file requests, you make hundreds of sequential shard reads. A well-designed WebDataset pipeline can sustain 2-5 GB/s from S3 per node with proper parallelism.

Checkpoint Frequency Tradeoffs

Checkpointing too rarely means more work lost on failure. Checkpointing too often means more storage IO overhead and potential stalls if async checkpointing falls behind.

A practical rule: checkpoint at most once per hour of training time per node. For a 1024-GPU training job running at full speed, checkpointing a 70B parameter model takes approximately:

  • State dict copy to CPU: 15-30 seconds (dominated by PCIe bandwidth, 280 GB x 4 = 1.12 TB transferred)
  • Serialization and write to Lustre at 10 GB/s: ~14 seconds for the CPU copy itself

With async checkpointing, the CPU copy is the training-blocking portion. 15-30 seconds every hour is a 0.4-0.8% overhead - acceptable.

FFCV - The Alternative to WebDataset for Vision

FFCV (Fast Forward Computer Vision) takes a different approach than WebDataset. Instead of tar files, it uses a custom binary format that is designed for random access with a small index. This allows true random shuffling within a shard file without the shuffle buffer approximation that WebDataset requires. FFCV can achieve peak training throughput for vision models that exceeds even DALI in some benchmarks by pre-computing augmentations into a compact binary format.

The tradeoff: FFCV requires converting your dataset to the FFCV format upfront (a few hours for ImageNet) and is vision-specific. It does not generalize to text, audio, or video datasets the way WebDataset does.


Common Mistakes

:::danger Using num_workers=0 in production training

num_workers=0 means data loading happens in the main Python process, synchronously blocking the training loop. The GPU sits idle while the CPU loads and processes the next batch. For any dataset with more than trivial preprocessing, this is a significant slowdown - often 3-10x slower than with proper worker configuration.

The default should be num_workers=max(4, os.cpu_count() // 2) as a starting point, then benchmark to find the actual optimum.

# WRONG: synchronous, kills GPU utilization
loader = DataLoader(dataset, batch_size=256, num_workers=0)

# RIGHT: async workers overlap IO with GPU compute
loader = DataLoader(
dataset,
batch_size=256,
num_workers=8,
prefetch_factor=2,
pin_memory=True,
persistent_workers=True,
)

:::

:::danger Saving checkpoints synchronously in a distributed job

In a distributed training job (DDP or FSDP), if rank 0 saves a checkpoint synchronously while other ranks wait at a barrier, you have just stalled your entire cluster. The effective checkpoint overhead is: (model size) / (storage bandwidth) multiplied by the number of ranks that are blocked waiting.

For a 512-GPU job checkpointing a 70B model, 30 seconds of checkpoint IO stalls 512 GPUs simultaneously. That is 512 x 30 = 4.3 GPU-hours wasted per checkpoint. At 3/GPUhouronacloudcluster,eachcheckpointcosts 3/GPU-hour on a cloud cluster, each checkpoint costs ~13.

Always use async checkpointing that lets training continue while the checkpoint serializes. :::

:::warning Storing datasets as millions of individual files

Datasets stored as individual files (one JPEG per sample, one .npy per tensor) work fine for quick experiments on a single GPU. They fail in production for two reasons: (1) filesystem metadata overhead scales linearly with file count - a directory with 10M files takes minutes to list and seconds to stat-check, and (2) random access to individual files cannot saturate storage bandwidth because each access is a separate IO operation with its own overhead.

Convert to WebDataset shards (1000-5000 samples per shard) before any serious training run. The conversion is a one-time cost that pays for itself in the first training job. :::

:::warning Forgetting to pin memory when using GPU training

Without pin_memory=True in the DataLoader, CPU tensors live in pageable memory. When CUDA performs an H2D (host-to-device) transfer, it must first copy the data from pageable memory to pinned memory internally - a double copy. With pin_memory=True, the DataLoader allocates CPU tensors directly in pinned (page-locked) memory, allowing direct DMA transfer to GPU. For large batch sizes and high throughput pipelines, the difference in H2D transfer speed is 2-3x.

Always use pin_memory=True when training on GPU. The memory cost (pinned memory is more expensive to allocate) is trivial compared to the transfer speedup. :::


Interview Questions and Answers

Q1: Explain why storing a dataset as individual JPEG files can cause low GPU utilization, even when the total data size fits comfortably in a fast SSD.

The problem is access pattern, not capacity. Modern NVMe SSDs achieve their rated 7 GB/s throughput only for sequential reads with deep IO queues. Accessing millions of individual files means millions of separate IO operations, each with open/stat/read/close syscall overhead, each requiring separate kernel scheduling. At deep queue depths, NVMe delivers this fast. But PyTorch's DataLoader default configuration issues one file read per worker sequentially, which is low queue depth random access - 10-50x slower than sequential reads.

The solution is containerized formats (WebDataset, FFCV, TFRecord) that pack thousands of samples into large sequential binary files. The GPU never sees the difference - it still gets JPEG-decoded tensors. But the storage system now sees sequential reads of 100 MB shard files instead of thousands of individual 150 KB JPEG reads.

Q2: A training job uses 8x A100 GPUs and achieves 95% GPU utilization in the first epoch but drops to 60% in later epochs. The model and batch size have not changed. What is the likely cause and how do you diagnose it?

The most likely cause is filesystem caching effects. In the first epoch, the OS page cache fills with recently read files. If the dataset fits or partially fits in RAM, subsequent epoch starts see many cache hits - fast. But as training progresses through different parts of the dataset or if the cache starts evicting old entries (because the dataset is larger than RAM), cache misses increase and storage IO becomes the bottleneck.

To diagnose: run iostat -x 1 during training and compare IO utilization between early and late epochs. Check free -h to see available RAM and cache usage. Monitor DataLoader worker queue depth and stall time.

The fix depends on the cause: if the dataset fits in RAM, explicitly warm the page cache before training by pre-reading all files. If the dataset is too large, switch to WebDataset shards which have more predictable sequential access patterns that the kernel read-ahead subsystem handles better than random file access.

Q3: What is prefetch_factor in PyTorch's DataLoader and how does setting it too high or too low affect performance?

prefetch_factor=N means each DataLoader worker prepares N batches ahead of what the training loop is currently consuming. With 8 workers and prefetch_factor=2, there are up to 16 batches being prepared in parallel at any time.

Too low (prefetch_factor=1 or 0): the GPU may finish its current batch before the next one is ready, causing idle GPU cycles. The training loop blocks waiting for data.

Too high: each prefetched batch occupies pinned CPU memory. With prefetch_factor=4, batch_size=256, 224x224 float32 images, each worker holds 4 batches x 256 x 3 x 224 x 224 x 4 bytes = ~5 GB of pinned memory per worker. With 8 workers, that is 40 GB of pinned memory - potentially exhausting system RAM or causing memory pressure that slows the system.

The practical optimum is usually prefetch_factor=2 for GPU training, occasionally 3-4 for slower storage systems. Start at 2 and increase only if you see the GPU stalling on batch retrieval.

Q4: Describe the difference between synchronous and asynchronous checkpointing. In a 512-GPU distributed training job, why does the choice matter significantly?

Synchronous checkpointing: rank 0 (or all ranks) call torch.save() which blocks until the write completes. In a distributed job, other ranks hit a barrier waiting for rank 0 to finish saving. Training is stopped for the entire duration of the checkpoint write.

Asynchronous checkpointing: the state dict is copied from GPU to CPU memory (fast, a few seconds over NVLink/PCIe), then a background thread writes to disk while training continues. The GPU-blocking portion is only the state dict copy.

For a 512-GPU job, a 60-second synchronous checkpoint stalls 512 GPUs simultaneously. At a typical A100 cost of 3/GPUhour,thatis512x(60/3600)x3/GPU-hour, that is 512 x (60/3600) x 3 = 25.60percheckpoint.Checkpointingevery15minutes=4checkpoints/hour=25.60 per checkpoint. Checkpointing every 15 minutes = 4 checkpoints/hour = 102/hour wasted on checkpoint IO. Over a week-long training run, that is $17,000 in wasted compute purely from synchronous checkpointing overhead. Async checkpointing brings this cost to near zero.

Q5: You are designing a data pipeline for training a large language model on 2 trillion tokens of text. The data is stored in S3. How do you structure the storage and data loading to feed a 512-GPU cluster without IO bottlenecks?

Start with the bandwidth math: 512 GPUs, assume global batch of 2M tokens, step time of ~3 seconds for a 7B parameter model = 667M tokens/second = ~1.3 GB/s of raw text. This is easily achievable from S3 with proper organization.

Data organization: convert the 2 trillion tokens to WebDataset tar shards of 100-500 MB each (containing pre-tokenized sequences as binary arrays). Store shards in S3 with a flat prefix structure and randomize shard order during each training run.

Per-node loading: each of the 64 training nodes (8 GPUs each) reads from disjoint shard subsets using WebDataset's resampled=True mode with per-node shard partitioning. Each node needs ~20 MB/s of read throughput (1.3 GB/s / 64 nodes) - trivially achievable from S3.

DataLoader configuration: 4-8 workers per node, prefetch_factor=3, pin_memory=True. Each worker streams one shard sequentially - this matches S3's optimal access pattern (large sequential reads, not random access).

The critical detail is pre-tokenization: if you stream raw text and tokenize on the fly, the tokenizer (especially BPE) becomes the CPU bottleneck. Pre-tokenize once, store as int16 token sequences in the shards. At training time you just read binary arrays - no decode, no tokenization.

Q6: What is the iostat output you should look for to determine whether storage IO is the training bottleneck, and what does each metric mean?

Run iostat -x 1 to get extended statistics every second. The key columns:

%util - percentage of time the device was busy servicing requests. If this is at 100%, the device is saturated and is the bottleneck. Values below 80% generally indicate the device is not the constraint.

r_await and w_await - average latency in milliseconds for read and write requests. For NVMe this should be under 1ms for reads. High latency (10ms+) indicates either queue depth issues or storage system congestion.

rMB/s and wMB/s - actual read and write throughput. Compare against the device's rated sequential throughput to see how efficiently it is being used.

rkB/s per r/s - effective average read request size. If this is 4KB (the typical IOPS benchmark size), you are doing random small-file access. If it is 512KB+, you are doing efficient sequential streaming. For WebDataset shards, you should see large average request sizes.

During training, you want to see: %util well below 100% (meaning storage is not the constraint), rMB/s at a sustainable rate matching your required throughput calculation, and large average read sizes indicating sequential access patterns.


Summary

Storage IO is the invisible tax on GPU utilization. The gap between a well-optimized data pipeline and a naive one is often 3-5x in effective training throughput - which translates directly to training time and compute cost.

The three highest-impact changes for most teams, in order:

  1. Switch from individual files to WebDataset shards. This single change typically doubles throughput for vision and text workloads by converting random small-file IO to sequential large-file streaming.

  2. Set num_workers, pin_memory=True, and persistent_workers=True correctly. Too few workers causes GPU stalls. Proper worker configuration often provides another 2-3x improvement over default settings.

  3. Use asynchronous checkpointing. For runs longer than a day, synchronous checkpointing wastes a measurable percentage of your compute budget waiting for disk writes.

DALI provides additional gains for compute-intensive vision workloads where CPU preprocessing is the bottleneck. S3 streaming with WebDataset is the right architecture for cloud-native training where data lives in object storage.

Measure before optimizing. iostat -x 1 and torch.utils.data.DataLoader throughput benchmarks give you the data needed to identify where time actually goes. Intuition about storage bottlenecks is often wrong - measure first.

© 2026 EngineersOfAI. All rights reserved.