Skip to main content

PyTorch DataLoaders and Datasets

It is 2021. The Tesla Autopilot team is training a new perception model for lane detection and object recognition. The training set: 1.4 million labeled video clips, each containing 8 camera angles, at 36 frames per clip. The raw data is 180 TB. The model is running on 64 A100 GPUs, and there is an obvious budget pressure - at cloud GPU prices, every hour of training is thousands of dollars.

The first profiling run reveals something alarming: GPU utilization is averaging 40%. The A100 is doing nothing for 60% of the time it is being billed. A PyTorch profiler trace shows the training step takes 8ms, but the data loading gap between batches is 12ms. The GPUs are starved. They are sitting idle, waiting for the next batch to arrive from CPU memory.

The data loading pipeline is a standard map-style Dataset reading JPEG files from local NVMe. The num_workers value: 4. The images are being decoded one at a time by PIL, resized, normalized, and then assembled into a batch. This works fine at small scale. At 1.4 million samples and 64 GPUs, it is a bottleneck that costs millions per training run.

The solution involves three changes: moving from individual JPEG files to HDF5 archives (sequential reads instead of random seeks), increasing num_workers to 16 per GPU, and adding prefetch_factor=4 to pipeline I/O with computation. GPU utilization rises to 92%. Training cost drops by 30%. The model, the data, and the hardware are all the same. The data pipeline was the only thing that changed.

A training job is running at 40% GPU utilization. The GPU should be at 95%+. You open the profiler. The training step takes 12ms, but the data loading between batches takes 30ms. The DataLoader has num_workers=0 - all data loading is happening on the main Python thread, starving the GPU.

Adding num_workers=4 and pin_memory=True brings GPU utilization to 92% and cuts total training time by 60%. The model did not change. The hardware did not change. Data pipeline efficiency is often the bottleneck.

The Dataset Abstraction

Map-Style Datasets: The Contract

A PyTorch Dataset must implement two methods:

  • __len__: returns the number of samples (integer)
  • __getitem__: returns a single sample given an integer index

The DataLoader calls these methods to assemble batches. The contract is simple, but the implementation decisions are not.

import torch
from torch.utils.data import Dataset
import numpy as np

class TabularDataset(Dataset):
"""Simple dataset for tabular data stored as NumPy arrays."""

def __init__(self, X: np.ndarray, y: np.ndarray):
# Store as float32 tensors immediately (eager loading)
self.X = torch.tensor(X, dtype=torch.float32)
self.y = torch.tensor(y, dtype=torch.long)

def __len__(self):
return len(self.X)

def __getitem__(self, idx):
return self.X[idx], self.y[idx]


# Usage
X = np.random.randn(10_000, 20).astype(np.float32)
y = np.random.randint(0, 3, size=10_000)

dataset = TabularDataset(X, y)
print(len(dataset)) # 10000
print(dataset[0]) # (tensor of shape (20,), tensor scalar)
print(dataset[0][0].shape) # torch.Size([20])

Lazy Loading vs Eager Loading

Eager loading (shown above): convert all data to tensors in __init__. Fast access in __getitem__, but uses memory proportional to dataset size. Fine for datasets that fit in RAM (typically up to ~10GB).

Lazy loading: store file paths or indices in __init__, and load from disk in __getitem__. Handles arbitrarily large datasets but incurs I/O cost on every access.

class LazyImageDataset(Dataset):
"""Load images from disk on demand - handles datasets larger than RAM."""

def __init__(self, image_paths, labels, transform=None):
self.image_paths = image_paths # just a list of strings
self.labels = labels # in memory (small)
self.transform = transform

def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
from PIL import Image
# I/O happens HERE, in the worker process, not __init__
img = Image.open(self.image_paths[idx]).convert('RGB')
if self.transform:
img = self.transform(img)
return img, self.labels[idx]

Memory Mapping for Large Arrays

np.memmap opens a file as if it were a NumPy array without loading it into RAM. The OS handles paging - only the accessed parts are loaded into memory. This is the standard approach for datasets that are too large for RAM but structured as contiguous arrays (e.g., pre-tokenized text).

import numpy as np
from torch.utils.data import Dataset

class MemmapDataset(Dataset):
"""
Memory-mapped dataset for large pre-tokenized arrays.
Common for LLM pre-training (e.g., 50B tokens stored as uint16).
"""

def __init__(self, data_path, seq_len, dtype=np.uint16):
# Memory-mapped: the OS pages in only what is accessed
self.data = np.memmap(data_path, dtype=dtype, mode='r')
self.seq_len = seq_len
# Number of complete sequences
self.n_samples = (len(self.data) - 1) // seq_len

def __len__(self):
return self.n_samples

