Kernel Fusion Strategies
Reading time: ~40 min · Interview relevance: High · Target roles: ML Systems Engineer, CUDA Developer, Performance Engineer
Every unfused CUDA kernel reads inputs from HBM and writes outputs back to HBM. Fusing three elementwise ops into one kernel turns three HBM round-trips into one - a 3x bandwidth reduction with zero change to output values.
The Production Scenario
The inference team benchmarks a new transformer variant. The forward pass takes 28ms per batch. The hardware team says the A100 is capable of at least 8ms for this workload based on FLOP counts. Something is wrong.
They profile with Nsight Compute. The result is revealing: 8ms of actual compute, 20ms of HBM traffic between operations. The model has 47 separate CUDA kernel launches in its forward pass - 34 of them are elementwise operations like ReLU, bias add, dropout, and layer normalization that each launch their own kernel, read from HBM, do minimal work, and write back to HBM. The GPU spend most of its time loading and storing data rather than computing.
They apply torch.compile. It fuses 40% of those kernels automatically, reducing the launch count to 28. The forward pass drops to 18ms.
They are not satisfied. They write a custom Triton kernel that fuses the entire layer normalization sequence (mean, variance, normalize, scale, shift) into one pass. A second custom kernel fuses the attention bias and dropout. Total kernel launches: 19. Forward pass: 11ms.
The remaining gap between 11ms and the 8ms theoretical minimum is unavoidable kernel launch overhead and compute-bound matmuls. They are now memory-efficient.
This lesson explains exactly why this happened and how to do it yourself.
Why This Exists - The Cost of Every Kernel Launch
HBM Round-Trips are Expensive
Modern GPUs have a strict memory hierarchy. Computation happens in registers and shared memory (SRAM) on the chip. Data lives in HBM off-chip. Every time a value is produced by one kernel and consumed by the next, it makes a round-trip to HBM and back.
The A100 SXM4 has 2 TB/s HBM bandwidth. That sounds fast. But consider a simple three-operation chain:
x = relu(linear(x) + bias)
Without fusion, PyTorch launches three separate kernels:
linear: reads input () and weight (), writes output () to HBMbias_add: reads output from HBM, reads bias, writes result back to HBMrelu: reads from HBM, writes back to HBM
For a single token with , (typical LLM hidden size):
- Activation size: bytes = 8 KB per operation
- Three operations: 3 reads + 3 writes = 6 HBM accesses = 48 KB total
With fusion into one kernel:
- Read input once, compute linear + bias + relu in registers, write output once
- 2 HBM accesses = 16 KB total
A 3x reduction in HBM traffic for free, with zero change to the result.
Why Elementwise Operations Are Especially Wasteful
For elementwise operations (add, multiply, ReLU, sigmoid, etc.), the arithmetic intensity is near zero. The operation does 1-2 FLOPs per element while reading and writing 4-8 bytes. An isolated ReLU on a 4096-element vector:
The A100's hardware threshold (roofline peak) is 156 FLOP/byte. A standalone ReLU is running at 1/1248th of potential throughput, spending 99.99% of time waiting for HBM.
When you fuse five elementwise ops together, the FLOP count per byte loaded goes up by 5x. The data moves once, five operations happen in registers, and the result is written once. This is the essence of fusion.
Historical Context
Before Fusion: The Framework Tax
Early deep learning frameworks (Caffe, early Theano) executed each layer as a separate kernel. As models grew deeper and more complex, the overhead of kernel boundaries became a significant fraction of total runtime - sometimes 30-50% of wall clock time on fast hardware.
The first systematic approach to fusion was XLA (Accelerated Linear Algebra), developed at Google for TensorFlow in 2016. XLA's fusion pass analyzed the computation graph and identified "fusible subgraphs" - chains of pointwise operations, reductions, and broadcasts that could be compiled into single CUDA kernels. The results were dramatic for inference workloads.
PyTorch's approach was different. Its dynamic computation graph made static compilation harder. TorchScript attempted fusion via just-in-time compilation, but its adoption was limited. The real breakthrough came with torch.compile (introduced in PyTorch 2.0, 2023), which brought a production-quality fusion compiler to the PyTorch ecosystem through TorchInductor.
Meanwhile, the research community was writing manual fused kernels for specific high-value operations. FlashAttention (2022) is the canonical example: a manually written CUDA kernel that fuses the entire attention computation. Apex (NVIDIA, 2018) provided fused layer norm and fused Adam optimizer. xFormers (Meta, 2021) provided fused attention variants for different hardware profiles.
Types of Fusion
Vertical Fusion (Producer-Consumer Chains)
Vertical fusion connects operations where the output of one feeds directly into the next. This is the most common and highest-value type.
Without fusion: With fusion:
input (HBM) ──> linear ──> HBM input (HBM) ──> fused_kernel ──> HBM
HBM ──> bias_add ──> HBM (everything in registers)
HBM ──> layernorm ──> HBM
HBM ──> relu ──> HBM
The condition for vertical fusion: each intermediate result is used exactly once. If an intermediate is used by multiple downstream ops, you cannot fuse the entire chain without duplicating work.
Classic examples:
bias + layernorm + activation(common in transformer blocks)dropout + add + layernorm(residual connections)matmul + bias + activation(standard dense layer)mean + variance + normalize + scale + shift(the full layer norm sequence)
Horizontal Fusion (Independent Parallel Operations)
Horizontal fusion combines independent operations that happen on the same data. Instead of launching two separate kernels that both read the same tensor from HBM, one kernel handles both.
Without fusion: With fusion:
x ──> relu(x) ──> HBM x ──> single kernel ──> relu_out (HBM)
x ──> gelu(x) ──> HBM └──> gelu_out (HBM)
(x is read from HBM twice) (x is read once)
Less common than vertical fusion but valuable in:
- Multi-head attention: all attention heads are independent
- MoE routing: expert computations are independent
- Multi-task models: task-specific heads operating on shared features
Reduction Fusion (Elementwise + Reduce)
Reductions (sum, mean, max) are expensive because they require all threads to cooperate. But if the reduction is preceded or followed by elementwise operations, fusing them eliminates intermediate writes.
Layer normalization is the canonical example:
Without fusion (5 separate kernels):
1. mean(x) ──> HBM
2. x - mean ──> HBM
3. variance(x - mean) ──> HBM
4. normalize: (x - mean) / sqrt(var + eps) ──> HBM
5. scale + shift: gamma * norm + beta ──> HBM
With fusion (1 kernel):
1. Compute mean (reduction) in shared memory
2. Compute variance (reduction) in shared memory
3. Normalize + scale + shift in registers
→ Write final output once to HBM
A 5x reduction in HBM writes, plus elimination of 4 intermediate allocations.
torch.compile Automatic Fusion
How torch.compile Works
When you call torch.compile(model), PyTorch traces the model's forward pass and builds a computation graph. The default backend is TorchInductor, which:
- Partitions the graph into fusible regions and non-fusible barriers (e.g., operations with complex data dependencies, in-place ops, certain reductions)
- Generates optimized CUDA code for fusible regions via NVFuser (NVIDIA) or Triton-based codegen
- Compiles to native GPU code and caches the result for reuse
import torch
import torch.nn as nn
# Simple model with many small operations
class TransformerBlock(nn.Module):
def __init__(self, d_model=1024, num_heads=16):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
)
def forward(self, x):
attn_out, _ = self.attn(x, x, x)
x = self.norm1(x + attn_out) # residual + layernorm: fusible
ff_out = self.ff(x)
x = self.norm2(x + ff_out) # residual + layernorm: fusible
return x
model = TransformerBlock().cuda().to(torch.bfloat16)
x = torch.randn(2, 512, 1024, device='cuda', dtype=torch.bfloat16)
# Eager mode (no fusion)
eager_out = model(x)
# Compiled mode (with automatic fusion)
compiled_model = torch.compile(model, backend='inductor')
compiled_out = compiled_model(x) # First call: compiles (slow)
compiled_out = compiled_model(x) # Second call: uses cached compiled kernel
print(torch.allclose(eager_out, compiled_out, atol=1e-2)) # True - same result
Benchmarking torch.compile Fusion
import torch
import time
def benchmark_fusion(model, x, n_warmup=10, n_iters=100, label=""):
"""Measure throughput with and without fusion."""
# Warmup
for _ in range(n_warmup):
_ = model(x)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_iters):
_ = model(x)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
ms_per_iter = (elapsed / n_iters) * 1000
print(f"{label}: {ms_per_iter:.2f}ms per forward pass")
return ms_per_iter
model = TransformerBlock().cuda().to(torch.bfloat16)
x = torch.randn(4, 2048, 1024, device='cuda', dtype=torch.bfloat16)
eager_ms = benchmark_fusion(model, x, label="Eager (unfused)")
compiled = torch.compile(model, backend='inductor')
_ = compiled(x) # trigger compilation
compiled_ms = benchmark_fusion(compiled, x, label="torch.compile (fused)")
print(f"Speedup from fusion: {eager_ms / compiled_ms:.2f}x")
Inspecting Fusion Decisions
torch.compile provides debugging tools to see exactly what was fused:
import os
# Enable debug output - shows fusion groups and generated code
os.environ['TORCH_COMPILE_DEBUG'] = '1'
# Alternatively, use the explain API
from torch._dynamo import explain
explanation = explain(model)(x)
print(explanation.graphs) # Captured computation graphs
print(explanation.break_reasons) # Why certain ops broke the graph
print(explanation.ops_per_graph) # Ops in each fused group
# For detailed kernel-level inspection, use torchdynamo dump:
torch._dynamo.config.log_level = 'DEBUG'
# TorchInductor-specific debug (shows generated triton code):
os.environ['TORCH_LOGS'] = '+inductor'
compiled_model = torch.compile(model, backend='inductor')
_ = compiled_model(x)
# Look in /tmp/torchinductor_*/output_code.py for generated Triton kernels
What torch.compile Can and Cannot Fuse
# FUSIBLE by torch.compile / TorchInductor:
# 1. Pointwise op chains
y = torch.relu(x + bias) # Fused: bias_add + relu in one kernel
y = x * scale + shift # Fused: two elementwise ops
y = torch.sigmoid(x) * x # Fused: SiLU activation
# 2. Reduction followed by broadcast
mean = x.mean(dim=-1, keepdim=True)
y = x - mean # Fused: mean + subtract in one kernel
# 3. Some matmul + pointwise patterns
y = torch.nn.functional.linear(x, w, b) # matmul + bias: may fuse
# NOT FUSIBLE without manual work:
# 1. Ops with complex indexing patterns
y = x[indices] # Gather/scatter breaks fusion graphs
# 2. Operations that require synchronization between tiles
y = F.softmax(x, dim=-1) # Softmax needs global reduction, hard to fuse
# 3. In-place operations that write to inputs
x.add_(y) # In-place can break data dependency analysis
# 4. Graph breaks caused by Python control flow
if x.sum() > 0: # Python-level condition: graph break
y = x * 2
else:
y = x * 3
Manual Fusion with Triton
When torch.compile cannot fuse a pattern, or when you need maximum performance for a critical operation, manual fusion with Triton gives you full control.
Layer Normalization: The Canonical Fusion Example
Standard layer norm executes as 5 separate operations. A fused kernel does it in one pass:
import triton
import triton.language as tl
import torch
@triton.jit
def fused_layer_norm_kernel(
X_ptr, W_ptr, B_ptr, Y_ptr, # input, weight (gamma), bias (beta), output
N, # number of elements per row (hidden_dim)
eps, # epsilon for numerical stability
stride_x, # stride between rows
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be power of 2)
):
"""
Fused layer norm: compute mean, variance, normalize, scale, shift in one pass.
Each program instance handles one row of X.
"""
# Which row are we processing?
row_idx = tl.program_id(0)
# Pointer to start of our row
X_row_ptr = X_ptr + row_idx * stride_x
Y_row_ptr = Y_ptr + row_idx * stride_x
# --- Pass 1: Compute mean ---
# Each thread handles BLOCK_SIZE elements; we loop if N > BLOCK_SIZE
mean = tl.zeros([1], dtype=tl.float32)
for block_start in range(0, N, BLOCK_SIZE):
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
x = tl.load(X_row_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
mean += tl.sum(x, axis=0)
mean = mean / N
# --- Pass 2: Compute variance ---
var = tl.zeros([1], dtype=tl.float32)
for block_start in range(0, N, BLOCK_SIZE):
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
x = tl.load(X_row_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
diff = x - mean
var += tl.sum(diff * diff, axis=0)
var = var / N
rstd = 1.0 / tl.sqrt(var + eps)
# --- Pass 3: Normalize, scale (gamma), and shift (beta) ---
for block_start in range(0, N, BLOCK_SIZE):
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < N
x = tl.load(X_row_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
w = tl.load(W_ptr + offsets, mask=mask, other=1.0).to(tl.float32)
b = tl.load(B_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
y = (x - mean) * rstd * w + b
tl.store(Y_row_ptr + offsets, y.to(tl.float16), mask=mask)
def triton_layer_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
eps: float = 1e-5) -> torch.Tensor:
"""
Layer norm implemented as a single fused Triton kernel.
x: [batch * seq_len, hidden_dim] - must be contiguous
"""
assert x.is_contiguous(), "Input must be contiguous"
assert x.dtype == torch.float16, "Input must be fp16"
rows, N = x.shape
y = torch.empty_like(x)
# Choose block size: must be power of 2, <= N, trade-off between
# registers and parallelism
BLOCK_SIZE = min(triton.next_power_of_2(N), 2048)
# Launch one program per row
grid = (rows,)
fused_layer_norm_kernel[grid](
x, weight, bias, y,
N=N,
eps=eps,
stride_x=x.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
)
return y
# Verify correctness
def verify_layer_norm_fusion():
torch.manual_seed(42)
batch, seq, hidden = 4, 512, 1024
x = torch.randn(batch * seq, hidden, device='cuda', dtype=torch.float16)
gamma = torch.ones(hidden, device='cuda', dtype=torch.float16)
beta = torch.zeros(hidden, device='cuda', dtype=torch.float16)
# PyTorch reference
ln = torch.nn.LayerNorm(hidden, device='cuda', dtype=torch.float16)
ref = ln(x.view(batch, seq, hidden)).view(batch * seq, hidden)
# Our fused kernel
out = triton_layer_norm(x, gamma, beta)
max_err = (ref - out).abs().max().item()
print(f"Max absolute error vs PyTorch LayerNorm: {max_err:.6f}")
assert max_err < 0.01, f"Error too large: {max_err}"
print("Correctness check passed!")
if torch.cuda.is_available():
verify_layer_norm_fusion()
Fused Bias + Activation (GELU/SiLU)
The linear layer pattern W @ x + b followed by activation is extremely common and worth fusing:
@triton.jit
def fused_bias_gelu_kernel(
X_ptr, # Input: output of matmul, shape [rows, cols]
B_ptr, # Bias vector, shape [cols]
Y_ptr, # Output, shape [rows, cols]
cols, # Number of columns (hidden dim)
stride, # Row stride of X
BLOCK: tl.constexpr,
):
"""
Fuses: Y = gelu(X + bias)
One program per row.
"""
row = tl.program_id(0)
X_ptr = X_ptr + row * stride
Y_ptr = Y_ptr + row * stride
for col_start in range(0, cols, BLOCK):
offsets = col_start + tl.arange(0, BLOCK)
mask = offsets < cols
x = tl.load(X_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
b = tl.load(B_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
val = x + b
# Approximate GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
# (same approximation PyTorch uses by default)
c = 0.044715 * val * val * val
tanh_arg = 0.7978845608 * (val + c)
gelu_val = 0.5 * val * (1.0 + tl.math.tanh(tanh_arg))
tl.store(Y_ptr + offsets, gelu_val.to(tl.float16), mask=mask)
def fused_bias_gelu(x: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""x: [rows, cols] fp16, bias: [cols] fp16"""
rows, cols = x.shape
y = torch.empty_like(x)
BLOCK = min(triton.next_power_of_2(cols), 1024)
fused_bias_gelu_kernel[(rows,)](x, bias, y, cols, x.stride(0), BLOCK=BLOCK)
return y
Fused Softmax + Dropout
Combining softmax and dropout is valuable in attention computation during training:
@triton.jit
def fused_softmax_dropout_kernel(
X_ptr,
Y_ptr,
N, # row length
dropout_prob, # dropout probability (probability of zeroing)
seed, # random seed for reproducibility
stride,
BLOCK: tl.constexpr,
):
"""
Fuses softmax followed by dropout into a single kernel.
Avoids writing the post-softmax matrix to HBM before dropout reads it.
"""
row = tl.program_id(0)
X_ptr = X_ptr + row * stride
Y_ptr = Y_ptr + row * stride
# Step 1: Find row max (for numerically stable softmax)
row_max = tl.full([1], float('-inf'), dtype=tl.float32)
for start in range(0, N, BLOCK):
offsets = start + tl.arange(0, BLOCK)
mask = offsets < N
x = tl.load(X_ptr + offsets, mask=mask, other=float('-inf')).to(tl.float32)
row_max = tl.maximum(row_max, tl.max(x, axis=0))
# Step 2: Compute sum of exp(x - max)
row_sum = tl.zeros([1], dtype=tl.float32)
for start in range(0, N, BLOCK):
offsets = start + tl.arange(0, BLOCK)
mask = offsets < N
x = tl.load(X_ptr + offsets, mask=mask, other=float('-inf')).to(tl.float32)
row_sum += tl.sum(tl.exp(x - row_max), axis=0)
# Step 3: Normalize and apply dropout in one pass
scale = 1.0 / (1.0 - dropout_prob) # compensate for dropped values
for start in range(0, N, BLOCK):
offsets = start + tl.arange(0, BLOCK)
mask = offsets < N
x = tl.load(X_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
softmax_val = tl.exp(x - row_max) / row_sum
# Philox random number generator for dropout
rand = tl.rand(seed, tl.program_id(0) * N + offsets)
keep = rand > dropout_prob
y = tl.where(keep, softmax_val * scale, 0.0)
tl.store(Y_ptr + offsets, y.to(tl.float16), mask=mask)
Architecture Diagram: Fusion Types
xFormers: Pre-Fused Operations
Meta's xFormers library provides pre-written fused kernels for common transformer operations:
# pip install xformers
from xformers.ops import memory_efficient_attention, LowerTriangularMask
# Fused memory-efficient attention (alternative to FlashAttention)
def xformers_attention(q, k, v, causal=True):
"""
q, k, v: [batch, seq, heads, head_dim] - note the heads-last format
Returns: [batch, seq, heads, head_dim]
"""
attn_bias = LowerTriangularMask() if causal else None
return memory_efficient_attention(q, k, v, attn_bias=attn_bias)
# Fused layer norm (xFormers wraps Apex's fused layer norm)
from xformers.components.feedforward import MLP
# xFormers MLP: fuses the linear + bias + activation in forward
class FusedFeedForward(torch.nn.Module):
def __init__(self, d_model, d_ff, activation='gelu'):
super().__init__()
try:
from xformers.components.activations import Activation
from xformers.factory import xFormerDecoderLayer
# xFormers handles fusion internally when available
self.ff = MLP(dim_model=d_model, dropout=0.0,
activation=activation, hidden_layer_multiplier=d_ff // d_model)
except ImportError:
# Fallback to standard
self.ff = torch.nn.Sequential(
torch.nn.Linear(d_model, d_ff),
torch.nn.GELU(),
torch.nn.Linear(d_ff, d_model)
)
def forward(self, x):
return self.ff(x)
FlashAttention as the Canonical Fusion Example
FlashAttention is the most famous fused kernel in deep learning - it fuses the entire attention computation into one pass. Understanding why it is a fusion example, not just a tiling optimization, is important.
Standard attention has three separate HBM read/write cycles:
- Compute , write S to HBM
- Read S, compute , write P to HBM
- Read P and V, compute , write O to HBM
FlashAttention fuses all three into a single kernel that:
- Reads Q, K, V once from HBM
- Computes S, P, O entirely in SRAM/registers using tiling
- Writes O once to HBM
The intermediate matrices S and P never touch HBM. This is vertical fusion at the algorithm level - the same principle as fusing bias + relu, just applied to the entire attention mechanism.
# The fusion principle in pseudocode:
# UNFUSED (3 separate kernels, 3 HBM write-read pairs):
S = q @ k.transpose(-1, -2) / math.sqrt(d) # → HBM
P = F.softmax(S, dim=-1) # HBM → HBM
O = P @ v # HBM →
# FUSED (1 kernel via FlashAttention):
O = F.scaled_dot_product_attention(q, k, v, is_causal=True)
# Internally: S, P never written to HBM. 3x fewer HBM bytes.
Measuring Fusion Impact in Production
Memory Bandwidth Profiling
import torch
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
def profile_fusion_impact(model_eager, model_compiled, x, output_dir='./prof'):
"""
Compare HBM bandwidth usage between eager and compiled models.
"""
# Profile eager model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
with_flops=True,
) as prof_eager:
for _ in range(5):
_ = model_eager(x)
# Profile compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
with_flops=True,
) as prof_compiled:
for _ in range(5):
_ = model_compiled(x)
# Summarize
print("=== Eager Model ===")
print(prof_eager.key_averages().table(sort_by='cuda_time_total', row_limit=15))
print("\n=== Compiled Model ===")
print(prof_compiled.key_averages().table(sort_by='cuda_time_total', row_limit=15))
# Count unique CUDA kernel launches
eager_kernels = len([e for e in prof_eager.events() if e.device_type == 2])
compiled_kernels = len([e for e in prof_compiled.events() if e.device_type == 2])
print(f"\nEager CUDA launches: {eager_kernels}")
print(f"Compiled CUDA launches: {compiled_kernels}")
print(f"Kernel reduction: {eager_kernels / compiled_kernels:.1f}x")
Quick Bandwidth Estimation
def estimate_memory_traffic(ops_sequence, hidden_dim, seq_len, dtype_bytes=2):
"""
Estimate HBM bytes transferred for a sequence of ops, fused vs unfused.
ops_sequence: list of dicts with 'type' and 'shape' keys
"""
unfused_bytes = 0
fused_bytes = 0
print(f"Sequence: {len(ops_sequence)} operations")
print(f"Tensor: [{seq_len}, {hidden_dim}], dtype_bytes={dtype_bytes}")
print()
tensor_size = seq_len * hidden_dim * dtype_bytes
for i, op in enumerate(ops_sequence):
# Each unfused op: read input + write output
unfused_bytes += 2 * tensor_size
print(f" Op {i+1}: {op['name']} - reads {tensor_size/1024:.1f}KB, writes {tensor_size/1024:.1f}KB")
# Fused: read once, write once
fused_bytes = 2 * tensor_size
print(f"\nUnfused total HBM: {unfused_bytes/1024:.1f} KB ({unfused_bytes/1e6:.2f} MB)")
print(f"Fused total HBM: {fused_bytes/1024:.1f} KB ({fused_bytes/1e6:.2f} MB)")
print(f"Bandwidth reduction: {unfused_bytes/fused_bytes:.1f}x")
# Layer norm sequence example
layernorm_ops = [
{'name': 'mean reduction'},
{'name': 'subtract mean'},
{'name': 'variance reduction'},
{'name': 'normalize'},
{'name': 'scale (gamma) + shift (beta)'},
]
estimate_memory_traffic(layernorm_ops, hidden_dim=1024, seq_len=2048)
Production Engineering Notes
When to Use Each Approach
| Situation | Recommended Approach |
|---|---|
| Standard transformer ops (layernorm, attention, mlp) | torch.compile with inductor backend |
| Need maximum performance for attention | FlashAttention (flash_attn package) |
| Custom op sequence torch.compile won't fuse | Manual Triton kernel |
| JAX/XLA users | XLA auto-fusion (already aggressive) |
| Non-critical paths | Don't bother - correctness first |
torch.compile in Practice
# Production-grade torch.compile configuration
import torch
model = MyTransformerModel().cuda().to(torch.bfloat16)
# Option 1: default inductor (best for most cases)
compiled = torch.compile(model, backend='inductor')
# Option 2: reduce-overhead mode (minimizes kernel launch overhead,
# useful when model has many small ops)
compiled = torch.compile(model, mode='reduce-overhead')
# Option 3: max-autotune (tries many kernel configurations, slow to compile
# but finds optimal for your specific hardware and shapes)
compiled = torch.compile(model, mode='max-autotune')
# Option 4: dynamic shapes (if input shapes vary between batches)
compiled = torch.compile(model, dynamic=True)
# Practical tip: wrap in try/except for CI safety
def safe_compile(model, **kwargs):
"""Compile with fallback to eager on failure."""
try:
return torch.compile(model, **kwargs)
except Exception as e:
print(f"torch.compile failed: {e}, falling back to eager")
return model
Debugging Fusion Failures
# The most common reason torch.compile does NOT fuse:
# 1. Data-dependent control flow (Python-level conditionals)
def bad_forward(x):
if x.max() > 1.0: # Graph break! Python can't trace through this
return x * 2
return x
# Fix: use torch operations that stay in the graph
def good_forward(x):
mask = (x > 1.0).float()
return x * (1 + mask) # Equivalent, fully traceable
# 2. In-place operations that alias inputs
def bad_inplace(x, y):
x.add_(y) # In-place: may break alias analysis
return x
# Fix: use out-of-place
def good_inplace(x, y):
return x + y
# 3. Non-tensor Python objects mid-computation
def bad_mixed(x, python_list):
for val in python_list: # Loop over Python objects: graph break
x = x + val
return x
# Fix: convert to tensor first
def good_mixed(x, python_list):
offsets = torch.tensor(python_list, device=x.device, dtype=x.dtype)
return x + offsets.sum()
Common Mistakes
:::danger Assuming torch.compile Always Fuses Everything
torch.compile is not magic. It cannot fuse through graph breaks (data-dependent control flow, Python-level conditions, non-tensor objects). It cannot fuse arbitrary attention masks. It does not fuse sequences where intermediate outputs are needed by multiple consumers. Always profile before and after to confirm fusion occurred. Use TORCH_COMPILE_DEBUG=1 to see exactly what was fused.
:::
:::danger Register Pressure from Excessive Fusion Fusing many operations increases the number of live values a thread must hold in registers simultaneously. Each intermediate result in a fused kernel consumes registers until it is consumed. Extreme fusion (fusing 10+ operations) can exceed the register file capacity, causing register spilling to local memory - which is slow L2 access, not true register access. Profile with Nsight Compute and check "Register count" and "Local memory transactions". If registers exceed 128-255 per thread, consider splitting the fused kernel. :::
:::warning Compilation Overhead at Startup
torch.compile performs compilation on the first forward pass with a new input shape. For a large transformer model, this can take 30-120 seconds. This is unacceptable for production inference servers that must start quickly. Solutions: (1) pre-compile with torch.compile at startup with a dummy input matching production shapes, (2) use torch.jit.script for faster but less aggressive fusion, (3) export to ONNX + TensorRT which compiles ahead-of-time, (4) use torch.compile(dynamic=True) to handle shape variation without recompilation.
:::
:::warning Fused Kernels are Harder to Debug
When a fused kernel produces NaN or incorrect results, identifying which operation is the source is much harder than with separate kernels. You cannot add intermediate print statements. Debugging strategy: (1) bisect by temporarily running in eager mode with torch.compile disabled via TORCH_COMPILE_DISABLE=1, (2) use torch.autograd.anomaly_mode(check_nan=True) in eager mode first, (3) add torch.use_deterministic_algorithms(True) to catch non-determinism, (4) narrow down with @torch.no_grad() to eliminate backward pass issues.
:::
:::tip FlashAttention is the Highest-Value Fusion
If you can only apply one fusion technique to a transformer model, make it FlashAttention. It reduces attention memory from O(N^2) to O(N) and provides 2-4x speedup over unfused attention. Everything else (layer norm fusion, bias-GELU fusion, etc.) is incremental by comparison. Ensure F.scaled_dot_product_attention is using the flash backend by confirming fp16/bf16 dtype, supported head dimensions, and causal masking.
:::
Interview Questions
Q1: Why does fusing elementwise operations reduce memory bandwidth usage?
Every CUDA kernel launch follows this pattern: read inputs from HBM into chip-local registers, perform the computation, write outputs back to HBM. For elementwise ops (add, relu, multiply, etc.), the computation is trivial - often 1-2 FLOPs per element - but the memory traffic is unavoidable per-kernel.
Consider y = dropout(relu(x + bias)). Unfused:
- Kernel 1 (bias_add): read x (M bytes), read bias (B bytes), write result (M bytes)
- Kernel 2 (relu): read result (M bytes), write relu_out (M bytes)
- Kernel 3 (dropout): read relu_out (M bytes), write y (M bytes)
- Total: 6 HBM accesses of M bytes = 6M bytes
Fused into one kernel:
- Read x once (M bytes), read bias once (B bytes)
- Compute bias_add, relu, dropout entirely in registers
- Write y once (M bytes)
- Total: 2M + B bytes, approximately 3x reduction
The key insight is that intermediate values (the result of bias_add, the result of relu) never need to exist in HBM - they exist transiently in registers and are consumed immediately by the next operation. Fusion keeps them in registers.
Q2: How does torch.compile identify operations to fuse, and what are its limitations?
torch.compile uses a multi-stage pipeline:
-
Dynamo tracing: Python-level tracer records all PyTorch operations as a FX graph (computation graph). Encounters "graph breaks" at Python control flow, non-PyTorch operations, or data-dependent conditionals.
-
Inductor lowering: the FX graph is lowered to Inductor's IR (intermediate representation), which identifies "scheduler nodes" - groups of operations with no intermediate dependencies that can be fused.
-
Fusion heuristics: Inductor applies rules to identify fusible subgraphs. Pointwise ops (add, relu, multiply) always fuse. Reductions fuse with preceding or following pointwise ops. Some matmul+bias patterns fuse. Complex dependencies break fusion.
-
Codegen: fused groups become single Triton kernels, compiled to GPU instructions.
Limitations:
- Graph breaks: any Python control flow,
print(), non-tensor operations, or dynamic shapes between traced ops create graph breaks. Ops across a graph break cannot be fused. - Data-dependent conditions:
if tensor.sum() > 0cannot be traced through. - Multiple consumers: if an intermediate value is consumed by two different downstream ops, it cannot be fused without duplication.
- Some reduction patterns: complex multi-dimensional reductions are not always fusible.
- Custom CUDA extensions: ops implemented as Python-wrapped C++/CUDA are opaque to Inductor and cannot be fused with surrounding ops.
Q3: Give an example of a high-value manual kernel fusion and explain why torch.compile would not handle it.
Layer normalization followed by the first linear layer of a feed-forward network is a high-value fusion target. The sequence is:
y = gamma * (x - mean(x)) / sqrt(var(x) + eps) + beta # Layer norm
z = y @ W_ff + b_ff # First FF linear
a = gelu(z) # Activation
Why this is valuable: the layer norm output y is [seq, hidden] = 2048 x 1024 = 2 MB at fp16. Without fusion, y is written to HBM after layer norm, then immediately read back for the matmul. Fusing them saves a 2 MB write + 2 MB read = 4 MB of HBM traffic per layer.
Why torch.compile may not fuse it: the layer norm involves a two-pass reduction (mean then variance), each requiring all threads to synchronize and share partial sums. The matmul is a separate high-throughput operation with different tiling requirements. Inductor typically fuses the elementwise post-processing of layer norm but does not fuse the reduction with the downstream matmul because the optimal tile sizes and thread organization conflict.
Manual approach: write a Triton kernel that uses two shared memory reduction passes for mean/variance, then immediately multiplies the normalized output by the weight matrix without writing the normalized intermediate to HBM. This requires careful tiling to keep both the normalization and first matmul tile in SRAM simultaneously.
Q4: What is register pressure and how does it limit the degree of fusion you can apply?
Registers are the fastest memory on the GPU - essentially instantaneous access, no cache miss possible. Each CUDA thread has a limited private register file (typically 256 registers per thread on modern GPUs). When you fuse N operations into one kernel, all intermediate values from all N operations must live in registers simultaneously (until they are consumed).
For a fused sequence a = f1(x); b = f2(a); c = f3(b, y):
- x, a, b, c, y must all be live in registers when the thread executes
- Each fp16 value takes 1 register, fp32 takes 1 register (on modern hardware)
- With 32 values in flight, you need 32 registers per thread
When a kernel exceeds the register limit, the compiler "spills" values to local memory, which maps to L2 cache. L2 access is ~50x slower than registers but ~10x faster than HBM. Heavy register spilling can negate the bandwidth benefits of fusion.
How to measure: in Nsight Compute, check "Registers per thread" in the Launch Statistics section. If this approaches 255 and you see non-trivial "Local Memory Transactions," you have register pressure.
Mitigation: (1) split fused kernels at natural boundaries (e.g., separate the layer norm from the linear layer), (2) use fewer intermediate variables in the fused kernel, (3) recompute instead of storing (trading compute for registers), (4) use __launch_bounds__ in CUDA C to hint the compiler about expected occupancy.
Q5: How would you debug a fused kernel that produces incorrect results?
Step 1: Isolate to eager mode. Set TORCH_COMPILE_DISABLE=1 environment variable and re-run. If the bug disappears, it is a fusion-related issue. If it remains, the bug is in the underlying op, not the fusion.
Step 2: Bisect the fused region. If using Triton manually, add tl.device_print() statements to print intermediate values for a single thread block. For torch.compile, temporarily split the operation sequence into two torch.compile regions to isolate which sub-graph introduces the error.
Step 3: Test correctness against eager. For every fused kernel, maintain an unfused reference implementation and add an assertion:
def fused_layernorm_debug(x, gamma, beta, eps=1e-5):
fused_out = triton_layer_norm(x, gamma, beta, eps)
ref_out = F.layer_norm(x.float(), x.shape[-1:],
gamma.float(), beta.float(), eps).to(x.dtype)
max_err = (fused_out - ref_out).abs().max().item()
if max_err > 0.01:
raise ValueError(f"Fused layernorm error: {max_err:.4f}")
return fused_out
Step 4: Check for numerical precision issues. Fused kernels that accumulate in fp16 (instead of fp32) are prone to numerical error. In layer norm, always accumulate mean and variance in fp32 even if inputs are fp16. The Triton kernels above use .to(tl.float32) for exactly this reason.
Step 5: Test edge cases. Fused kernels often fail at boundary conditions: when N is not a multiple of BLOCK_SIZE, when batch size is 1, when the tensor is not contiguous. Add explicit tests for these cases using torch.testing.assert_close.
Q6: A new transformer architecture uses head_dim=96 instead of 64 or 128. How does this affect fusion and what would you do about it?
This is a common practical problem. The impact:
FlashAttention falls back to standard attention. FlashAttention is compiled for head_dim in {64, 128, 256}. head_dim=96 hits the fallback path: the N x N attention matrix is materialized in HBM, memory usage scales as O(N^2), and throughput drops 3-5x for long sequences. This is silent - no error is raised.
torch.compile may still fuse pointwise ops, but the matmul tiling is less optimal because 96 is not a power of 2 and does not align well with Tensor Core matrix shapes (16 x 16). You may see 10-20% lower matmul throughput compared to d=64 or d=128.
Solutions in order of preference:
-
Change head_dim to 128. Architecturally equivalent for most purposes. If your model uses 12 heads at d=96 (total 1152), try 9 heads at d=128 (total 1152). Same total, better hardware alignment.
-
Pad to 128 inside the attention kernel. Zero-pad Q, K, V from d=96 to d=128 before the attention computation, use FlashAttention, then unpad the output. The padding adds minimal overhead but unlocks full FlashAttention acceleration. xFormers supports padding internally.
-
Use xFormers memory_efficient_attention, which has broader head_dim support than flash_attn and may support 96.
-
Accept the performance hit if d=96 is architecturally important (e.g., you're loading a pretrained checkpoint with this shape) and the sequence length is short enough that the O(N^2) cost is acceptable.
Summary
Kernel fusion is one of the highest-leverage performance optimizations available to ML systems engineers, requiring no changes to model architecture or output quality.
The core insight is simple: every unfused operation pays a full HBM round-trip tax. Reading a tensor from HBM, doing minimal work, and writing it back is wasteful when the next operation immediately reads that same data. Fusion keeps intermediate values in registers and eliminates unnecessary HBM traffic.
The practical toolkit has three tiers:
-
torch.compile handles the common case automatically. For standard transformer components (layernorm, bias-add, activation, residual connections),
torch.compile(model, backend='inductor')typically delivers 30-50% speedup with zero code changes. -
FlashAttention handles the highest-value case - attention - by fusing the entire O(N^2) attention computation into one pass. Non-optional for sequences above 2048 tokens.
-
Manual Triton kernels handle cases torch.compile cannot fuse: complex reduction-plus-elementwise chains, multi-stage operations with data-dependent shapes, or performance-critical paths that need hand-tuned tile sizes.
The debugging and measurement discipline matters as much as the optimization itself. Always profile before and after. Always verify correctness against an eager reference. Always check that the fused kernel is actually being invoked rather than silently falling back.
