Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the Inference Batching demo on the EngineersOfAI Playground - no code required. :::

Batch Inference Pipelines

The Production Scenario

At 11:00 PM every night, Netflix kicks off one of the largest machine learning jobs in the world. It needs to generate personalized row rankings - "Continue Watching," "Top Picks For You," "Trending Now sorted for your taste" - for over 230 million subscribers, across hundreds of candidate sets, using models that consider watch history, time of day, device type, and dozens of other signals. By 6:00 AM, before the morning viewing peak, every subscriber's homepage must reflect a freshly computed recommendation set. The job runs on hundreds of GPU nodes processing 230 million inference requests in seven hours.

This is batch inference at scale. No single user is waiting for a result right now. The freshness window is measured in hours, not milliseconds. The economic constraint is not latency but throughput and cost: how many GPU-hours does it take to score the entire user base? A pipeline that runs in 7 hours is 3x cheaper to run daily than one that takes 21 hours - because you can use fewer reserved GPU instances.

The engineers who built this pipeline did not think about it as a model serving problem. They thought about it as a data engineering problem with a GPU step in the middle: read from the data warehouse, preprocess at scale, run the model efficiently, write predictions back to a store where the serving layer can look them up in real time. The model itself is 5% of the problem. The other 95% is orchestration, chunking, failure recovery, monitoring, and ensuring that a single corrupt user record does not abort a 7-hour job.

This lesson covers how to design, build, and operate batch inference pipelines - from the architectural patterns that make them reliable, to the specific tricks for maximizing GPU utilization, to the failure recovery strategies that keep a 230-million-record job from needing to restart from scratch.

Why This Exists - The Case for Batch Over Real-Time

Real-time inference is expensive. Every request requires the model to be loaded (or kept hot), the features to be computed, and the result to be returned within milliseconds. At high throughput, this means a GPU cluster running 24/7, with redundancy, autoscaling, and the full operational complexity of a real-time system.

Batch inference exists for use cases where the prediction is not needed at the moment of the request but can be pre-computed and stored. When a user opens Netflix at 7:00 AM, the homepage shows recommendations that were computed at 3:00 AM. The user does not know or care. The important thing is that the recommendations are fresh enough to be relevant - which for viewing habits, a few hours of staleness is acceptable.

The cost difference is dramatic. A GPU running 24/7 to serve real-time inference costs roughly 730/month(A10GatAWSondemandpricing).ThesameinferenceworkloadrunasanightlybatchjobonspotGPUinstancesmightcost730/month (A10G at AWS on-demand pricing). The same inference workload run as a nightly batch job on spot GPU instances might cost 50/month - a 15x cost reduction - because you only pay for the hours you use, at spot pricing.

Batch inference trades latency for cost and throughput. Make that trade knowingly.

Historical Context

Batch processing predates modern ML by decades. IBM mainframe batch jobs in the 1960s would run overnight to process the previous day's transactions. The pattern is identical: collect data, process it in bulk, store results for later use.

The application of batch processing to ML model scoring emerged from the data warehouse world. Hadoop MapReduce (2004) let engineers run distributed computations on petabytes of data. Spark (2009) made those computations dramatically faster by keeping intermediate data in memory. As ML models became valuable in production, engineers naturally applied the Spark paradigm to model scoring: run the model as a UDF (user-defined function) in a distributed computation.

The GPU era changed the economics. Before deep learning, most production models were logistic regression, gradient boosting, or shallow neural networks - fast enough to score millions of records on CPU in minutes. Deep learning models changed the latency profile: a single forward pass through a large transformer takes tens of milliseconds on a single CPU core, making pure CPU batch inference prohibitively slow. The modern batch inference pipeline runs CPU preprocessing at scale (Spark, Ray), then moves data to GPU for the model forward pass, then writes results back to storage.

Core Concepts

The Batch Inference Architecture