def __getitem__(self, idx):
start = idx * self.seq_len
end = start + self.seq_len
# x: input tokens, y: targets shifted by 1 (next-token prediction)
x = torch.from_numpy(self.data[start:end].astype(np.int64))
y = torch.from_numpy(self.data[start+1:end+1].astype(np.int64))
return x, y

Custom Dataset Patterns

Pattern 1: Image Folder Dataset with Albumentations

torchvision transforms work on PIL images. Albumentations works on NumPy arrays and is significantly faster for complex augmentation pipelines.

from pathlib import Path
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

class AlbumentationsImageDataset(Dataset):
"""
Directory structure:
root/
class_0/img1.jpg ...
class_1/img1.jpg ...
"""

def __init__(self, root: str, transform=None):
root = Path(root)
self.transform = transform
self.samples = []
self.class_to_idx = {}

for cls_idx, cls_dir in enumerate(sorted(root.iterdir())):
if not cls_dir.is_dir():
continue
self.class_to_idx[cls_dir.name] = cls_idx
for img_path in cls_dir.glob('*.jpg'):
self.samples.append((img_path, cls_idx))

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
img_path, label = self.samples[idx]
# Load as NumPy array for albumentations
image = np.array(Image.open(img_path).convert('RGB'))
if self.transform:
augmented = self.transform(image=image)
image = augmented['image'] # ToTensorV2 returns a tensor
return image, label


# Albumentations augmentation pipeline (faster than torchvision)
train_transform = A.Compose([
A.RandomResizedCrop(224, 224),
A.HorizontalFlip(p=0.5),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.8),
A.GaussNoise(var_limit=(10, 50), p=0.2),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(), # converts HWC NumPy → CHW torch.Tensor
])

