torch.compile and XLA
The 40% That Was Hiding in Plain Sight
In September 2022, the PyTorch team released PyTorch 2.0 with one headline number: torch.compile achieves a geometric mean 43% speedup on standard benchmarks without changing model code. One line of code. No custom CUDA kernels. No quantization. No model architecture changes. Just wrapping your existing model with torch.compile(model).
The benchmarks were real. Production deployments at Meta, which drove this work, saw 30-50% throughput improvements on transformer models. For a company spending hundreds of millions of dollars on GPU compute, 30% is not a small number. It is the cost of an entire data center.
What changed? Not the operations. Not the hardware. The change was that PyTorch could now look at a sequence of operations - attention, layer norm, residual add, MLP - and instead of executing them one by one with separate GPU kernel launches, fuse them into fewer, larger kernels that read and write GPU memory far fewer times. The bottleneck in modern GPU computing is almost never compute; it is memory bandwidth. Every time you read a tensor from GPU HBM (high-bandwidth memory) and write it back, you pay a cost. Fusion eliminates intermediate reads and writes.
Understanding torch.compile deeply matters beyond just calling it. You need to understand when it helps, when it hurts, how to debug graph breaks (which silently defeat the optimization), and how the backend selection (default, reduce-overhead, max-autotune) affects your deployment. You need to understand why some models compile successfully and others hit recompilation loops. And you need to understand XLA - the Google compiler system that takes the same ideas further and powers all of JAX and TensorFlow on TPUs.
This lesson traces the entire stack: from your Python model code, through Python bytecode manipulation, through graph capture and differentiation, through Triton kernel generation, to the actual machine instructions running on your GPU. It is a lot of machinery. Understanding it will make you a significantly better ML engineer.
Why This Exists - The Eager Mode Tax
Before torch.compile, PyTorch worked in "eager mode." Every operation executed immediately when called. y = x + 1 ran the CUDA kernel right now, returned a tensor, and moved on. This design was the reason PyTorch beat TensorFlow in researcher adoption: you could use Python debuggers, print tensors anywhere, write conditional logic based on tensor values, and everything worked intuitively.
The cost of eager mode is visible in a profiler trace. For a transformer forward pass, you see hundreds of individual kernel launches: one for each matrix multiply, one for each softmax, one for each layer norm, one for each element-wise operation. Between each kernel launch: a CUDA synchronization overhead, a host-to-device dispatch, and crucially, the result tensor is written to HBM and then the next kernel reads it back from HBM.
The math is brutal. A modern GPU like the H100 has 3.35 TB/s of HBM bandwidth and 1,979 TFLOPS of BF16 compute. The arithmetic intensity (FLOPs per byte of memory access) for many transformer operations is low - especially element-wise operations like GeLU or residual adds. These operations are completely bandwidth-bound: the GPU is waiting for memory, not compute.
Fusion is the answer. If you can fuse layer_norm + matmul + gelu + residual_add into a single kernel, you read the input tensor once, do all the computation, write the output once. You eliminate three intermediate tensor materializations. For a large model, this alone accounts for most of the 43% speedup.
TensorFlow 1.x solved this with static graphs: tf.Session, tf.placeholder. You described the entire computation graph ahead of time, TensorFlow compiled it, and then you ran it. Fusion was trivial because the compiler had the full graph. But the developer experience was terrible - no Python debuggers, no dynamic shapes, imperative-style control flow required tf.cond and tf.while_loop.
torch.compile attempts to have it both ways: write code in eager mode (full Python flexibility), get static graph performance (fusion, operator scheduling, kernel optimization). The mechanism for achieving this is genuinely novel.
Historical Context - The Road to torch.compile
TensorFlow's graph execution mode established the value of ML compilation in 2015. XLA (Accelerated Linear Algebra) was Google's answer to the performance gap in 2017 - a compiler that could fuse HLO (High Level Operations) graphs into efficient GPU/TPU code.
PyTorch went the opposite direction: eager execution from day one (2016). By 2020, PyTorch dominated ML research and was rapidly taking production workloads. But the performance gap versus TF/XLA in production was real.
Several approaches were tried: torch.jit.script (static type annotation, compile-time graph tracing), torch.fx (symbolic graph tracing via operator interception), TorchScript (a restricted Python subset that could be compiled). All required compromising the programming model.
torch.compile in PyTorch 2.0 (2023) was the breakthrough. The key insight was TorchDynamo: instead of restricting what you can write, intercept Python bytecode execution at the CPython level. This lets torch.compile capture graphs from arbitrary Python code - including loops, conditionals, and function calls - by watching what PyTorch operations are called and building a graph from that observation.
JAX, developed by Google Brain (2018), took an even more aggressive approach: make JIT compilation the primary programming model from the start. jax.jit transforms entire functions through XLA. The restriction is that JIT-traced functions must be functionally pure (no side effects, static shapes for maximum performance). The reward is aggressive XLA optimization including autotuned hardware-specific kernels.
The torch.compile Architecture
TorchDynamo - The Frontend
TorchDynamo is not a static analyzer. It does not read your source code and parse it into a graph. Instead, it hooks into CPython's frame evaluation mechanism (PEP 523, _PyEval_EvalFrameDefault) to intercept bytecode execution.
When TorchDynamo encounters a torch.compile-decorated function, it replaces CPython's frame evaluator for that function with its own. As the function executes, TorchDynamo:
- Watches the bytecode instructions being executed
- When it sees PyTorch tensor operations, records them symbolically (with placeholder tensors instead of real data)
- When it encounters Python control flow (if/else, loops) that depends on tensor values, it hits a "graph break" - it cannot capture this part statically
- Emits a FX (FX = function transformation) graph for the captured operations
The genius of this approach: you do not need to rewrite your model. TorchDynamo handles arbitrary Python code by capturing what it can and falling back to eager execution for what it cannot.
import torch
import torch._dynamo as dynamo
# Simple model to illustrate torch.compile
class SimpleMLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(512, 1024)
self.fc2 = torch.nn.Linear(1024, 512)
self.act = torch.nn.GELU()
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
model = SimpleMLP().cuda()
x = torch.randn(32, 512, device='cuda')
# Basic compilation
compiled_model = torch.compile(model)
output = compiled_model(x) # First call: compilation happens here
# Explain what torch.compile would do without running it
explanation = dynamo.explain(model, x)
print(explanation)
# Shows: graphs captured, graph breaks, operations in each graph
# Check if the compilation was "full graph" (no breaks)
print(f"Number of graphs: {explanation.graph_count}")
print(f"Graph break reasons: {explanation.break_reasons}")
Understanding Graph Breaks
Graph breaks are the most common source of underperformance with torch.compile. When dynamo cannot trace through a piece of code, it:
- Ends the current graph
- Emits the captured graph for compilation
- Falls back to eager execution for the untraceable part
- Starts a new graph after
Each break introduces overhead (kernel launch synchronization) and defeats fusion across the break.
import torch
model = torch.nn.Linear(10, 10).cuda()
# CAUSES GRAPH BREAK: print inside forward
class BadModel1(torch.nn.Module):
def forward(self, x):
print(f"Input shape: {x.shape}") # print = graph break
return torch.relu(x)
# CAUSES GRAPH BREAK: tensor.item() - converts to Python scalar
class BadModel2(torch.nn.Module):
def forward(self, x):
if x.mean().item() > 0: # .item() forces graph break
return torch.relu(x)
return torch.sigmoid(x)
# CAUSES GRAPH BREAK: data-dependent Python control flow
class BadModel3(torch.nn.Module):
def forward(self, x):
results = []
for i in range(x.shape[0]): # Dynamic loop - graph break
results.append(x[i] * 2)
return torch.stack(results)
# NO GRAPH BREAK: static control flow
class GoodModel1(torch.nn.Module):
def __init__(self, use_bias):
super().__init__()
self.use_bias = use_bias # Set at init, not from tensor data
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
out = self.linear(x)
if self.use_bias: # Static Python bool - NOT a graph break
out = out + 1.0
return out
# NO GRAPH BREAK: tensor operations without .item()
class GoodModel2(torch.nn.Module):
def forward(self, x):
mask = (x > 0).float()
return x * mask # Element-wise conditional - stays in graph
# Detect graph breaks with explain()
import torch._dynamo as dynamo
bad2 = BadModel2()
x = torch.randn(10, 10)
explanation = dynamo.explain(bad2, x)
print(f"Graphs due to breaks: {explanation.graph_count}")
for reason in explanation.break_reasons:
print(f" Break: {reason.reason}")
AOTAutograd - Capturing Forward and Backward Together
AOTAutograd (Ahead-of-Time Autograd) takes the FX graph from TorchDynamo and transforms it in a critical way: it traces through both the forward pass AND the backward pass together, before any compilation happens.
Why this matters: if you compile the forward pass without considering the backward, the compiler may make choices that are inefficient for the backward (like not saving activations that the backward will need, or saving too many). AOTAutograd produces a joint graph that includes both forward and backward operations. The backend compiler can then make holistic optimization decisions.
import torch
from torch._functorch.aot_autograd import aot_module_simplified
# See what AOTAutograd produces
def my_fn(x, w):
y = torch.mm(x, w)
return torch.relu(y)
x = torch.randn(4, 4, requires_grad=True)
w = torch.randn(4, 4, requires_grad=True)
# AOTAutograd produces two graphs:
# 1. "Forward with saved residuals" - forward pass saving only what backward needs
# 2. "Backward" - backward pass using saved residuals
# These are compiled separately but designed together
def fw_compiler(gm, example_inputs):
print("=== Forward Graph ===")
gm.print_readable()
return gm
def bw_compiler(gm, example_inputs):
print("=== Backward Graph ===")
gm.print_readable()
return gm
# In practice, the backend (Inductor) handles compilation
# This shows the decomposed graphs
compiled = aot_module_simplified(
torch.nn.Sequential(torch.nn.Linear(4, 4), torch.nn.ReLU()),
(x,),
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
)
TorchInductor - The Backend
TorchInductor is the default backend for torch.compile. It takes the FX graph (post-AOTAutograd) and generates:
- Triton kernels for GPU operations (Python-based GPU kernel language)
- C++/OpenMP code for CPU operations
The key optimization TorchInductor performs is operator fusion. Multiple FX graph nodes that could be separate GPU kernels are fused into one Triton kernel.
import torch
import time
# Demonstrate fusion benefit: fused vs unfused operations
def unfused_ops(x, w, b):
"""Typical eager mode execution - separate kernels."""
y = torch.nn.functional.layer_norm(x, x.shape[-1:])
z = torch.nn.functional.linear(y, w, b)
a = torch.nn.functional.gelu(z)
return a + x # Residual add
# torch.compile will fuse these into fewer Triton kernels
compiled_ops = torch.compile(unfused_ops)
x = torch.randn(32, 512, 512, device='cuda', dtype=torch.float16)
w = torch.randn(512, 512, device='cuda', dtype=torch.float16)
b = torch.randn(512, device='cuda', dtype=torch.float16)
# Warmup compiled version
for _ in range(5):
compiled_ops(x, w, b)
torch.cuda.synchronize()
N = 100
# Time eager
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(N): unfused_ops(x, w, b)
torch.cuda.synchronize()
eager_time = (time.perf_counter() - t) / N * 1000
# Time compiled
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(N): compiled_ops(x, w, b)
torch.cuda.synchronize()
compiled_time = (time.perf_counter() - t) / N * 1000
print(f"Eager: {eager_time:.2f}ms")
print(f"Compiled: {compiled_time:.2f}ms")
print(f"Speedup: {eager_time/compiled_time:.2f}x")
Compile Modes and fullgraph
import torch
import torch.nn as nn
class TransformerLayer(nn.Module):
def __init__(self, d_model, nhead):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model)
)
def forward(self, x):
attn_out, _ = self.attn(x, x, x)
x = self.norm1(x + attn_out)
ff_out = self.ff(x)
return self.norm2(x + ff_out)
model = TransformerLayer(512, 8).cuda()
x = torch.randn(16, 64, 512, device='cuda')
# Mode 1: default - balance compilation time and runtime performance
model_default = torch.compile(model, mode="default")
# Mode 2: reduce-overhead - minimize kernel launch overhead (good for small batches)
model_reduced = torch.compile(model, mode="reduce-overhead")
# Uses CUDA graphs to batch kernel launches into a single replay
# Mode 3: max-autotune - maximize runtime performance (slow compilation)
# Profiles multiple kernel implementations, picks the fastest
model_autotuned = torch.compile(model, mode="max-autotune")
# fullgraph=True - fail hard if there are any graph breaks
# Use this to ensure the entire model compiles
try:
model_fullgraph = torch.compile(model, fullgraph=True)
y = model_fullgraph(x)
print("Full graph compilation succeeded")
except Exception as e:
print(f"Graph break detected: {e}")
# dynamic=True - handle dynamic input shapes without recompiling
# Default: recompile for each new shape
model_dynamic = torch.compile(model, dynamic=True)
# Debugging with backend="eager" - runs graph through eager PyTorch
# Verifies graph capture is correct without actual compilation
model_debug = torch.compile(model, backend="eager")
# See the generated Triton code (Inductor only)
import os
os.environ["TORCH_COMPILE_DEBUG"] = "1"
# Next compile will write debug files to /tmp/torchinductor_*
Symbolic Shapes and Dynamic Shapes
By default, torch.compile compiles a separate version of your model for each unique input shape. This is called "static shapes" mode. For a model that processes batches of varying size, this can trigger many recompilations.
dynamic=True enables symbolic shapes: the compiler uses symbolic variables for certain dimensions (batch size, sequence length) and generates code that works for any value of those dimensions.
import torch
class DynamicModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(512, 256)
def forward(self, x):
return torch.relu(self.linear(x))
model = DynamicModel().cuda()
# Static compilation: will recompile for each new batch size
static_model = torch.compile(model)
static_model(torch.randn(32, 512, device='cuda')) # Compiles for (32, 512)
static_model(torch.randn(64, 512, device='cuda')) # Recompiles for (64, 512)!
# Dynamic compilation: handles any batch size
dynamic_model = torch.compile(model, dynamic=True)
dynamic_model(torch.randn(32, 512, device='cuda')) # Compiles symbolically
dynamic_model(torch.randn(64, 512, device='cuda')) # No recompilation
dynamic_model(torch.randn(128, 512, device='cuda')) # No recompilation
# Monitor recompilations
import logging
torch._logging.set_logs(recompiles=True)
# Now prints a message whenever a recompilation happens with the reason
torch.export for Deployment
torch.export produces a fully serializable, portable representation of a compiled model. It is intended for production deployment where you want: guaranteed no graph breaks, fully traced computation, deployable to non-Python runtimes.
import torch
from torch.export import export, ExportedProgram
class DeployableModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(256, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 10)
)
def forward(self, x):
return self.layers(x)
model = DeployableModel()
x = torch.randn(1, 256)
# Export with dynamic batch size
from torch.export import Dim
batch_dim = Dim("batch", min=1, max=256)
dynamic_shapes = {"x": {0: batch_dim}}
exported: ExportedProgram = export(
model,
args=(x,),
dynamic_shapes=dynamic_shapes
)
# The exported program is fully serializable
torch.export.save(exported, "model.pt2")
loaded = torch.export.load("model.pt2")
# Verify equivalence
y1 = exported.module()(x)
y2 = loaded.module()(x)
print(f"Max diff: {(y1 - y2).abs().max():.2e}") # ~0
# Inspect the exported graph
print(exported.graph_module.print_readable())
# Run the exported model
y = exported.module()(torch.randn(64, 256)) # Dynamic batch works
print(f"Output shape: {y.shape}")
XLA - Accelerated Linear Algebra
XLA is Google's compiler for ML computations. It was designed primarily for TPUs but works on GPUs too. JAX uses XLA as its exclusive backend. TensorFlow uses XLA optionally.
HLO - The XLA IR
HLO (High Level Operations) is the intermediate representation XLA operates on. It is a computation graph: nodes are operations (Dot, Add, Reshape, etc.), edges are tensors.
Example HLO for a simple linear layer + ReLU:
HloModule simple_model
ENTRY main {
input = f32[32,512]{1,0} parameter(0)
weights = f32[512,256]{1,0} parameter(1)
bias = f32[256]{0} parameter(2)
dot = f32[32,256]{1,0} dot(input, weights),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
broadcast_bias = f32[32,256]{1,0} broadcast(bias), dimensions={1}
add = f32[32,256]{1,0} add(dot, broadcast_bias)
zero = f32[] constant(0)
zero_broadcast = f32[32,256]{1,0} broadcast(zero), dimensions={}
ROOT relu = f32[32,256]{1,0} maximum(add, zero_broadcast)
}
XLA's compiler fuses the add and maximum into a single kernel that reads the dot product once and applies both operations.
jax.jit - XLA JIT in Practice
import jax
import jax.numpy as jnp
import numpy as np
import time
# JAX uses functional style - functions must be pure (no side effects)
# No in-place operations: jnp arrays are immutable
def forward_pass(params, x):
"""A simple MLP forward pass."""
w1, b1, w2, b2 = params
h = jnp.dot(x, w1) + b1
h = jax.nn.relu(h)
return jnp.dot(h, w2) + b2
# jax.jit compiles the function through XLA
jit_forward = jax.jit(forward_pass)
# Initialize parameters
key = jax.random.PRNGKey(0)
w1 = jax.random.normal(key, (512, 1024)) * 0.02
b1 = jnp.zeros(1024)
w2 = jax.random.normal(key, (1024, 256)) * 0.02
b2 = jnp.zeros(256)
params = (w1, b1, w2, b2)
x = jax.random.normal(key, (32, 512))
# First call: XLA compilation
t = time.perf_counter()
y = jit_forward(params, x)
y.block_until_ready() # JAX is async by default - block to get real timing
first_call = time.perf_counter() - t
# Second call: compiled code
t = time.perf_counter()
y = jit_forward(params, x)
y.block_until_ready()
second_call = time.perf_counter() - t
print(f"First call (compile): {first_call*1000:.1f}ms")
print(f"Second call (run): {second_call*1000:.2f}ms")
# JAX makes tracing explicit - you can inspect what is traced
def traced_fn(x):
print(f"Python: type(x) = {type(x)}") # Called only during tracing
return jnp.sum(x ** 2)
traced_jit = jax.jit(traced_fn)
y1 = traced_jit(jnp.array([1.0, 2.0])) # "Python: type(x) = ShapedArray(float32[2])"
y2 = traced_jit(jnp.array([3.0, 4.0])) # No print - using compiled code
XLA Device Setup and Comparison with torch.compile
# JAX XLA device management
import jax
import jax.numpy as jnp
# List available devices
print(jax.devices()) # [CpuDevice(id=0)] or [GpuDevice(id=0, ...]
print(jax.default_backend()) # 'cpu', 'gpu', or 'tpu'
# Move computation to GPU
@jax.jit
def gpu_matmul(a, b):
return jnp.dot(a, b)
# Data on the right device
a = jax.device_put(jnp.ones((1024, 1024)), jax.devices('gpu')[0])
b = jax.device_put(jnp.ones((1024, 1024)), jax.devices('gpu')[0])
c = gpu_matmul(a, b)
# Parallel execution across multiple GPUs
@jax.pmap # Parallelizes over the first axis across all GPUs
def parallel_forward(params, x):
return jnp.dot(x, params['w']) + params['b']
# Replicate params across GPUs, shard x across GPUs
n_gpus = len(jax.devices('gpu'))
params_replicated = jax.device_put_replicated(
{'w': jnp.eye(512), 'b': jnp.zeros(512)},
jax.devices('gpu')
)
x_sharded = jnp.array(
np.random.randn(n_gpus, 32, 512) # First axis = number of GPUs
)
# Comparing JAX vs PyTorch compilation styles
# JAX approach: explicit JIT, functional, static shapes by default
import jax
import jax.numpy as jnp
@jax.jit
def jax_model(params, x):
for w, b in zip(params['weights'], params['biases']):
x = jax.nn.relu(jnp.dot(x, w) + b)
return x
# PyTorch approach: implicit JIT via torch.compile, OOP, dynamic-friendly
import torch
class TorchModel(torch.nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers:
x = torch.relu(layer(x))
return x
compiled_torch = torch.compile(TorchModel([torch.nn.Linear(512, 512) for _ in range(4)]))
# Key differences:
# JAX: functional, pure functions, explicit prng, explicit devices
# PyTorch: OOP, stateful modules, autograd integrated, easier debugging
# Both use XLA/LLVM for code generation on the backend
# JAX is more aggressive about compilation; torch.compile is more permissive
tf.function vs jax.jit vs torch.compile
# TensorFlow tf.function - XLA optional
import tensorflow as tf
import numpy as np
@tf.function
def tf_forward(x, w, b):
return tf.nn.relu(tf.matmul(x, w) + b)
# With XLA compilation
@tf.function(jit_compile=True) # Enables XLA
def tf_xla_forward(x, w, b):
return tf.nn.relu(tf.matmul(x, w) + b)
x = tf.constant(np.random.randn(32, 512).astype(np.float32))
w = tf.constant(np.random.randn(512, 256).astype(np.float32))
b = tf.constant(np.random.randn(256).astype(np.float32))
# First call: compilation
y1 = tf_forward(x, w, b)
y2 = tf_xla_forward(x, w, b) # More aggressive XLA compilation
# Summary table:
"""
Feature | tf.function | jax.jit | torch.compile
---------------------|-------------|---------|---------------
Language | TF graph | JAX/NumPy | PyTorch
Backend | XLA optional| XLA | Triton/LLVM
Handles Python ctrl | Limited | No | Yes (dynamo)
Dynamic shapes | Limited | Recompile| Optional
Eager fallback | No | No | Yes (graph break)
Debug story | Poor | Fair | Best
Production maturity | Very high | High | Growing fast
"""
When torch.compile Helps vs Hurts
import torch
import torch.nn as nn
import time
def benchmark(model, input_fn, n_warmup=10, n_runs=100, device='cuda'):
"""Proper GPU benchmark with synchronization."""
model = model.to(device)
compiled = torch.compile(model)
# Warmup
for _ in range(n_warmup):
x = input_fn()
_ = compiled(x)
torch.cuda.synchronize()
# Eager timing
t = time.perf_counter()
for _ in range(n_runs):
x = input_fn()
y = model(x)
torch.cuda.synchronize()
eager_ms = (time.perf_counter() - t) / n_runs * 1000
# Compiled timing (after warmup)
t = time.perf_counter()
for _ in range(n_runs):
x = input_fn()
y = compiled(x)
torch.cuda.synchronize()
compiled_ms = (time.perf_counter() - t) / n_runs * 1000
return eager_ms, compiled_ms
# CASE 1: Large transformer layer - torch.compile helps significantly
class LargeTransformer(nn.Module):
def __init__(self):
super().__init__()
self.attn = nn.MultiheadAttention(1024, 16, batch_first=True)
self.norm1 = nn.LayerNorm(1024)
self.ff = nn.Sequential(
nn.Linear(1024, 4096), nn.GELU(), nn.Linear(4096, 1024)
)
self.norm2 = nn.LayerNorm(1024)
def forward(self, x):
x = self.norm1(x + self.attn(x, x, x)[0])
return self.norm2(x + self.ff(x))
# CASE 2: Single large matrix multiply - compile does NOT help
class SingleMatmul(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4096, 4096)
def forward(self, x):
return self.linear(x) # Single op, nothing to fuse
# CASE 3: Lots of tiny ops - compile helps a lot
class TinyOps(nn.Module):
def forward(self, x):
# Many element-wise ops - lots of intermediate tensors
x = x * 2.0
x = x + 1.0
x = torch.sigmoid(x)
x = x * x
x = torch.tanh(x)
x = x / (x.abs() + 1.0)
return x
# WHEN compile HURTS:
# - Very small models (compilation overhead > runtime benefit)
# - Models with many graph breaks (overhead without fusion benefit)
# - Dynamic shapes with frequent recompilation
# - One-shot inference (pay compile cost, run once)
print("""
torch.compile helps most when:
1. Model has many fusible element-wise ops (norm, activation, residual)
2. Batch size is large (amortizes kernel launch overhead)
3. Model runs thousands of forward passes (amortizes compile cost)
4. Static shapes (no recompilation)
torch.compile helps least when:
1. Single large compute op (matmul, conv) - already optimized by cuBLAS
2. Many graph breaks - defeats fusion
3. Dynamic shapes with variable sequence lengths
4. Small models / small batch sizes
""")
:::warning Graph Break Silent Performance Loss
The most dangerous torch.compile failure mode is silent graph breaks. Your code runs correctly. The compiler runs. But due to a .item() call, a print statement, or a Python built-in that dynamo cannot trace, the graph is split into 5 pieces with eager execution between them. You get 10% speedup instead of 40%. Always run torch._dynamo.explain(model, example_input) to check the number of graphs and break reasons before deploying compiled models.
:::
:::danger Recompilation Loops
If torch.compile sees a new input shape on every call, it recompiles on every call. This is catastrophically slow. Common causes: variable-length sequence processing without padding to fixed length, training loops where batch size changes, or using dynamic=False (default) with truly variable shapes. Monitor recompilations with torch._logging.set_logs(recompiles=True).
:::
Production Engineering Notes
Compilation budget in training. For large models trained for days, the one-time compilation cost is irrelevant. For short experiments or hyperparameter sweeps, compilation cost may dominate. Use torch.compile in production training runs; skip it in 5-minute experiments.
Mode selection. mode="reduce-overhead" uses CUDA graphs which record and replay kernel sequences. This eliminates the CPU overhead of launching each kernel individually - critical for small batch inference where CPU overhead is the bottleneck. It requires static input shapes.
Saving and loading compiled models. torch.compile compilation is not saved with torch.save(model.state_dict()). The next load will recompile. For deployment, use torch.export which saves the compiled graph. For training checkpoints, save the eager model and recompile after loading.
Debugging compiled models. Use backend="eager" to run the captured graph without optimization - this verifies graph capture correctness. Use backend="aot_eager" to also test AOTAutograd. Use TORCH_COMPILE_DEBUG=1 to write detailed debug files. Use torch._dynamo.explain() before every deployment to catch graph breaks.
import torch
# Production checklist for torch.compile deployment
model = torch.nn.Transformer(d_model=512, nhead=8)
example_input = torch.randn(10, 32, 512)
# Step 1: Check for graph breaks
import torch._dynamo as dynamo
explanation = dynamo.explain(model, example_input)
if explanation.graph_count > 1:
print(f"WARNING: {explanation.graph_count} graphs due to breaks:")
for r in explanation.break_reasons:
print(f" {r.reason} at {r.user_stack}")
# Step 2: Verify correctness with eager backend
model_eager_check = torch.compile(model, backend="eager")
y_eager = model_eager_check(example_input)
y_normal = model(example_input)
assert torch.allclose(y_eager, y_normal, atol=1e-5), "Compilation changed output!"
# Step 3: Compile with appropriate mode
compiled = torch.compile(model, mode="reduce-overhead", fullgraph=True)
# Step 4: Warmup to trigger compilation
for _ in range(10):
_ = compiled(example_input)
torch.cuda.synchronize()
print("Warmup complete, model ready for serving")
# Step 5: Monitor for recompilation in serving
torch._logging.set_logs(recompiles=logging.WARNING)
Interview Q&A
Q1: Explain what torch.compile does at a high level. What are the three main components?
torch.compile is a JIT compiler for PyTorch models that captures a graph of operations, optimizes them (primarily through fusion), and generates efficient GPU code. The three main components are:
-
TorchDynamo: the frontend. It hooks into CPython's frame evaluation (
PEP 523) to intercept bytecode execution. Rather than requiring you to write code in a special restricted format, Dynamo watches your Python code run and records PyTorch operations as an FX graph. When it encounters Python control flow it cannot trace (likeif tensor.item() > 0), it hits a "graph break" and falls back to eager execution for that portion. -
AOTAutograd: captures the forward and backward pass together before compilation. This allows the backend compiler to make holistic decisions - for example, understanding which tensors from the forward pass need to be saved for the backward, avoiding unnecessary saves or enabling fused forward-backward kernels.
-
TorchInductor: the backend. Takes the FX graph and generates Triton kernels (for GPU) or C++/OpenMP code (for CPU). The key optimization is operator fusion: multiple operations that would be separate GPU kernel launches are fused into a single kernel that reads input once and computes all operations before writing output.
Q2: What is a graph break in torch.compile? What causes them and how do you detect and fix them?
A graph break occurs when TorchDynamo cannot continue tracing a computation graph through a piece of Python code. When a break happens, Dynamo: ends the current graph, compiles what it has captured so far, falls back to eager PyTorch execution for the untraceable portion, then starts a new graph after.
Common causes:
tensor.item()- converts a tensor to a Python scalar, making the value data-dependentprint(tensor)- accesses tensor data for display- Calls to unsupported Python libraries
- Data-dependent control flow:
if tensor.mean() > 0.5: ... - Dynamic shapes with unknown dimensions
Detection: torch._dynamo.explain(model, example_input) - prints the number of graphs, the break reasons, and the stack traces where breaks occur.
Fixes:
- Replace
tensor.item()with tensor operations where possible - Remove debug prints from forward methods
- Replace data-dependent branching with masked tensor operations
- Use
fullgraph=Truewhich raises an error on any break, forcing you to fix them
Q3: What is XLA and how does it differ from TorchInductor?
XLA (Accelerated Linear Algebra) is Google's compiler for ML computations, developed primarily for TPUs but also supporting GPUs and CPUs. It:
- Takes HLO (High Level Operations) as its IR - a graph of operations
- Applies algebraic simplifications, operation fusion, layout optimization
- Generates code for TPU systolic arrays, GPU PTX, or CPU LLVM
- Is the exclusive backend for JAX and optional for TensorFlow
Key differences from TorchInductor:
| Aspect | XLA | TorchInductor |
|---|---|---|
| Primary language | JAX, TensorFlow | PyTorch |
| IR format | HLO graphs | FX graphs |
| Code generation | Custom backends per target | Triton (GPU), C++ (CPU) |
| Dynamic shapes | Recompiles by default | Optional dynamic mode |
| Control flow | Static only in JIT path | Python allowed (graph breaks) |
| TPU support | First-class | Not supported |
| Maturity | ~7 years production | ~2 years |
XLA is more aggressive about requiring static shapes and pure functions. TorchInductor is more permissive and integrates better with Python.
Q4: Why does operator fusion improve GPU performance? Give a concrete example.
Modern GPUs have two bottlenecks: compute (TFLOPS) and memory bandwidth (TB/s). For many ML operations - especially element-wise ones like activations, normalization, residual adds - the bottleneck is memory bandwidth, not compute. The operation is trivially fast once the data is in registers; the cost is reading from and writing to HBM (high bandwidth memory).
Without fusion, layer_norm(x) + linear(layer_norm_output) does:
- Read
xfrom HBM for layer_norm computation - Write layer_norm result to HBM
- Read layer_norm result from HBM for linear
- Write linear result to HBM
With fusion:
- Read
xfrom HBM - Compute layer_norm AND linear in registers/shared memory
- Write final result to HBM
Four HBM accesses become two. For large tensors and bandwidth-bound operations, this ~2x reduction in memory traffic translates directly to ~2x speedup.
For a model with 100 such fusible operation pairs, fusion is the difference between 100 kernel launches reading 200 tensors from HBM vs 50 kernel launches reading 100 tensors from HBM.
Q5: How does jax.jit handle tracing? What is a traced value and what are the implications for Python control flow?
jax.jit uses tracing: when first called, it executes the function with abstract "tracer" values instead of real arrays. Tracers record what operations are applied to them but do not have concrete values. The result is an XLA computation graph.
A traced value (jax.core.ShapedArray) has a known dtype and shape but no concrete data. Operations on tracers produce more tracers. At the end of the function, JAX has a complete XLA computation graph which it compiles.
Implications for Python control flow:
import jax
import jax.numpy as jnp
@jax.jit
def conditional_fn(x):
if x.sum() > 0: # PROBLEM: x.sum() is a tracer - no concrete value!
return x * 2
return x * -1
# This fails or gives wrong results because Python 'if' needs a concrete bool
JAX's solution: use jax.lax.cond for data-dependent branching:
@jax.jit
def safe_conditional(x):
return jax.lax.cond(
x.sum() > 0,
lambda x: x * 2,
lambda x: x * -1,
x
)
This contrasts with torch.compile / TorchDynamo which handles Python control flow by either tracing through static branches or inserting graph breaks for data-dependent ones.
Q6: When should you use torch.compile vs torch.export for deployment?
torch.compile:
- Development and training workflows
- When your model has Python control flow that you want to keep
- When input shapes may vary
- When you want easy debugging (graph breaks fall back gracefully)
- When you accept the warmup cost on each process restart
torch.export:
- Production serving systems
- When you need guaranteed zero Python overhead (no fallback to eager)
- When deploying to non-Python runtimes (C++ servers, mobile, edge)
- When you want serializable, portable model representation
- When you need static shape guarantees for serving
Use torch.export when: you have finalized the model, you know the input shape constraints, you are deploying to a production system with strict latency requirements, and you want the deployment artifact to be independent of the Python development environment.
torch.compile is for iteration speed; torch.export is for deployment confidence.
Q7: Explain what symbolic shapes are in torch.compile and why they matter.
By default, torch.compile treats tensor shapes as concrete values. If you call the model with batch size 32, it compiles code specialized for batch size exactly 32. If you then call with batch size 64, it recompiles - triggering another 1-3 seconds of Triton kernel generation.
Symbolic shapes (enabled with dynamic=True or torch.compile(model, dynamic=True)) make certain dimensions symbolic variables. The compiler generates code that works for any value of those dimensions:
# Static shapes: compiler assumes batch=32, seq_len=128
# Compiles loop bounds as constants: for i in range(32): ...
# Symbolic shapes: compiler treats batch as symbolic 'b'
# Generates: for i in range(b): ... with b as a runtime parameter
The tradeoff: symbolic shape code is less optimized than shape-specialized code (you lose opportunities to unroll loops based on known sizes, to apply specific SIMD patterns for known dimensions, etc.). The benefit is no recompilation for different shapes.
In practice: use static shapes for inference with known, fixed input shapes (most production API servers). Use dynamic shapes for training with variable sequence lengths (NLP) or variable batch sizes (dynamic batching inference).