A batch inference pipeline has five components:

  1. Data source: The input to score - a table in a data warehouse, a set of S3 files, a Kafka topic snapshot
  2. Preprocessing: Feature computation, normalization, tokenization - typically distributed across CPU workers
  3. Model execution: The forward pass - typically on GPU, with careful batching for utilization
  4. Post-processing: Score calibration, filtering, business logic - typically on CPU
  5. Output sink: Writing predictions - to a feature store, a database, or S3 for downstream use

Optimal Batch Sizes for GPU Utilization

GPU utilization is the key metric in batch inference. A GPU processes a batch of N examples in nearly the same time as it processes 1 example - until the batch saturates GPU memory or compute. The goal is to find the batch size where the GPU spends more than 90% of its time on actual matrix multiplications, not on memory transfers or idle wait.

A simple utilization sweep:

# gpu_batch_sweep.py
import torch
import time
import numpy as np
from typing import Callable


def measure_throughput(
model: torch.nn.Module,
feature_dim: int,
batch_size: int,
n_warmup: int = 5,
n_measure: int = 20,
) -> float:
"""Returns examples/second at the given batch size."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()

# Warmup
for _ in range(n_warmup):
x = torch.randn(batch_size, feature_dim, device=device)
with torch.no_grad():
_ = model(x)

torch.cuda.synchronize()
start = time.perf_counter()

for _ in range(n_measure):
x = torch.randn(batch_size, feature_dim, device=device)
with torch.no_grad():
_ = model(x)

torch.cuda.synchronize()
elapsed = time.perf_counter() - start

total_examples = batch_size * n_measure
return total_examples / elapsed


def find_optimal_batch_size(
model: torch.nn.Module,
feature_dim: int,
batch_sizes: list[int] = [1, 8, 16, 32, 64, 128, 256, 512],
) -> int:
"""Find the batch size that maximizes throughput."""
results = {}
for bs in batch_sizes:
try:
throughput = measure_throughput(model, feature_dim, bs)
results[bs] = throughput
print(f" batch_size={bs:4d}: {throughput:8.0f} examples/sec")
except RuntimeError as e:
if "out of memory" in str(e).lower():
print(f" batch_size={bs:4d}: OOM - stopping sweep")
break
raise

optimal = max(results, key=results.get)
print(f"\nOptimal batch size: {optimal} ({results[optimal]:.0f} ex/sec)")
return optimal

Implementation: Ray for Distributed Batch Inference

Ray is the recommended tool for Python-native batch inference. It handles work distribution, parallelism, and failure recovery while keeping the code simple:

# ray_batch_inference.py
import ray
import numpy as np
import torch
import pandas as pd
from pathlib import Path
import time
import json
from typing import Iterator


@ray.remote(num_gpus=1)
class ModelActor:
"""A Ray actor that holds the model in GPU memory and processes batches."""

def __init__(self, model_path: str, batch_size: int = 128):
import torch

self.batch_size = batch_size
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model once at actor startup - kept in GPU memory
print(f"Loading model from {model_path} on {self.device}")
self.model = torch.jit.load(model_path, map_location=self.device)
self.model.eval()

self.processed = 0
self.start_time = time.time()

def predict_batch(self, features: np.ndarray) -> np.ndarray:
"""Run inference on a numpy array. Returns predictions as numpy array."""
with torch.no_grad():
x = torch.from_numpy(features).float().to(self.device)
logits = self.model(x)
probs = torch.softmax(logits, dim=-1)
preds = probs.cpu().numpy()

self.processed += len(features)
return preds

def get_stats(self) -> dict:
elapsed = time.time() - self.start_time
return {
"processed": self.processed,
"throughput": self.processed / elapsed if elapsed > 0 else 0,
}


def chunk_dataframe(
df: pd.DataFrame,
chunk_size: int,
) -> Iterator[pd.DataFrame]:
"""Yield DataFrame chunks of specified size."""
for i in range(0, len(df), chunk_size):
yield df.iloc[i : i + chunk_size]


def run_batch_inference(
input_path: str,
model_path: str,
output_path: str,
batch_size: int = 128,
num_gpus: int = 1,
feature_cols: list[str] = None,
checkpoint_path: str = "/tmp/batch_checkpoint.json",
):
"""
Full batch inference pipeline with checkpointing.