val_transform = A.Compose([
A.Resize(256, 256),
A.CenterCrop(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2(),
])

Pattern 2: HDF5 Dataset for Scientific Data

HDF5 provides hierarchical storage with chunking and compression. It is the standard for scientific datasets (medical images, genomics, physics simulations). The key advantage: sequential reads within a chunk are fast even from spinning disk.

import h5py
import torch
from torch.utils.data import Dataset

class HDF5Dataset(Dataset):
"""
Reads samples from an HDF5 file. Assumes:
file['X'][i] - feature array for sample i
file['y'][i] - integer label for sample i
"""

def __init__(self, h5_path: str, transform=None):
self.h5_path = h5_path
self.transform = transform
# Open in __init__ just to read the length
# Do NOT store the file handle - worker processes cannot share it
with h5py.File(h5_path, 'r') as f:
self.length = len(f['X'])

def __len__(self):
return self.length

def __getitem__(self, idx):
# Open the file here - each worker has its own handle
# h5py files are not picklable across processes if kept open
with h5py.File(self.h5_path, 'r') as f:
x = torch.tensor(f['X'][idx], dtype=torch.float32)
y = int(f['y'][idx])
if self.transform:
x = self.transform(x)
return x, y

:::warning HDF5 and Multiprocessing Never store an open h5py.File handle as an instance attribute when using num_workers > 0. File handles are not serializable across processes. Open and close the file inside __getitem__ instead. For performance, organize data in chunks that match your batch size. :::

Pattern 3: Multi-Modal Dataset (Image + Text)

class ImageCaptionDataset(Dataset):
"""
Multi-modal dataset pairing images with text captions.
Used for CLIP-style contrastive training.
"""

def __init__(self, image_paths, captions, tokenizer, image_transform=None, max_text_len=77):
assert len(image_paths) == len(captions)
self.image_paths = image_paths
self.captions = captions
self.tokenizer = tokenizer
self.image_transform = image_transform
self.max_text_len = max_text_len

def __len__(self):
return len(self.image_paths)

def __getitem__(self, idx):
# Load image
image = Image.open(self.image_paths[idx]).convert('RGB')
if self.image_transform:
image = self.image_transform(image)

# Tokenize caption
tokens = self.tokenizer(
self.captions[idx],
max_length=self.max_text_len,
truncation=True,
padding='max_length',
return_tensors='pt',
)

return {
'image': image, # (C, H, W)
'input_ids': tokens['input_ids'].squeeze(0), # (max_text_len,)
'attention_mask': tokens['attention_mask'].squeeze(0),
}

Pattern 4: Dataset Wrapping Another Dataset

Wrap a base dataset to apply test-time augmentation (TTA) - averaging predictions over multiple augmented views of each sample.

class TTADataset(Dataset):
"""
Test-time augmentation: returns N augmented views of each sample.
Average model predictions across views for better accuracy.
"""

def __init__(self, base_dataset: Dataset, tta_transforms: list):
self.base_dataset = base_dataset
self.tta_transforms = tta_transforms

def __len__(self):
return len(self.base_dataset)

def __getitem__(self, idx):
image, label = self.base_dataset[idx]
# Return all augmented views + original label
views = [transform(image) for transform in self.tta_transforms]
return torch.stack(views), label # (n_views, C, H, W), scalar

# Usage: at inference time
tta_ds = TTADataset(val_ds, tta_transforms=[
val_transform,
A.Compose([A.HorizontalFlip(p=1.0), val_transform]),
A.Compose([A.Rotate(limit=10, p=1.0), val_transform]),
])

DataLoader Deep Dive

from torch.utils.data import DataLoader

# Production-quality DataLoader
train_loader = DataLoader(
train_ds,
batch_size=64,
shuffle=True, # shuffles BEFORE sampling each epoch
num_workers=8, # parallel data loading processes
pin_memory=True, # faster CPU→GPU transfer
drop_last=True, # discard incomplete final batch
persistent_workers=True, # keep worker processes between epochs
prefetch_factor=2, # prefetch 2 batches per worker
)

num_workers: Choosing the Right Number

num_workers is the single most impactful DataLoader parameter. With num_workers=0, all data loading blocks the main thread and starves the GPU. With workers, data for the next batch is loaded in parallel while the GPU processes the current batch.

import os
import time
from torch.utils.data import DataLoader

# Rule: num_workers = 4 * num_GPUs, capped at CPU count - 1
cpu_count = os.cpu_count()
n_gpus = torch.cuda.device_count()
recommended = min(4 * max(n_gpus, 1), cpu_count - 1)

# Profile to find the exact sweet spot for your workload
for nw in [0, 2, 4, 8, 16]:
loader = DataLoader(train_ds, batch_size=64, num_workers=nw, pin_memory=True)
t0 = time.perf_counter()
for _ in loader:
pass
elapsed = time.perf_counter() - t0
print(f"num_workers={nw:2d}: {elapsed:.2f}s per epoch")
# Output pattern: time drops until you hit diminishing returns at 4-8 for most SSD workloads

I/O-bound vs CPU-bound workloads: For JPEG images with heavy augmentation, the bottleneck is CPU (decoding + augmentation). Use more workers. For simple arrays (NumPy, HDF5), the bottleneck is I/O throughput. Workers beyond 4–8 give no benefit.

macOS/Windows caveat: Python's default multiprocessing start method on macOS is spawn (not fork). spawn is slower to start but safer. On Linux, fork is used by default and is faster. If you see crashes on macOS with num_workers > 0, set multiprocessing_context='fork' (but be careful - fork can cause issues with certain libraries like OpenMP).

# macOS: prefer spawn for safety
loader = DataLoader(
dataset,
num_workers=4,
multiprocessing_context='spawn', # explicit
)

pin_memory: Page-Locked Memory for Fast GPU Transfer

Without pin_memory, CPU memory is pageable - the OS can swap it to disk. When a pageable tensor is transferred to the GPU via DMA, the CUDA driver must first copy it to a pinned (page-locked) staging buffer, then DMA from there. This is two copies.

With pin_memory=True, the DataLoader allocates tensors directly in pinned memory. The DMA engine copies directly from pinned CPU memory to GPU memory - one copy, and it can overlap with GPU computation when combined with non_blocking=True.

# Without pin_memory: two-hop transfer
# CPU pageable → CPU pinned (CPU copy) → GPU memory (DMA)

# With pin_memory=True: one-hop transfer
# CPU pinned (allocated directly) → GPU memory (DMA, async)

use_pin_memory = torch.cuda.is_available()
loader = DataLoader(dataset, batch_size=64, pin_memory=use_pin_memory)

# In the training loop: use non_blocking=True for async transfer
# This overlaps the CPU→GPU copy with the previous batch's GPU computation
batch_x = batch_x.to(device, non_blocking=True)
batch_y = batch_y.to(device, non_blocking=True)

The speedup is 30–40% for large batches. Only use pin_memory=True when training on GPU - it wastes memory on CPU-only systems.

prefetch_factor and persistent_workers

# prefetch_factor: how many batches each worker pre-loads ahead
# Default: 2 (each worker has 2 batches ready before they are requested)
# Increase to 4 if GPU is still idle even with num_workers set correctly

loader = DataLoader(
dataset,
num_workers=8,
prefetch_factor=4, # 8 workers × 4 prefetches = 32 batches buffered
persistent_workers=True, # keep workers alive between epochs
)

# persistent_workers=True: worker processes are reused between epochs
# Without it: workers are spawned and destroyed every epoch
# The startup overhead for 8 workers can be 1-2 seconds per epoch
# Always use persistent_workers=True when num_workers > 0

Collate Functions

The collate_fn receives a list of samples (the output of __getitem__ called multiple times) and must combine them into a single batch tensor.

Default Collate Behavior

The default collate stacks tensors of the same shape along a new dimension. It handles Python scalars (converts to tensors), NumPy arrays (converts to tensors), and nested structures.

# Default collate: works when all samples have identical shapes
# samples = [(x1, y1), (x2, y2), ..., (x64, y64)]
# batch = (torch.stack([x1, x2, ...], dim=0), torch.tensor([y1, y2, ...]))

# If __getitem__ returns a dict:
# samples = [{'img': tensor, 'label': 0}, {'img': tensor, 'label': 1}, ...]
# batch = {'img': stacked_tensor, 'label': tensor_of_labels}
# The default collate handles this automatically

Custom Collate for Variable-Length Sequences

import torch
from torch.nn.utils.rnn import pad_sequence

def collate_text(batch):
"""
batch: list of (token_ids, label) tuples
token_ids: 1D tensor of VARIABLE length (no padding applied yet)
"""
token_ids_list, labels = zip(*batch)

# Pad to the length of the longest sequence IN THIS BATCH
# Dynamic padding: much more efficient than global max-length padding
# pad_sequence: pads with 0 by default, batch_first=True → (B, T)
token_ids_padded = pad_sequence(
token_ids_list,
batch_first=True,
padding_value=0
)

# Attention mask: 1 where real tokens, 0 where padding
attention_mask = (token_ids_padded != 0).long()

# Return a dict - works seamlessly with Hugging Face models
return {
'input_ids': token_ids_padded, # (B, T_max)
'attention_mask': attention_mask, # (B, T_max)
'labels': torch.tensor(labels, dtype=torch.long), # (B,)
}


# Dataset returning variable-length token tensors
class TextDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_len=512):
# Tokenize all texts eagerly (fine for most NLP datasets)
self.data = [
(torch.tensor(tokenizer.encode(t)[:max_len], dtype=torch.long), l)
for t, l in zip(texts, labels)
]

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx]

