Triton for Custom Kernels
Reading time: ~45 min · Interview relevance: High · Target roles: ML Systems Engineer, CUDA Developer, Research Engineer
Flash Attention 2 is implemented in Triton. If that is not enough to make you learn Triton, consider that it took a PhD student two days to write the first version, and a CUDA expert would have needed two weeks.
The Relative Position Bias Problem
A research team is building a custom attention variant. Standard multi-head attention uses absolute position embeddings - the same for any pair of tokens regardless of their relative offset. Their model needs relative position bias: for any two positions and , add a learned scalar to the attention logit before softmax. There are only learned bias values for a sequence of length .
The operation is straightforward mathematically:
The problem is fusing this into a kernel. The standard PyTorch approach materializes the full attention score matrix, adds the bias, then applies softmax. For a 4096-length sequence with 32 heads, that is bytes = 1 GB of intermediate tensors, all written to HBM and read back. This is exactly the memory bandwidth problem that Flash Attention was designed to solve for standard attention.
Writing this in CUDA from scratch is a genuine two-week project. You need to manage thread blocks, shared memory allocation, warp scheduling, tensor core invocation, and numerical stability for the online softmax - all while keeping the relative bias lookup correct across tile boundaries.
In Triton, the team has a working kernel in two days. It runs within 10% of a hand-tuned CUDA implementation. The reason is not magic - Triton generates the same PTX instructions a CUDA expert would write. The difference is that Triton abstracts away the parts that require deep hardware expertise (thread indexing, memory coalescing, shared memory bank conflicts) and lets the programmer focus on the algorithm.
This lesson teaches you how that works.
Why Triton Exists
The CUDA Expertise Barrier
CUDA programming requires mastering a stack of increasingly low-level concepts:
- Thread/block/grid organization and indexing
- Memory hierarchy: registers, L1 cache, shared memory, L2, HBM
- Shared memory management: bank conflicts, padding strategies
- Warp scheduling: divergence, occupancy, latency hiding
- Tensor Core invocation: wmma API or inline PTX
- Asynchronous memory copies:
cp.asyncfor pipelining - Register pressure: keeping register count low enough for high occupancy
Each layer takes weeks to months to master. Most researchers and many ML engineers do not have this background. Before Triton, the options were: (a) wait for a CUDA expert to write the kernel, (b) use cuBLAS/cuDNN (which only covers standard operations), or (c) accept the performance cost of Python-level fusion.
The Triton Abstraction
Triton's key insight is that most high-performance GPU kernels share a common structure: partition the work into rectangular tiles, load tiles into fast memory, compute on tiles, write results back. The tricky parts of CUDA - making memory accesses coalesce, avoiding shared memory bank conflicts, choosing tile sizes for high occupancy - can be handled automatically if you describe what data each tile needs rather than how each thread accesses individual elements.
Triton's programming model is tile-oriented: you write a function that processes one tile at a time, using operations on blocks of values rather than individual elements. The compiler handles the thread-level details.
This does not mean Triton can always match hand-tuned CUDA. For kernels that rely on warp-level communication (__shfl_sync), custom memory layouts, or hardware-specific features like cp.async pipelining, CUDA remains necessary. But for the large class of kernels that follow the tile-compute-store pattern, Triton is competitive.
Historical Context
Triton began as Philippe Tillet's PhD research at MIT. His 2019 thesis, "A LLVM-based compiler for Triton," introduced the idea of a blocked programming model for GPU kernels with automatic shared memory management. The compiler demonstrated that a sufficiently smart autotuner could close most of the performance gap between hand-tuned CUDA and a higher-level abstraction.
OpenAI adopted Triton in 2021 and released it as open source. The first major version used LLVM as its compilation backend. Triton 2.0 (2022) replaced the backend with MLIR (Multi-Level Intermediate Representation), enabling cleaner code generation and better optimization passes.
The community adoption accelerated when Andrej Karpathy and others demonstrated that Flash Attention - one of the most performance-critical kernels in all of deep learning - could be implemented cleanly in Triton. The Triton implementation of Flash Attention became the reference for Flash Attention 2, and today it ships in most production LLM inference stacks.
As of 2024, Triton supports NVIDIA GPUs (Volta through Hopper), AMD ROCm GPUs, and has experimental Intel GPU support. The compiler is actively maintained by OpenAI with contributions from NVIDIA, AMD, and academic groups.
The Triton Programming Model
From Threads to Blocks
The central shift in Triton is moving from thread-level thinking to block-level thinking.
In CUDA, you write a function that runs once per thread:
__global__ void add_cuda(float* a, float* b, float* c, int N) {
// This function runs once per thread
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N) {
c[tid] = a[tid] + b[tid];
}
}
In Triton, you write a function that runs once per block of threads:
@triton.jit
def add_triton(a_ptr, b_ptr, c_ptr, N, BLOCK_SIZE: tl.constexpr):
# This function runs once per BLOCK, not per thread
# pid is the block index, not thread index
pid = tl.program_id(0)
# Generate indices for all elements in this block
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Bounds check
mask = offsets < N
# Load a BLOCK_SIZE-element vector from memory (one load per block)
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
# Compute (operates on the entire block simultaneously)
c = a + b
# Store result
tl.store(c_ptr + offsets, c, mask=mask)
The conceptual difference: in CUDA, you are writing code for a single worker (thread) and the GPU runs it massively in parallel. In Triton, you are writing code for a team foreman (block program) and describing what the whole team computes together.
The Key Primitives
tl.program_id(axis): Returns the index of the current block along the given axis. Analogous to blockIdx.x in CUDA. The grid of blocks is defined by the launch configuration.
tl.arange(start, end): Generates a 1D tensor of integers from start to end (exclusive). The range must be a power of 2. Used to create index vectors for the current block.
tl.load(ptr, mask, other): Loads a block of values from global memory. ptr is a tensor of pointers. mask is a boolean tensor - elements where mask is False load other instead. Automatically generates coalesced memory access patterns.
tl.store(ptr, value, mask): Stores a block of values to global memory. Same coalescing behavior as load.
tl.dot(a, b): Matrix multiply of two 2D blocks. Maps directly to Tensor Core WMMA operations. Inputs should be FP16 or BF16 for Tensor Cores. Returns FP32.
tl.sum(x, axis): Reduction sum along an axis.
tl.max(x, axis): Reduction maximum along an axis.
BLOCK_SIZE: tl.constexpr: When a parameter is annotated tl.constexpr, Triton knows it at compile time. This enables loop unrolling, register allocation optimization, and specialized code generation for each block size.
Kernel 1: Softmax (Memory-Bandwidth Bound)
Softmax is the canonical Triton example because it illustrates the most important pattern: a kernel that requires two passes over the data (max for numerical stability, then normalize), which in naive PyTorch means two HBM reads and writes, but in a fused Triton kernel means only one.
import torch
import triton
import triton.language as tl
@triton.jit
def softmax_kernel(
output_ptr,
input_ptr,
input_row_stride, # stride between rows (= number of columns)
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
Numerically stable softmax over rows of a 2D matrix.
Each program handles one row.
Memory pattern: 2 reads (input twice: for max, then for exp),
but only 1 read if we use online softmax (compute max and sum together).
Here we use the simpler 2-pass approach first.
"""
# Each program handles one row
row_idx = tl.program_id(0)
# Pointer to the start of this row
row_start_ptr = input_ptr + row_idx * input_row_stride
# Load the row into registers
# Triton requires BLOCK_SIZE to be a power of 2 and known at compile time
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
# Bounds-check load: out-of-bounds elements get -inf (safe for max)
mask = col_offsets < n_cols
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Numerically stable softmax: subtract max before exp
row_max = tl.max(row, axis=0)
row_minus_max = row - row_max
numerator = tl.exp(row_minus_max)
# Sum of exp values (ignore masked elements which are exp(-inf) = 0)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Store result
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
def softmax(x: torch.Tensor) -> torch.Tensor:
"""
Launch the Triton softmax kernel over rows.
x: [n_rows, n_cols] tensor on CUDA
"""
assert x.is_cuda and x.is_contiguous()
n_rows, n_cols = x.shape
# BLOCK_SIZE must be >= n_cols and must be a power of 2
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# Each row is handled by one program (one "team")
grid = (n_rows,)
output = torch.empty_like(x)
softmax_kernel[grid](
output, x,
x.stride(0), # row stride (= n_cols for contiguous)
output.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
# num_warps controls how many warps collaborate on each block
# More warps = more parallelism within a block
num_warps=4 if BLOCK_SIZE <= 1024 else 8,
)
return output
# Correctness check
x = torch.randn(1024, 4096, device='cuda', dtype=torch.float32)
y_torch = torch.softmax(x, dim=1)
y_triton = softmax(x)
print(f"Max error: {(y_torch - y_triton).abs().max():.2e}")
# Expected: < 1e-6
# Performance
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[128, 512, 1024, 2048, 4096, 8192],
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'PyTorch'],
styles=[('blue', '-'), ('red', '--')],
ylabel='GB/s',
plot_name='softmax-bandwidth',
args={'M': 4096},
)
)
def benchmark_softmax(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
if provider == 'triton':
fn = lambda: softmax(x)
else:
fn = lambda: torch.softmax(x, dim=1)
ms = triton.testing.do_bench(fn)
# Bandwidth = 2 reads + 1 write (read input twice for max pass and exp pass, write once)
# But Triton fuses into 1 read + 1 write for the kernel
gbps = 2 * x.numel() * x.element_size() / ms * 1e-6
return gbps
Kernel 2: Fused Layer Norm
Layer normalization requires computing mean and variance of each row, then normalizing. The naive approach (three separate kernels: mean, variance, normalize) makes three passes over the data. A fused kernel makes one pass.
@triton.jit
def layer_norm_kernel(
X_ptr, # [M, N] input
Y_ptr, # [M, N] output
W_ptr, # [N] weight (gamma)
B_ptr, # [N] bias (beta)
Mean_ptr, # [M] output means (for backward pass)
Rstd_ptr, # [M] output 1/std (for backward pass)
stride, # row stride of X and Y
N, # number of columns
eps: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Fused layer norm: compute mean, variance, and normalize in one pass.
Each program handles one row.
"""
row = tl.program_id(0)
X_row = X_ptr + row * stride
Y_row = Y_ptr + row * stride
# Load input row
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
x = tl.load(X_row + cols, mask=mask, other=0.0).to(tl.float32)
# Compute mean (Welford online algorithm would be more stable, but this is simpler)
# Note: tl.sum / N works only because masked elements are 0
mean = tl.sum(x, axis=0) / N
# Compute variance
x_centered = tl.where(mask, x - mean, 0.0)
var = tl.sum(x_centered * x_centered, axis=0) / N
# Reciprocal standard deviation
rstd = 1.0 / tl.sqrt(var + eps)
# Store mean and rstd for backward pass
tl.store(Mean_ptr + row, mean)
tl.store(Rstd_ptr + row, rstd)
# Normalize and apply affine transform
w = tl.load(W_ptr + cols, mask=mask)
b = tl.load(B_ptr + cols, mask=mask)
y = (x_centered * rstd) * w + b
# Store output
tl.store(Y_row + cols, y, mask=mask)
class TritonLayerNorm(torch.autograd.Function):
"""Layer norm with Triton-accelerated forward pass."""
@staticmethod
def forward(ctx, x, weight, bias, eps=1e-5):
# x: [M, N]
M, N = x.shape
assert x.is_contiguous() and x.is_cuda
BLOCK_SIZE = triton.next_power_of_2(N)
y = torch.empty_like(x)
mean = torch.empty(M, dtype=torch.float32, device=x.device)
rstd = torch.empty(M, dtype=torch.float32, device=x.device)
layer_norm_kernel[(M,)](
x, y, weight, bias, mean, rstd,
x.stride(0), N, eps=eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=8,
)
ctx.save_for_backward(x, weight, mean, rstd)
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
# Backward pass can be implemented similarly
# Left as exercise - follows the same tiled reduction pattern
x, weight, mean, rstd = ctx.saved_tensors
# ... (omitted for brevity)
return None, None, None, None
def triton_layer_norm(x, weight, bias, eps=1e-5):
return TritonLayerNorm.apply(x, weight, bias, eps)
Kernel 3: Tiled Matrix Multiply
The tiled GEMM kernel demonstrates tl.dot (Tensor Core operations) and the double-buffering pattern (num_stages) for latency hiding.
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8},
num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8},
num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8},
num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8},
num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 4},
num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 4},
num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'], # autotune separately for each (M, N, K) shape
)
@triton.jit
def matmul_kernel(
A_ptr, B_ptr, C_ptr,
M, N, K,
stride_am, stride_ak, # strides of A
stride_bk, stride_bn, # strides of B
stride_cm, stride_cn, # strides of C
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, # for L2 cache optimization (grouped ordering)
):
"""
Tiled matrix multiply: C = A @ B
A: [M, K], B: [K, N], C: [M, N]
Each program computes one [BLOCK_M, BLOCK_N] tile of C.
Iterates over K dimension in BLOCK_K chunks.
"""
# Compute program ID with grouped ordering for L2 cache reuse
# Standard ordering: row-major over output tiles
# Grouped ordering: compute a GROUP_M x (N/BLOCK_N) rectangle together
# so adjacent blocks share A tiles in L2 cache
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# Compute starting indices for this tile
offs_m = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_n = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
# Pointers to first tiles of A and B
a_ptrs = A_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
# Accumulate in FP32 for numerical precision
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# K-loop: iterate over K dimension
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Bounds-checked loads
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
# tl.dot: 2D matrix multiply on Tensor Cores
# a: [BLOCK_M, BLOCK_K], b: [BLOCK_K, BLOCK_N]
# Result: [BLOCK_M, BLOCK_N] accumulated in float32
acc += tl.dot(a, b)
# Advance pointers to next K tile
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# Convert accumulator to output dtype
c = acc.to(tl.float16)
# Store output tile
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = C_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""Matrix multiply using the autotuned Triton kernel."""
assert A.shape[1] == B.shape[0], "Inner dimensions must match"
assert A.is_cuda and B.is_cuda
assert A.dtype == torch.float16 and B.dtype == torch.float16
M, K = A.shape
K, N = B.shape
C = torch.empty(M, N, device=A.device, dtype=torch.float16)
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_M']) * triton.cdiv(N, meta['BLOCK_N']),)
matmul_kernel[grid](
A, B, C,
M, N, K,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
)
return C
How tl.dot Maps to Tensor Cores
When Triton compiles a tl.dot(a, b) operation, it generates Tensor Core WMMA (Warp Matrix Multiply-Accumulate) instructions in PTX. The specific instructions depend on the data types and tile sizes:
The key requirement: BLOCK_M, BLOCK_N, BLOCK_K must be multiples of 16 (for FP16 Tensor Cores). If they are not, Triton falls back to scalar FMA instructions, losing all Tensor Core benefit. This is why you always see block sizes like 64, 128, 256 in production Triton kernels - never 60 or 100.
For INT8 tensor cores, BLOCK_K must be a multiple of 32. For FP8 on Hopper, BLOCK_K must be a multiple of 32 for E4M3 inputs.
Kernel 4: Custom Attention with Relative Position Bias
This is the motivating example from the opening. Standard Flash Attention in Triton plus a relative bias lookup:
@triton.jit
def rel_attn_kernel(
Q_ptr, K_ptr, V_ptr, # [batch, heads, seq_len, head_dim]
Bias_ptr, # [2*seq_len - 1] relative bias values
Out_ptr, # [batch, heads, seq_len, head_dim]
stride_qb, stride_qh, stride_qm, stride_qk, # Q strides
stride_kb, stride_kh, stride_kn, stride_kk, # K strides
stride_vb, stride_vh, stride_vn, stride_vk, # V strides
stride_ob, stride_oh, stride_om, stride_ok, # Out strides
batch, heads, seq_len, head_dim,
scale, # 1 / sqrt(head_dim)
BLOCK_M: tl.constexpr, # tile size in Q dimension
BLOCK_N: tl.constexpr, # tile size in K/V dimension
BLOCK_D: tl.constexpr, # head_dim (constexpr for tl.dot)
):
"""
Flash Attention with relative position bias.
Uses online softmax (Dao et al.) to avoid materializing the full attention matrix.
Each program handles BLOCK_M queries and iterates over all K/V blocks.
"""
# Program handles one (batch, head, query_tile) tuple
start_m = tl.program_id(0)
off_bh = tl.program_id(1)
off_b = off_bh // heads
off_h = off_bh % heads
# Starting row in Q for this program
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_D)
# Load query block: [BLOCK_M, head_dim]
q_ptrs = (Q_ptr
+ off_b * stride_qb
+ off_h * stride_qh
+ offs_m[:, None] * stride_qm
+ offs_d[None, :] * stride_qk)
q = tl.load(q_ptrs, mask=offs_m[:, None] < seq_len, other=0.0)
# Initialize online softmax state
# lse = log-sum-exp (running normalizer)
# m_i = running max (for numerical stability)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
# Iterate over K/V blocks (this is the "flash" loop)
for start_n in range(0, seq_len, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
# Load key block: [BLOCK_D, BLOCK_N] (transposed for matmul)
k_ptrs = (K_ptr
+ off_b * stride_kb
+ off_h * stride_kh
+ offs_n[None, :] * stride_kn
+ offs_d[:, None] * stride_kk)
k = tl.load(k_ptrs, mask=offs_n[None, :] < seq_len, other=0.0)
# Compute attention scores: [BLOCK_M, BLOCK_N]
qk = tl.dot(q, k) * scale # q [BLOCK_M, BLOCK_D] @ k [BLOCK_D, BLOCK_N]
# Add relative position bias
# For query position i and key position j, bias index is (i - j) + (seq_len - 1)
# offs_m[:, None] - offs_n[None, :] gives the [BLOCK_M, BLOCK_N] relative offset matrix
rel_offs = offs_m[:, None] - offs_n[None, :] + (seq_len - 1)
bias_mask = (offs_m[:, None] < seq_len) & (offs_n[None, :] < seq_len)
bias = tl.load(Bias_ptr + rel_offs, mask=bias_mask, other=0.0)
qk = qk + bias
# Apply causal mask (optional - mask future positions)
causal_mask = offs_m[:, None] >= offs_n[None, :]
qk = tl.where(causal_mask, qk, float('-inf'))
# Online softmax update (Dao et al., Flash Attention algorithm)
# m_new = max(m_i, row_max(qk))
m_new = tl.maximum(m_i, tl.max(qk, axis=1))
# Rescale accumulated output and normalizer for new max
alpha = tl.exp(m_i - m_new)
lse_i = lse_i * alpha + tl.sum(tl.exp(qk - m_new[:, None]), axis=1)
acc = acc * alpha[:, None]
# Load value block and accumulate
v_ptrs = (V_ptr
+ off_b * stride_vb
+ off_h * stride_vh
+ offs_n[:, None] * stride_vn
+ offs_d[None, :] * stride_vk)
v = tl.load(v_ptrs, mask=offs_n[:, None] < seq_len, other=0.0)
# Weighted sum: exp(qk - m_new) @ v
p = tl.exp(qk - m_new[:, None])
acc += tl.dot(p.to(tl.float16), v)
m_i = m_new
# Normalize accumulator
acc = acc / lse_i[:, None]
# Store output
out_ptrs = (Out_ptr
+ off_b * stride_ob
+ off_h * stride_oh
+ offs_m[:, None] * stride_om
+ offs_d[None, :] * stride_ok)
tl.store(out_ptrs, acc.to(tl.float16), mask=offs_m[:, None] < seq_len)
def relative_attention(Q, K, V, bias):
"""
Fused attention with relative position bias.
Q, K, V: [batch, heads, seq_len, head_dim] FP16
bias: [2*seq_len - 1] FP32 relative bias values
"""
batch, heads, seq_len, head_dim = Q.shape
assert head_dim in [16, 32, 64, 128], "head_dim must be power of 2 for tl.dot"
scale = 1.0 / (head_dim ** 0.5)
output = torch.empty_like(Q)
BLOCK_M, BLOCK_N = 128, 64
grid = (triton.cdiv(seq_len, BLOCK_M), batch * heads)
rel_attn_kernel[grid](
Q, K, V, bias, output,
Q.stride(0), Q.stride(1), Q.stride(2), Q.stride(3),
K.stride(0), K.stride(1), K.stride(2), K.stride(3),
V.stride(0), V.stride(1), V.stride(2), V.stride(3),
output.stride(0), output.stride(1), output.stride(2), output.stride(3),
batch, heads, seq_len, head_dim,
scale=scale,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_D=head_dim,
num_warps=4, num_stages=2,
)
return output
Autotune: Automatic Configuration Search
The @triton.autotune decorator is one of Triton's most powerful features. It benchmarks multiple kernel configurations at first call and caches the best one.
@triton.autotune(
configs=[
# Each config specifies constexpr params + runtime params
triton.Config(
{'BLOCK_M': 128, 'BLOCK_N': 128},
num_warps=4, # number of warps per block (4 warps = 128 threads)
num_stages=3, # pipeline stages for double-buffering (latency hiding)
),
triton.Config(
{'BLOCK_M': 64, 'BLOCK_N': 256},
num_warps=8,
num_stages=4,
),
triton.Config(
{'BLOCK_M': 256, 'BLOCK_N': 64},
num_warps=8,
num_stages=2,
),
],
key=['M', 'N'], # re-autotune when M or N changes
warmup=25, # warmup iterations per config
rep=100, # measurement iterations per config
)
@triton.jit
def my_kernel(A_ptr, B_ptr, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
pass
# What autotune does behind the scenes:
# 1. On first call with a given (M, N), run all configs for `warmup` iterations each
# 2. Measure each config for `rep` iterations using CUDA events
# 3. Pick the fastest config
# 4. Cache the result in ~/.triton/cache/[kernel_hash]/[M,N].json
# 5. Future calls with the same (M, N) load the cached config and skip benchmarking
# IMPORTANT: Cache is keyed by kernel source hash + problem shape
# If you modify the kernel source, the cache is invalidated automatically
# If you want to force re-tuning: delete ~/.triton/cache/ or use TRITON_CACHE_DIR env var
What num_warps and num_stages Control
num_warps: How many warps (groups of 32 threads) collaborate on each block. Typical range 2-8. More warps increases occupancy but also increases shared memory usage and register pressure. The right value depends on the kernel's arithmetic intensity.
num_stages: Controls software pipelining. With num_stages=3, Triton generates code that loads tiles for the next iteration while the current iteration is computing. This hides global memory latency. Think of it as instruction-level parallelism across iterations. Typical range 2-5. Higher stages require more registers (to hold the prefetched data), reducing occupancy.
Debugging Triton Kernels
Correctness: Compare Against Reference
import triton.testing
def verify_kernel(triton_fn, reference_fn, *args, atol=1e-2, rtol=1e-2):
"""
Compare Triton kernel output against a reference implementation.
Use atol/rtol appropriate for FP16 (not FP32 defaults).
"""
triton_out = triton_fn(*args)
ref_out = reference_fn(*args)
# triton.testing.allclose is a convenience wrapper
# For custom checking:
abs_diff = (triton_out - ref_out).abs()
max_diff = abs_diff.max().item()
rel_diff = (abs_diff / (ref_out.abs() + 1e-8)).max().item()
if max_diff > atol or rel_diff > rtol:
print(f"FAIL: max_abs={max_diff:.2e}, max_rel={rel_diff:.2e}")
# Find where the error is largest
idx = abs_diff.argmax()
print(f"Worst element: triton={triton_out.flatten()[idx]:.6f}, "
f"ref={ref_out.flatten()[idx]:.6f}")
else:
print(f"PASS: max_abs={max_diff:.2e}, max_rel={rel_diff:.2e}")
# Test softmax
x = torch.randn(512, 4096, device='cuda', dtype=torch.float32)
verify_kernel(
lambda x: softmax(x),
lambda x: torch.softmax(x, dim=1),
x,
atol=1e-5, rtol=1e-5,
)
Performance: do_bench
import triton.testing
def benchmark_vs_reference(triton_fn, ref_fn, *args, n_warmup=25, n_bench=100):
"""Compare throughput of Triton kernel against reference."""
# Warmup both (important - first call includes JIT compilation time)
for _ in range(n_warmup):
triton_fn(*args)
ref_fn(*args)
# Measure with Triton's benchmarking utility
triton_ms = triton.testing.do_bench(lambda: triton_fn(*args), warmup=n_warmup, rep=n_bench)
ref_ms = triton.testing.do_bench(lambda: ref_fn(*args), warmup=n_warmup, rep=n_bench)
print(f"Triton: {triton_ms:.3f} ms")
print(f"Reference: {ref_ms:.3f} ms")
print(f"Speedup: {ref_ms / triton_ms:.2f}x")
# Test matmul
M, N, K = 4096, 4096, 4096
A = torch.randn(M, K, device='cuda', dtype=torch.float16)
B = torch.randn(K, N, device='cuda', dtype=torch.float16)
benchmark_vs_reference(
lambda: matmul(A, B),
lambda: torch.mm(A, B),
A, B,
)
The Interpret Mode (CPU Debug)
When a Triton kernel produces wrong results and you cannot debug it from the GPU side, use interpreter mode:
import os
# Run kernel on CPU for easier debugging
# All tl operations execute as Python/NumPy, making them printable
os.environ['TRITON_INTERPRET'] = '1'
# Now launch the kernel - it will run on CPU, not GPU
# You can add print statements inside the @triton.jit function
# Note: BLOCK_SIZE must be small for CPU to be fast (e.g., 16 instead of 1024)
:::warning Interpreter Mode Limitations Interpreter mode runs on CPU and is extremely slow. Use it only for small test cases (N=64 or smaller). Also, interpreter mode does not catch all race conditions or memory alignment issues that only appear on GPU. Always run both the interpreter (for logic correctness) and the GPU version (for numerical correctness). :::
When to Use Triton vs CUDA
Triton Cannot Do These Things
Warp shuffle operations: Operations like __shfl_down_sync move data directly between threads within a warp without going through shared memory. They are essential for tree reductions (finding max/sum within a warp). Triton's tl.sum uses a less optimal path. For reductions on small arrays where warp-level communication dominates, CUDA can be 2-3x faster.
Custom asynchronous copy patterns: The cp.async instruction on Ampere allows overlapping global memory loads with computation by issuing copies to shared memory ahead of when data is needed. Triton's num_stages emulates this automatically, but only for the standard K-loop tiling pattern. For irregular access patterns, you need CUDA.
TMA (Tensor Memory Accelerator) on Hopper: H100's TMA hardware can move tiles of data from global memory to shared memory in a single instruction without consuming compute threads. This is a significant throughput improvement for memory-bandwidth-bound kernels. Triton 2.2+ has some TMA support, but full control requires CUDA.
Custom memory layouts: Some kernels require non-standard tensor layouts in shared memory (e.g., swizzled layouts to eliminate bank conflicts for non-standard access patterns). CUTLASS provides these as composable components. In Triton, you get automatic conflict avoidance but cannot specify the exact layout.
Production Engineering Notes
Pre-warming the JIT Cache
The first call to a Triton kernel incurs JIT compilation overhead (50-500 ms depending on kernel complexity). In production, pre-warm all kernels before the first user request:
import os
def prewarm_triton_kernels():
"""
Run each kernel once with representative shapes to trigger JIT compilation.
Call this during server startup, not in the request path.
"""
print("Pre-warming Triton kernels...")
# Run with multiple typical shapes to populate the autotune cache
shapes = [(1, 4096), (16, 4096), (64, 4096), (256, 4096)]
for M, N in shapes:
x = torch.randn(M, N, device='cuda', dtype=torch.float32)
_ = softmax(x) # triggers JIT + autotune
# For matrix multiply, pre-warm common LLM shapes
llm_shapes = [(1, 4096, 4096), (64, 4096, 4096), (1, 4096, 11008)]
for M, N, K in llm_shapes:
A = torch.randn(M, K, device='cuda', dtype=torch.float16)
B = torch.randn(K, N, device='cuda', dtype=torch.float16)
_ = matmul(A, B)
print("Triton kernels pre-warmed.")
# In server startup:
prewarm_triton_kernels()
Setting the Triton Cache Directory
# Default cache: ~/.triton/cache/
# For production, use a shared volume for consistent caching across restarts
export TRITON_CACHE_DIR=/opt/ml/triton-cache
Profiling with Nsight Compute
Triton kernels profile identically to CUDA kernels:
# Enable Nsight Compute profiling:
# nsys profile --capture-range=cudaProfilerApi python my_script.py
# ncu --set full python my_script.py
# Add CUDA profiler markers to identify Triton kernels in the trace
torch.cuda.nvtx.range_push("my_triton_softmax")
output = softmax(x)
torch.cuda.nvtx.range_pop()
Common Mistakes
:::danger BLOCK_SIZE Not a Power of 2
Triton requires BLOCK_SIZE to be a power of 2. A BLOCK_SIZE of 200 will silently produce wrong results or fail to compile. Always use triton.next_power_of_2(n) to compute the block size from input dimensions. For block sizes smaller than 64, ensure you have the correct number of warps (BLOCK_SIZE / 32 = minimum warps needed).
:::
:::danger tl.dot Requires Multiples of 16
tl.dot(a, b) only issues Tensor Core instructions when BLOCK_M, BLOCK_N, and BLOCK_K are all multiples of 16 (for FP16). If they are not, Triton silently falls back to scalar FMAD instructions, losing the 8-16x Tensor Core throughput benefit. Always use block sizes of 64, 128, or 256. Check Nsight Compute's "SM Throughput" and "Tensor Core Utilization" metrics to verify Tensor Cores are being used.
:::
:::danger Autotune First-Call Latency in Production Autotune benchmarks all configurations on the first call, which can take seconds. Never let the first production request trigger autotuning. Pre-warm all kernels during server startup with representative shapes. Set TRITON_CACHE_DIR to a persistent directory so the cache survives container restarts. :::
:::warning Mixing Dtypes in tl.dot
If you pass FP32 tensors to tl.dot, Triton will compute in FP32 and you will not get Tensor Core throughput (FP32 scalar FMA instead). Always ensure inputs to tl.dot are FP16 or BF16. The accumulator should be FP32 (tl.float32) for numerical precision. Convert inputs with .to(tl.float16) before tl.dot if needed.
:::
:::warning Stride Calculation for Non-Contiguous Tensors
Triton kernels receive tensor strides explicitly. If your input tensor is not contiguous (e.g., after a transpose or slice), you must pass the correct strides or call .contiguous() first. Unlike PyTorch which handles this transparently, Triton will silently read from wrong memory locations if strides are wrong. Always verify: assert tensor.is_contiguous() in debug mode, or pass strides explicitly and handle them in the kernel.
:::
Interview Questions and Answers
Q1: How does Triton's tile-based programming model differ from CUDA's thread-based model, and what does this abstraction cost you?
In CUDA, you write code for a single thread. The programmer explicitly computes each thread's responsibility (via threadIdx, blockIdx), manages shared memory allocation and bank conflicts, and coordinates warp-level synchronization. In Triton, you write code for a block of threads working together. You specify what data the block needs (via tl.load with a range of offsets), what the block computes, and where to store results. Triton handles thread-level indexing, shared memory staging, and memory coalescing automatically.
The cost of the abstraction: Triton cannot express warp-level operations (shuffles, votes), custom shared memory layouts, or asynchronous memory pipelines that do not fit the standard tile-load-compute-store pattern. For kernels that rely on these features (like radix sort using warp shuffle, or custom GEMM layouts), CUDA provides more control. For the large class of kernels that ARE tile-compute-store patterns (softmax, layer norm, attention, GEMM, elementwise ops), Triton typically achieves 80-95% of hand-tuned CUDA performance at 20-30% of the development time.
Q2: What does tl.dot actually generate on an A100, and what are the requirements for it to use Tensor Cores?
tl.dot(a, b) where a is [BLOCK_M, BLOCK_K] and b is [BLOCK_K, BLOCK_N] generates WMMA (warp matrix multiply-accumulate) PTX instructions on Volta/Turing or MMA instructions on Ampere/Hopper. The Tensor Core compute happens on 16x16 fragments, so the MLIR backend tiles the operation into 16x16xBLOCK_K sub-operations mapped to individual warps.
Requirements for Tensor Core execution:
- Input dtype must be FP16, BF16, TF32, INT8, or FP8 (depending on architecture)
- BLOCK_M, BLOCK_N, BLOCK_K must be multiples of 16 (for FP16 on sm_80)
- The accumulator must be float32 (or int32 for INT8)
- Both inputs must be 2D blocks of constexpr size
If any of these requirements are violated, Triton emits scalar FMA instructions and you lose the ~16x throughput advantage of Tensor Cores. Check this by profiling with Nsight Compute: look for "Tensor Core Active Cycles" in the pipeline metrics.
Q3: Describe the autotune process. How does Triton decide which configuration is best, and what are the production implications?
When you call a kernel with @triton.autotune for the first time with a given problem shape (as specified by key=['M', 'N']), Triton iterates through all configurations in the configs list. For each configuration, it runs warmup iterations to heat up the GPU, then rep iterations measuring wall time with CUDA events. The fastest configuration is selected and cached.
The cache is stored in the filesystem (default: ~/.triton/cache/) keyed by a hash of the kernel source code and the problem shape. Future calls with the same shape load the cached configuration without re-benchmarking.
Production implications: (1) First call is slow (can be 5-60 seconds for complex kernels with many configs). Always pre-warm during startup. (2) The cache is per-machine and per-kernel-source-version. If you update the kernel, the cache is invalidated. (3) Cache the directory on persistent storage across container restarts. (4) For LLM inference, you need to benchmark all shapes that will appear in production (all batch sizes, all sequence lengths), so pre-warm with the full distribution.
Q4: When would you choose CUDA over Triton? Give a concrete example where Triton genuinely cannot match CUDA performance.
The clearest case is warp-level reduction operations. Consider building a histogram kernel: you have an input array of integers in [0, 255] and want to count occurrences of each value. The efficient CUDA approach uses warp-level shuffles to perform partial reductions within each warp before writing to shared memory, then does a tree reduction across warps. The __shfl_down_sync instruction moves a value from thread i to thread i-16, then i-8, i-4, i-2, i-1 in successive steps, each taking one clock cycle without shared memory accesses.
Triton has no equivalent to __shfl_down_sync. Its tl.sum reduction uses a generic path that goes through shared memory, adding latency and bank conflict risk. For this histogram kernel, a carefully written CUDA implementation with warp shuffles is typically 2-3x faster than the equivalent Triton kernel.
Another case: the H100's TMA (Tensor Memory Accelerator) can load tiles from global to shared memory in a single instruction, freeing all compute threads from memory-copy work. CUTLASS uses TMA extensively for its Hopper GEMM kernels, achieving near-theoretical peak throughput. Triton has partial TMA support but not full programmable control.
Q5: Walk through the online softmax algorithm used in Flash Attention. Why is it needed and how does Triton implement it efficiently?
Standard softmax requires two passes over the data: one to find the max (for numerical stability), and one to compute exp and normalize. For an attention score matrix of shape [seq_len, seq_len] for large seq_len, materializing the entire matrix and making two passes means reading and writing hundreds of MB to HBM.
Online softmax (Milakov & Gimelshein, 2018) maintains running statistics that allow single-pass computation. For each new block of scores, you update two state variables: (running max) and (running sum of exp). When a new block arrives with max , you rescale the existing accumulator by (since the max changed) before adding the new block's contribution. This keeps the softmax numerically equivalent to the two-pass version.
In the Triton attention kernel, the outer K-loop iterates over K/V blocks. At each iteration:
- Compute attention scores for this K block
- Compute
m_new = tl.maximum(m_i, tl.max(qk, axis=1)) - Rescale accumulator:
acc = acc * tl.exp(m_i - m_new)[:, None] - Update log-sum-exp:
lse_i = lse_i * tl.exp(m_i - m_new) + tl.sum(tl.exp(qk - m_new[:, None]), axis=1) - Accumulate weighted values:
acc += tl.dot(softmax_weights, v) - Update
m_i = m_new
The efficiency comes from keeping the accumulator acc in registers (it is an array of float32 values, size BLOCK_M * head_dim) and only reading K and V from HBM. The full attention matrix is never materialized. This reduces memory bandwidth from to where S is sequence length and d is head dimension.
Q6: What are the most common causes of numerical mismatch between a Triton kernel and a PyTorch reference?
The four main sources of numerical mismatch:
First, accumulation order. Floating-point addition is not associative. Triton's tl.dot accumulates in a different order than PyTorch's cuBLAS matmul, which can produce slightly different results even for FP32. This is expected and not a bug - use atol=1e-2 and rtol=1e-2 for FP16, atol=1e-4 for FP32 in your correctness tests.
Second, dtype mismatch. If your accumulator is FP16 instead of FP32, you will overflow or lose precision for large matrices. Always initialize the accumulator with tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) and convert to the output dtype only at the store step.
Third, mask not applied correctly. If mask=mask is missing from tl.load, out-of-bounds elements are loaded with undefined values (often garbage memory). The softmax kernel would then compute over garbage values for the last block. Always use masks when the tensor size is not an exact multiple of BLOCK_SIZE.
Fourth, stride errors for non-contiguous tensors. If you pass tensor.stride() to the kernel but the tensor was transposed or sliced, the strides may not match your assumption. Debug by adding assert tensor.is_contiguous() at the start of the Python launcher function.
Summary
Triton lets you write production GPU kernels in Python that reach near-CUDA performance for the broad class of tile-compute-store operations. The Flash Attention family of kernels - arguably the most impactful GPU optimization in LLM training and inference - is implemented in Triton.
The programming model is tile-based: write a function indexed by block ID that processes one tile of data, using tl.load/tl.store for memory access, tl.dot for Tensor Core GEMM, and tl.sum/tl.max for reductions. The compiler handles thread-level indexing and memory coalescing. Autotune handles block size selection.
Use Triton for any operation that fits the tile pattern. Fall back to CUDA only for warp-level shuffles, hardware-specific async copy features, or custom shared memory layouts. In practice, roughly 80% of custom kernel requirements in production ML systems can be satisfied with Triton.
The skills that matter most: understanding block size and Tensor Core alignment requirements, the online softmax pattern for memory-efficient attention, and knowing how to debug numerical mismatches. Master those and you can write kernels that were previously accessible only to CUDA experts.