Reads a parquet file, runs model inference, writes predictions.
Resumes from checkpoint if interrupted.
"""
ray.init(ignore_reinit_error=True)

# Load data
print(f"Loading data from {input_path}")
df = pd.read_parquet(input_path)
total_rows = len(df)
print(f"Total rows: {total_rows:,}")

if feature_cols is None:
feature_cols = [c for c in df.columns if c.startswith("feature_")]

# Load checkpoint - resume from where we left off
start_chunk = 0
existing_results = []
if Path(checkpoint_path).exists():
with open(checkpoint_path) as f:
checkpoint = json.load(f)
start_chunk = checkpoint["completed_chunks"]
print(f"Resuming from chunk {start_chunk}")
existing_results = pd.read_parquet(
checkpoint["partial_output_path"]
).to_dict("records") if "partial_output_path" in checkpoint else []

# Start model actors - one per GPU
actors = [ModelActor.remote(model_path, batch_size) for _ in range(num_gpus)]

all_results = existing_results.copy()
chunks = list(chunk_dataframe(df, batch_size))

print(f"Processing {len(chunks)} chunks with {num_gpus} GPU actor(s)")
start_time = time.time()

for chunk_idx, chunk in enumerate(chunks[start_chunk:], start=start_chunk):
# Round-robin across actors for load balancing
actor = actors[chunk_idx % num_gpus]

features = chunk[feature_cols].values.astype(np.float32)
predictions = ray.get(actor.predict_batch.remote(features))

# Combine IDs with predictions
chunk_results = []
for i, (row_idx, row) in enumerate(chunk.iterrows()):
chunk_results.append({
"id": row.get("id", row_idx),
"prediction": int(np.argmax(predictions[i])),
"confidence": float(np.max(predictions[i])),
"scores": predictions[i].tolist(),
})
all_results.extend(chunk_results)

# Checkpoint every 100 chunks
if (chunk_idx + 1) % 100 == 0:
partial_path = output_path.replace(".parquet", "_partial.parquet")
pd.DataFrame(all_results).to_parquet(partial_path, index=False)
with open(checkpoint_path, "w") as f:
json.dump({
"completed_chunks": chunk_idx + 1,
"total_chunks": len(chunks),
"partial_output_path": partial_path,
}, f)

elapsed = time.time() - start_time
done = (chunk_idx + 1 - start_chunk) * batch_size
rate = done / elapsed
remaining = (total_rows - (chunk_idx + 1) * batch_size) / rate
print(
f"Chunk {chunk_idx + 1}/{len(chunks)} | "
f"{rate:.0f} rows/sec | "
f"ETA: {remaining / 60:.1f} min"
)

# Write final output
print(f"Writing {len(all_results):,} predictions to {output_path}")
pd.DataFrame(all_results).to_parquet(output_path, index=False)

# Clean up checkpoint
if Path(checkpoint_path).exists():
Path(checkpoint_path).unlink()

total_elapsed = time.time() - start_time
print(f"Done: {total_rows:,} rows in {total_elapsed / 60:.1f} minutes")
print(f"Throughput: {total_rows / total_elapsed:.0f} rows/sec")

# Print actor stats
for i, actor in enumerate(actors):
stats = ray.get(actor.get_stats.remote())
print(f"Actor {i}: {stats['processed']:,} rows, {stats['throughput']:.0f}/sec")

ray.shutdown()

Implementation: Spark for Distributed Batch Inference

Spark is the right choice when your data already lives in a data lake and you want to leverage existing Spark infrastructure:

# spark_batch_inference.py
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf, col
from pyspark.sql.types import ArrayType, FloatType, StructType, StructField, StringType
import pandas as pd
import numpy as np
import torch
import pickle
from typing import Iterator


def create_inference_udf(model_path: str, feature_cols: list[str]):
"""
Create a Pandas UDF for distributed model inference.