# Register the custom collate
loader = DataLoader(
text_dataset,
batch_size=32,
collate_fn=collate_text,
shuffle=True,
num_workers=4,
)

for batch in loader:
# batch['input_ids'].shape: (32, T_max_in_this_batch)
# T_max varies per batch - no wasted padding beyond the batch's longest seq
print(batch['input_ids'].shape)
break

Custom Collate for Nested Structures

def collate_graph_batch(batch):
"""
Collate for graph data where nodes and edges vary per sample.
Returns stacked node features, edge lists with offset indices.
"""
node_features_list, edge_index_list, labels = [], [], []
cumulative_nodes = 0

for node_feats, edges, label in batch:
node_features_list.append(node_feats)
# Offset edge indices for this graph's nodes in the batched graph
edge_index_list.append(edges + cumulative_nodes)
labels.append(label)
cumulative_nodes += node_feats.size(0)

return {
'node_features': torch.cat(node_features_list, dim=0),
'edge_index': torch.cat(edge_index_list, dim=1),
'labels': torch.tensor(labels),
'batch_idx': torch.cat([
torch.full((n.size(0),), i, dtype=torch.long)
for i, n in enumerate(node_features_list)
]),
}

Samplers

Samplers control the order in which indices are drawn from the dataset. The DataLoader uses a sampler internally, but you can provide a custom one.

Standard Samplers

from torch.utils.data import (
RandomSampler, SequentialSampler, WeightedRandomSampler,
BatchSampler, DistributedSampler
)

# RandomSampler: shuffle all indices at the start of each epoch
# (This is what shuffle=True does internally)
sampler = RandomSampler(dataset)

# SequentialSampler: indices 0, 1, 2, ...
# (This is what shuffle=False does internally)
sampler = SequentialSampler(dataset)

loader = DataLoader(dataset, batch_size=32, sampler=sampler)
# NOTE: you cannot set both shuffle=True and sampler - they conflict

WeightedRandomSampler for Class Imbalance

Class imbalance (e.g., 95% negative, 5% positive) causes models to predict the majority class. Two solutions: class-weighted loss (simpler), or oversampling the minority class with WeightedRandomSampler (better when imbalance is severe).

from collections import Counter
from torch.utils.data import WeightedRandomSampler

# Compute per-sample weights: samples from rare classes get higher weights
labels = [dataset[i][1] for i in range(len(dataset))] # get all labels
class_counts = Counter(labels)
n_samples = len(labels)
n_classes = len(class_counts)

# Weight of class c = n_samples / (n_classes * count(c))
# Rare classes get higher weight → more likely to be sampled
class_weights = {
cls: n_samples / (n_classes * count)
for cls, count in class_counts.items()
}
sample_weights = torch.tensor([class_weights[label] for label in labels])

# num_samples: how many samples per epoch (usually len(dataset) for replacement=True)
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(dataset),
replacement=True, # sample with replacement (required for balancing)
)

loader = DataLoader(dataset, batch_size=32, sampler=sampler)
# Now each batch will have approximately balanced class distribution

DistributedSampler for Multi-GPU Training

In distributed training, each GPU process gets a different shard of the data. DistributedSampler handles this partitioning and ensures no two processes see the same sample.

from torch.utils.data.distributed import DistributedSampler

# In distributed setup (DDP), each process has a different rank
sampler = DistributedSampler(
dataset,
num_replicas=world_size, # total number of GPU processes
rank=rank, # this process's index (0 to world_size-1)
shuffle=True,
drop_last=True, # ensures all processes have same number of batches
)

loader = DataLoader(
dataset,
batch_size=per_gpu_batch_size,
sampler=sampler,
num_workers=4,
pin_memory=True,
)

# CRITICAL: reshuffle at the start of each epoch
for epoch in range(n_epochs):
sampler.set_epoch(epoch) # different shuffle seed per epoch
for batch in loader:
...

BatchSampler for Custom Batch Composition

from torch.utils.data import BatchSampler

# Group samples by sequence length for efficient NLP batching
# (similar-length sequences in same batch → less padding wasted)
class BucketBatchSampler:
"""
Groups sequences of similar length into batches.
Reduces padding waste by up to 50% vs random batching.
"""
def __init__(self, lengths, batch_size, drop_last=False):
self.batch_size = batch_size
self.drop_last = drop_last
# Sort indices by sequence length
self.sorted_indices = sorted(range(len(lengths)), key=lambda i: lengths[i])

def __iter__(self):
# Yield batches of contiguous indices (similar lengths grouped)
batch = []
for idx in self.sorted_indices:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if batch and not self.drop_last:
yield batch

def __len__(self):
if self.drop_last:
return len(self.sorted_indices) // self.batch_size
return math.ceil(len(self.sorted_indices) / self.batch_size)

Data Augmentation in PyTorch

torchvision.transforms.v2 (Modern API)

torchvision.transforms.v2 is the updated API that works with tensors directly (not just PIL Images) and supports batch augmentation. It is the recommended API for new code.

import torchvision.transforms.v2 as v2