Pandas UDFs are vectorized - Spark passes a pandas Series/DataFrame
to the function instead of one row at a time. The model is loaded
once per partition (not once per row) via broadcast.
"""

# Broadcast the model bytes to all executors
spark = SparkSession.getActiveSession()
with open(model_path, "rb") as f:
model_bytes = f.read()
model_broadcast = spark.sparkContext.broadcast(model_bytes)

@pandas_udf(
returnType=StructType([
StructField("prediction", StringType()),
StructField("confidence", FloatType()),
])
)
def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
# Load model from broadcast - this happens once per executor
model = torch.jit.load(
pickle.loads(model_broadcast.value),
map_location="cpu",
)
model.eval()

for batch_df in iterator:
features = batch_df[feature_cols].values.astype(np.float32)

with torch.no_grad():
x = torch.from_numpy(features)
logits = model(x)
probs = torch.softmax(logits, dim=-1).numpy()

predictions = np.argmax(probs, axis=1)
confidences = np.max(probs, axis=1)

label_map = {0: "safe", 1: "unsafe", 2: "review"}
yield pd.DataFrame({
"prediction": [label_map[p] for p in predictions],
"confidence": confidences.astype(np.float32),
})

return predict_udf


def run_spark_batch_inference(
input_table: str,
output_table: str,
model_path: str,
feature_cols: list[str],
spark_config: dict = None,
):
"""Run batch inference using Spark with a Pandas UDF."""
spark = (
SparkSession.builder
.appName("BatchInference")
.config("spark.sql.execution.arrow.pyspark.enabled", "true")
.config("spark.sql.execution.arrow.maxRecordsPerBatch", "512")
)

if spark_config:
for k, v in spark_config.items():
spark = spark.config(k, v)

spark = spark.getOrCreate()

# Read input data
print(f"Reading from {input_table}")
df = spark.table(input_table)
print(f"Input rows: {df.count():,}")

# Create the inference UDF
predict_udf = create_inference_udf(model_path, feature_cols)

# Apply inference - Spark handles parallelism automatically
# coalesce controls the number of partitions = number of model instances
n_partitions = 32 # tune based on cluster size
result_df = (
df
.coalesce(n_partitions)
.withColumn(
"prediction_result",
predict_udf(*[col(c) for c in feature_cols])
)
.withColumn("prediction", col("prediction_result.prediction"))
.withColumn("confidence", col("prediction_result.confidence"))
.drop("prediction_result")
)

# Write output - overwrite for idempotency
print(f"Writing to {output_table}")
result_df.write.mode("overwrite").saveAsTable(output_table)

final_count = spark.table(output_table).count()
print(f"Output rows: {final_count:,}")

spark.stop()

Failure Recovery: Checkpointing and Idempotency

A batch inference job scoring 50 million records takes hours. Without checkpointing, any failure restarts from zero. The right design makes every chunk idempotent and tracks progress:

# checkpoint_manager.py
import json
import hashlib
from pathlib import Path
from dataclasses import dataclass, asdict
from datetime import datetime
from typing import Optional


@dataclass
class ChunkCheckpoint:
job_id: str
total_chunks: int
completed_chunks: int
failed_chunks: list
started_at: str
last_updated: str
partial_output_path: Optional[str] = None


class CheckpointManager:
def __init__(self, checkpoint_dir: str):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)

def _checkpoint_path(self, job_id: str) -> Path:
return self.checkpoint_dir / f"{job_id}.json"

def save(self, checkpoint: ChunkCheckpoint):
checkpoint.last_updated = datetime.utcnow().isoformat()
with open(self._checkpoint_path(checkpoint.job_id), "w") as f:
json.dump(asdict(checkpoint), f, indent=2)

def load(self, job_id: str) -> Optional[ChunkCheckpoint]:
path = self._checkpoint_path(job_id)
if not path.exists():
return None
with open(path) as f:
data = json.load(f)
return ChunkCheckpoint(**data)

def clear(self, job_id: str):
path = self._checkpoint_path(job_id)
if path.exists():
path.unlink()

@staticmethod
def compute_job_id(input_path: str, model_version: str) -> str:
"""Deterministic job ID based on input + model - same job = same ID."""
key = f"{input_path}:{model_version}"
return hashlib.sha256(key.encode()).hexdigest()[:16]

Dynamic Batching for GPU Efficiency

The key insight for GPU batch inference: you want to send the largest possible batch to the GPU without exceeding memory. Too small and you waste GPU compute. Too large and you OOM. Dynamic batching accumulates examples up to a threshold and then flushes:

# dynamic_batcher.py
import asyncio
import time
import numpy as np
from dataclasses import dataclass, field
from typing import Optional
import torch


@dataclass
class PendingRequest:
features: np.ndarray
future: asyncio.Future
arrival_time: float = field(default_factory=time.time)


class DynamicBatcher:
"""
Accumulates inference requests and flushes them as batches to the GPU.

This is how production serving frameworks like Triton work internally.
It provides much better GPU utilization than serving one request at a time.
"""

def __init__(
self,
model: torch.nn.Module,
max_batch_size: int = 128,
max_wait_ms: float = 10.0,
):
self.model = model
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.pending: list[PendingRequest] = []
self.lock = asyncio.Lock()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device).eval()

async def predict(self, features: np.ndarray) -> np.ndarray:
"""Submit a single request for batched processing."""
loop = asyncio.get_event_loop()
future = loop.create_future()

async with self.lock:
self.pending.append(PendingRequest(features=features, future=future))

if len(self.pending) >= self.max_batch_size:
# Batch is full - flush immediately
asyncio.create_task(self._flush())

return await future

async def _flush_loop(self):
"""Background loop that flushes on timeout."""
while True:
await asyncio.sleep(self.max_wait_ms / 1000)
async with self.lock:
if self.pending:
oldest = self.pending[0].arrival_time
if (time.time() - oldest) * 1000 >= self.max_wait_ms:
await self._flush()

async def _flush(self):
"""Process all pending requests as a single GPU batch."""
if not self.pending:
return

batch = self.pending[:self.max_batch_size]
self.pending = self.pending[self.max_batch_size:]

features_np = np.stack([r.features for r in batch])
try:
with torch.no_grad():
x = torch.from_numpy(features_np).float().to(self.device)
logits = self.model(x)
probs = torch.softmax(logits, dim=-1).cpu().numpy()

for i, req in enumerate(batch):
if not req.future.done():
req.future.set_result(probs[i])
except Exception as e:
for req in batch:
if not req.future.done():
req.future.set_exception(e)

The Netflix Recommendation Pipeline Pattern

Netflix's batch scoring pipeline follows a specific pattern worth understanding in detail:

The key insight from this architecture: the serving layer is just a Cassandra lookup. When a user opens Netflix, the API reads a pre-computed ranked list from Cassandra. There is no model inference happening at request time - just a key-value read that takes under 1ms. The entire complexity of model inference, feature computation, and ranking is pushed to the batch pipeline.

This is the defining characteristic of batch inference: it separates computation time from serving time. Complex models that would take 500ms to run in real time produce results that can be served in under 1ms because the results are pre-computed.

Production Engineering Notes

Monitor batch job SLA separately from model quality: A batch inference job that fails silently and writes zero predictions is worse than one that writes stale predictions. Set alerts on: job completion time, output row count (should match input), prediction distribution shift.

Cost optimization with spot instances: Batch inference is a perfect workload for spot (preemptible) GPU instances - jobs can checkpoint and resume. AWS EC2 Spot or GCP Preemptible instances are 60-80% cheaper than on-demand. With checkpointing every 100 chunks, a spot instance preemption only loses the progress since the last checkpoint.

Partition-aware writing: When writing predictions back to a database, batch by partition key. Writing 50M rows in random order to Cassandra saturates the write path. Write each partition's data sequentially.

# Write predictions partitioned by user shard for efficient DB writes
df_predictions.repartition("user_shard").write.partitionBy("user_shard") \
.parquet("s3://predictions/2024-01-15/")

:::warning Model Artifact Versioning in Batch Jobs Your batch pipeline must load the exact same model version that was used for training evaluation. A common bug: the pipeline loads the latest model from a path like models/production/model.pt that gets updated by another team mid-run. Half the predictions use model v1 and half use model v2. Use immutable versioned paths: models/v47/model.pt. Log the model version in every output record. :::

:::danger The Silent Partial Failure Spark and Ray both report success even if a fraction of partitions fail (depending on configuration). A job that "completes" but has 15% of partitions silently skipped is worse than a clear failure - you have predictions for 85% of users and stale data for 15%, and you may not notice for days. Always validate output count against input count. Alert if output_count / input_count < 0.99. :::

:::danger Stale Predictions After Schema Change If you add a new feature column to your model but forget to add it to the batch pipeline's feature extraction, the pipeline will run successfully (filling the missing feature with zeros), and model quality degrades silently. Feature schema validation at pipeline startup - compare expected feature names and types against actual data schema - prevents this class of bug. :::

Interview Q&A

Q: When should you use batch inference versus real-time inference?

Use batch inference when: the prediction is not needed at the exact moment of the request (recommendations computed overnight, content tagging, credit risk scored nightly), the result can be pre-computed and stored for sub-millisecond lookup, or the model is too slow or expensive to run synchronously at request time. Use real-time inference when: the prediction depends on context only available at request time (fraud detection on a live transaction), latency must be under a few hundred milliseconds, or personalization requires up-to-the-minute data. The economic argument is also significant: batch inference on spot GPU instances can be 10-15x cheaper than a 24/7 real-time serving cluster.

Q: How do you handle failures in a long-running batch inference job?

Checkpoint frequently - every N chunks, write a progress file recording how many chunks have been completed and where the partial output lives. If the job crashes, it reads the checkpoint and resumes from the last completed chunk. Make each chunk write idempotent: writing the same predictions twice should produce identical output. Use a dead-letter queue for chunks that fail repeatedly - do not let a few corrupt records block the entire job. Log and alert when chunk failure rate exceeds 0.1%.

Q: What batch size maximizes GPU utilization and how do you find it?

GPU utilization increases with batch size up to the point where the batch saturates GPU memory bandwidth or compute. To find the optimal size empirically: run a sweep of batch sizes from 1 to the OOM boundary, measuring throughput in examples per second. The throughput curve typically increases rapidly from batch size 1 to 32-64, then plateaus. The optimal batch size is usually at the knee of that curve - maximum throughput per unit of GPU memory. For a typical transformer model on an A10G GPU, this is often batch size 64-256 depending on sequence length.

Q: How do you prevent a batch inference pipeline from using a different model version than intended?

Never load models from mutable paths like models/production/latest. Always reference versioned, immutable paths like models/v47/model.pt in the pipeline configuration. Store the model version as metadata in each output record so downstream consumers can verify the lineage. Use a model registry that tracks which version is "approved for batch" separately from "approved for real-time" - this prevents an untested model from accidentally being used in batch production. Add an assertion at pipeline startup that the loaded model's hash matches the expected hash from the registry.

Q: Describe how you would design a batch inference pipeline to score 500 million users nightly in under 6 hours.

Assuming a 50ms GPU model and 8x A100 GPUs available: first, partition the 500M users into 8 equal shards. Run 8 Ray actors in parallel, one per GPU. Each actor processes its 62.5M users. At optimal batch size of 256 and 50ms per batch, each actor processes 256/0.05 = 5,120 users/sec = 18.4M users/hour. 62.5M / 18.4M = 3.4 hours per actor. With 8 parallel actors, total time is 3.4 hours - well within the 6-hour window. The preprocessing step (feature computation in Spark) and the output write step (Cassandra inserts) need to be profiled separately - they often become the bottleneck at this scale. Partition feature data to collocate with GPU workers to minimize network transfer.

© 2026 EngineersOfAI. All rights reserved.