train_transform = v2.Compose([
v2.RandomResizedCrop(224, scale=(0.08, 1.0)),
v2.RandomHorizontalFlip(p=0.5),
v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
v2.RandomGrayscale(p=0.2),
v2.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
v2.ToTensor(),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# RandomApply: apply a transform with probability p
strong_augment = v2.Compose([
v2.RandomApply([
v2.ColorJitter(0.8, 0.8, 0.8, 0.2)
], p=0.8),
v2.RandomGrayscale(p=0.2),
v2.RandomApply([v2.GaussianBlur(23)], p=0.5),
v2.ToTensor(),
v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

GPU Augmentation with v2

A key advantage of v2: transforms can be applied to tensors already on the GPU, removing the CPU augmentation bottleneck entirely for large-scale training.

# Apply augmentation on GPU after moving batch
normalize = v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
cutmix = v2.CutMix(num_classes=num_classes)

for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device) # (B, C, H, W) float32
batch_y = batch_y.to(device)

# Augmentation happens on GPU - no CPU-GPU bottleneck
batch_x = normalize(batch_x)
batch_x, batch_y = cutmix(batch_x, batch_y) # CutMix on GPU

IterableDataset for Streaming

When the dataset is too large to fit in memory, use IterableDataset. It generates samples on-the-fly rather than requiring random access by index.

from torch.utils.data import IterableDataset

class StreamingTextDataset(IterableDataset):
"""Stream lines from a large text file - no full file load."""

def __init__(self, file_path: str, tokenizer, max_len: int = 128):
self.file_path = file_path
self.tokenizer = tokenizer
self.max_len = max_len

def __iter__(self):
with open(self.file_path, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
ids = self.tokenizer.encode(line)[:self.max_len]
yield torch.tensor(ids, dtype=torch.long)


class ShardedDataset(IterableDataset):
"""Handle multiple data shards with proper multi-worker distribution."""

def __init__(self, shard_paths: list):
self.shard_paths = shard_paths

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
shard_list = self.shard_paths
else:
# Each worker handles a disjoint subset of shards
per_worker = len(self.shard_paths) // worker_info.num_workers
start = worker_info.id * per_worker
end = start + per_worker if worker_info.id < worker_info.num_workers - 1 \
else len(self.shard_paths)
shard_list = self.shard_paths[start:end]

for shard_path in shard_list:
with open(shard_path, 'r') as f:
for line in f:
yield self.process_line(line)

def process_line(self, line):
return torch.tensor([int(x) for x in line.split()], dtype=torch.long)

:::warning Duplicate Samples with IterableDataset With num_workers > 0, all workers iterate the same __iter__ independently by default. If you do not partition the data source across workers, every sample is yielded N times (once per worker). Always check torch.utils.data.get_worker_info() inside __iter__ and shard the data source explicitly. :::

WebDataset: Streaming from Cloud Storage

Individual file access at scale is slow. Downloading 1.4M JPEG files from S3 creates 1.4M separate HTTP requests. WebDataset addresses this by storing data in TAR archives - sequential reads that can be streamed from any URL without random access.

# pip install webdataset
import webdataset as wds

# Data is stored as sharded TAR files on S3:
# s3://my-bucket/imagenet/shard-{000000..001023}.tar
# Each TAR contains: 000000.jpg, 000000.cls, 000001.jpg, 000001.cls, ...

dataset = (
wds.WebDataset("s3://my-bucket/imagenet/shard-{000000..001023}.tar",
shardshuffle=True)
.shuffle(1000) # shuffle within a buffer of 1000 samples
.decode("pil") # decode bytes → PIL Image automatically
.to_tuple("jpg", "cls") # extract fields from each sample
.map_tuple(train_transform, lambda x: int(x)) # apply transforms
.batched(64) # batch INSIDE the dataset (faster than DataLoader batching)
)

loader = wds.WebLoader(
dataset,
batch_size=None, # batching is done inside the dataset
num_workers=8,
pin_memory=True,
)

for images, labels in loader:
# images: (64, 3, 224, 224), labels: (64,)
pass

Why WebDataset is 10x faster than individual file downloads:

Approach1M files, 100KB eachBottleneck
Random S3 GET~28 hours1M HTTP requests
TAR shards (1000 files/TAR)~1 hour1000 HTTP requests, streaming reads
Local NVMe TAR~15 minutesDisk throughput only
# Creating WebDataset shards from local data
import webdataset as wds

with wds.TarWriter("shard-000000.tar") as sink:
for i, (image, label) in enumerate(local_dataset):
sink.write({
"__key__": f"{i:06d}",
"jpg": image_bytes, # raw JPEG bytes
"cls": str(label), # string label
})

Data Pipeline Architecture Diagram

Profiling Data Loading

Before optimizing, measure. The PyTorch profiler shows exactly where time is spent: CPU augmentation, I/O, GPU kernel launches, or memory transfers.

from torch.profiler import profile, ProfilerActivity, record_function, schedule

# Profile 10 training steps
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=2, warmup=2, active=6), # skip 2, warmup 2, profile 6
on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_logs'),
record_shapes=True,
with_stack=True,
) as prof:
for step, (batch_x, batch_y) in enumerate(train_loader):
with record_function("data_to_device"):
batch_x = batch_x.to(device, non_blocking=True)

with record_function("forward"):
logits = model(batch_x)

with record_function("loss"):
loss = loss_fn(logits, batch_y.to(device))

with record_function("backward"):
loss.backward()

with record_function("optimizer"):
optimizer.step()
optimizer.zero_grad()

prof.step()
if step >= 12:
break

# Print top CPU operations by total time
print(prof.key_averages().table(sort_by='cpu_time_total', row_limit=15))

Reading the Profiler Output

-------------------------- --------------- --------------- -------
Name CPU time total CUDA time total Calls
-------------------------- --------------- --------------- -------
data_to_device 125.3ms 0.2ms 10 ← bottleneck
forward 45.1ms 890.3ms 10
backward 82.4ms 2341.2ms 10
optimizer 12.1ms 45.2ms 10

If data_to_device time is large: add pin_memory=True and non_blocking=True. If there is a large gap before forward starts: the DataLoader is slow (increase num_workers, prefetch_factor). If forward CUDA time is small but CPU time is large: model is CPU-bound (unusual for GPU training, may indicate a synchronization point).

Debugging DataLoader Workers

Workers run in separate processes. Errors raised in workers can be opaque and hard to trace.

# 1. Set num_workers=0 to run everything in main process
# Errors will be raised directly, with full tracebacks
loader = DataLoader(dataset, num_workers=0)

# 2. Test __getitem__ directly before using DataLoader
sample = dataset[42]
print(type(sample), sample[0].shape)

# 3. Test a single batch
loader = DataLoader(dataset, batch_size=4, num_workers=0)
batch = next(iter(loader))
print(batch[0].shape, batch[1].shape)

# 4. Memory leak: if RAM grows during training, worker processes may be
# accumulating state. Use persistent_workers=False (workers restart per epoch)

# 5. File descriptor limit: on Linux, each worker opens files
# Check: ulimit -n (should be at least 1024 * num_workers)
# Fix: ulimit -n 65536

# 6. macOS "Too many open files": set num_workers=0 and increase file descriptor limit

# 7. CUDA tensors in Dataset: never return CUDA tensors from __getitem__
# Workers cannot share the CUDA context. Return CPU tensors only.

:::danger CUDA Tensors in Workers Never create or return CUDA tensors inside Dataset.__getitem__. Worker processes do not have a CUDA context and will crash. Always return CPU tensors from __getitem__ and call .to(device) inside the training loop. :::

Production-Ready DataLoader Configuration

import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

def make_dataloader(
dataset: Dataset,
batch_size: int,
shuffle: bool = True,
num_workers: int = None,
is_train: bool = True,
collate_fn=None,
) -> DataLoader:
"""
Production-ready DataLoader with sensible defaults.
Auto-detects optimal num_workers if not specified.
"""
if num_workers is None:
cpu_count = os.cpu_count() or 1
n_gpus = max(torch.cuda.device_count(), 1)
num_workers = min(4 * n_gpus, cpu_count - 1)

kwargs = dict(
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=torch.cuda.is_available(),
drop_last=is_train,
persistent_workers=(num_workers > 0),
prefetch_factor=2 if num_workers > 0 else None,
worker_init_fn=lambda worker_id: np.random.seed(
torch.initial_seed() % (2**31) + worker_id
),
)
if collate_fn is not None:
kwargs['collate_fn'] = collate_fn

return DataLoader(dataset, **kwargs)


# Usage
train_loader = make_dataloader(train_ds, batch_size=64, shuffle=True, is_train=True)
val_loader = make_dataloader(val_ds, batch_size=128, shuffle=False, is_train=False)

Dataset Transforms and Compose

import torchvision.transforms as T

# Compose chains transforms
transform = T.Compose([
T.Resize((128, 128)),
T.RandomHorizontalFlip(p=0.5),
T.ToTensor(), # converts PIL Image [0-255] to tensor [0-1]
T.Normalize([0.5], [0.5]), # per-channel normalize
])

# Custom transforms as callable classes
class AddGaussianNoise:
"""Adds Gaussian noise to a tensor (for data augmentation)."""

def __init__(self, mean=0.0, std=0.1):
self.mean = mean
self.std = std

def __call__(self, tensor):
return tensor + torch.randn_like(tensor) * self.std + self.mean

def __repr__(self):
return f"AddGaussianNoise(mean={self.mean}, std={self.std})"


noisy_transform = T.Compose([
T.ToTensor(),
AddGaussianNoise(std=0.05),
])

Common Mistakes

:::danger Storing Open File Handles in Dataset Do not store an open file handle (HDF5, SQLite, etc.) as an instance attribute when num_workers > 0. File handles are not serializable across processes. Open and close files inside __getitem__, or use worker_init_fn to open a per-worker file handle. :::

:::warning Wrong num_workers Default The default num_workers=0 is fine for development but catastrophic for production. GPU utilization will be 40–60% even with a fast model. Always profile with a few different num_workers values before submitting a long training run. :::

:::danger CUDA Tensors Returned from getitem Never return CUDA tensors from Dataset.__getitem__. Worker processes do not share the main process CUDA context. Always return CPU tensors and call .to(device) in the training loop. :::

:::warning shuffle=True with a Sampler DataLoader raises an error if you set both shuffle=True and provide a sampler. shuffle=True is shorthand for sampler=RandomSampler(dataset). If you provide a custom sampler (e.g., WeightedRandomSampler), set shuffle=False (or omit it). :::

YouTube Resources

VideoCreatorWhat It Covers
PyTorch Datasets and DataLoadersPython EngineerComplete DataLoader guide from basics to production
Custom Dataset TutorialAladdin PerssonBuilding custom datasets for images and tabular data
Data Loading PerformancePyTorchProfiling data pipelines and fixing bottlenecks
WebDataset TutorialNVIDIAStreaming datasets at scale from cloud storage

Interview Q&A

Q1: What is the PyTorch Dataset contract, and what two methods must you implement?

A PyTorch Dataset is any class that implements __len__ (returns the number of samples as an integer) and __getitem__ (takes an integer index and returns a single sample - typically a tuple of input tensor and label tensor). The DataLoader calls these methods internally to collect samples into batches. Map-style datasets (the standard case) require both methods. Iterable-style datasets (for streaming) implement only __iter__ and are used when random access by index is not possible, such as reading from a file or network stream. The choice between eager loading (converting everything to tensors in __init__) and lazy loading (loading from disk in __getitem__) depends on dataset size: if the dataset fits in RAM, eager loading is faster; if not, lazy loading is required.

Q2: How do you choose the right value for num_workers?

num_workers controls how many subprocess workers prefetch data in parallel. With num_workers=0, all loading happens on the main thread, blocking GPU computation between batches. The right value depends on the workload type. For CPU-bound workloads (complex image augmentation), you need more workers - typically 4 to 8 per GPU. For I/O-bound workloads (simple array loading), the bottleneck is disk throughput and 2–4 workers is often sufficient - more workers do not help if the disk is saturated. The practical rule: start with 4 * num_GPUs, cap at cpu_count - 1, then profile timing with 0, 2, 4, 8 workers to find the diminishing-returns point. On macOS, the start method is spawn by default (not fork), which has higher per-epoch startup cost - use persistent_workers=True to avoid restarting workers between epochs.

Q3: What does pin_memory=True actually do, and why does it speed up training?

pin_memory=True tells the DataLoader to allocate output tensors in page-locked (pinned) CPU memory instead of standard pageable memory. Normally, the CUDA DMA engine cannot copy from pageable memory directly because the OS can swap it to disk at any moment. The driver must first copy the data into a pinned staging buffer, then DMA from staging to GPU - two copies. With pin_memory=True, the data is already in pinned memory, so DMA goes directly from CPU to GPU - one copy. Combined with .to(device, non_blocking=True) in the training loop, this transfer overlaps with the GPU computing the previous batch, hiding the memory transfer latency entirely. The typical speedup is 30–40% for large batches. Only use it when training on GPU - it wastes memory on CPU-only systems and has no effect.

Q4: How do you implement WeightedRandomSampler for class imbalance, and when is it better than class-weighted loss?

WeightedRandomSampler assigns a weight to every sample, then samples with replacement so that the expected number of samples from each class is proportional to its weight. Implementation: compute per-class weights as n_samples / (n_classes * class_count), assign each sample its class weight, then create WeightedRandomSampler(sample_weights, num_samples=len(dataset), replacement=True). This is better than class-weighted loss when the imbalance is extreme (less than 1% minority class) - the model rarely sees minority examples with random sampling, so increasing the loss weight does not help much. WeightedRandomSampler ensures the minority class appears in roughly every batch. For moderate imbalance (5–30% minority), class-weighted loss (nn.CrossEntropyLoss(weight=...)) is simpler and usually sufficient.

Q5: Write a custom collate_fn for variable-length sequences. What does pad_sequence do?

pad_sequence(sequences, batch_first=True, padding_value=0) takes a list of 1D tensors of different lengths and pads them with padding_value on the right so all have the same length as the longest tensor in the list. With batch_first=True, the output shape is (B, T_max). The custom collate function receives a list of (token_ids, label) tuples from __getitem__. It unzips them into two lists, calls pad_sequence on the token lists, creates an attention mask (padded != 0), and returns a dict with input_ids, attention_mask, and labels. The critical advantage of doing this per-batch (dynamic padding) rather than globally: batches where all sequences happen to be short waste zero memory on padding, whereas global padding always pads to the dataset maximum length.

Q6: When should you use WebDataset instead of a standard map-style Dataset, and what are the tradeoffs?

Use WebDataset when data is stored in cloud object storage (S3, GCS, Azure Blob) and downloading individual files is too slow, or when the dataset is too large for local storage and must be streamed. WebDataset stores data as TAR archives (typically 1000–10000 samples per shard), enabling sequential streaming reads. A single HTTP request fetches a full shard, avoiding the per-file HTTP overhead. The tradeoffs: WebDataset does not support random access by index, so you cannot use standard WeightedRandomSampler or reproducible shuffling - it uses approximate shuffling with a buffer. Epoch-level shuffling requires reshuffling shard order, not sample order. For local NVMe storage with a few hundred thousand files, standard map-style Dataset is simpler and equally fast. WebDataset is the right choice at scale (millions of files) or when reading from cloud storage in a training cluster.

Q7: How do you debug slow data loading, and what does the profiler tell you?

First, measure: time a full epoch with num_workers=0 vs num_workers=4 vs num_workers=8. If adding workers helps significantly, the bottleneck is CPU-side processing (loading + augmentation). If workers help only marginally, the bottleneck is I/O (disk throughput). Use torch.profiler to measure the wall time of data_to_device, forward, and backward. If data_to_device is large relative to the others, add pin_memory=True and non_blocking=True. If there is a long idle gap before forward starts, the DataLoader is not prefetching fast enough - increase num_workers and prefetch_factor. If CPU augmentation is the bottleneck (measured by high CPU time in profiler), switch from PIL + torchvision to Albumentations (significantly faster), or move augmentation to the GPU using torchvision.transforms.v2 applied on GPU tensors.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Training Dynamics demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.