LLVM and MLIR
The Infrastructure Beneath Everything
Every major ML framework compiles its computations through LLVM or MLIR at some point in the stack. When you call torch.compile, TorchInductor generates Triton code which compiles through LLVM's NVPTX backend. When you run jax.jit, XLA generates HLO which routes through either LLVM for CPUs or custom GPU backends. When numba compiles your Python loop, it generates LLVM IR and calls the LLVM optimizer. When the Rust compiler builds PyO3 bindings, it uses LLVM as its backend.
LLVM is not a single compiler. It is a compiler infrastructure - a collection of reusable components for building compilers. The LLVM IR (Intermediate Representation) is the common language. The optimization passes are the algorithms. The backends are the final translation to machine code. Any language that can compile to LLVM IR gets access to decades of optimization research for free.
MLIR (Multi-Level Intermediate Representation) extends this idea. LLVM IR sits at one abstraction level - close to machine instructions. ML compilers need to optimize at much higher levels: linear algebra (matrix operations), tensor operations, loop nests with polyhedral transformations. MLIR introduces "dialects" - domain-specific IR extensions - that allow optimization at each abstraction level before lowering to LLVM IR.
Understanding LLVM and MLIR is not academic for ML engineers. When you hit a performance cliff in your Triton kernel, you need to understand how LLVM's vectorizer works to know what code patterns enable auto-vectorization. When you are debugging why torch.compile produces suboptimal code for your model, reading the MLIR output shows exactly what fusions and transformations occurred. When you are evaluating whether to target a new accelerator, MLIR dialects determine how much reuse of existing optimization work is possible.
This lesson builds from LLVM fundamentals through MLIR's dialect system to real-world examples of how PyTorch, TensorFlow, and Triton use this infrastructure to generate code for GPUs and TPUs.
Why This Exists - The Compiler Reinvention Problem
Before LLVM (2000), every new language needed to build its own backend. If you invented a new language in 1995, you would write: a lexer, a parser, a type checker, a code generator for x86, a code generator for SPARC, a code generator for PowerPC, and an optimizer. The optimizer alone was a PhD-level effort. Every language team was reinventing the same algorithms: constant folding, dead code elimination, register allocation, instruction scheduling.
Chris Lattner's insight (his 2002 PhD thesis at UIUC): build a shared compiler infrastructure with a well-defined IR. Any language that compiles to this IR gets all the optimizations for free. Any backend that can read this IR gets access to all the languages. You write the optimization once. Everyone benefits.
LLVM launched as an open-source project in 2003. Clang (2007) made it the default C/C++ compiler on macOS and later Linux. Rust adopted LLVM as its exclusive backend. Swift (2014) was designed by Chris Lattner specifically to use LLVM. Today, LLVM powers essentially every modern compiled language.
MLIR (2019, Google) extended LLVM's philosophy upward in the abstraction hierarchy. ML compilers had a specific problem: the gap between "matrix multiply" and "LLVM IR" is enormous. A direct translation loses all the structure that enables tiling, vectorization, and fusion at the linear algebra level. MLIR allows optimizations to happen at each level - from tensor operations down through loop nests to memory operations to machine instructions - without losing information prematurely.
LLVM Architecture - From Source to Machine Code
The three-phase structure is what makes LLVM reusable:
- Frontend: language-specific parsing, type checking, semantic analysis. Produces LLVM IR.
- Middle-end (optimizer): target-independent transformations on LLVM IR.
- Backend: target-specific code generation from optimized LLVM IR.
LLVM IR Syntax and Structure
LLVM IR looks like a typed assembly language. It is explicit about types, explicit about memory operations, and uses SSA (Static Single Assignment) form where every variable is assigned exactly once.
; A simple function: int add(int a, int b) { return a + b; }
define i32 @add(i32 %a, i32 %b) {
entry:
%result = add i32 %a, %b
ret i32 %result
}
; A loop: float sum(float* arr, int n)
define float @sum(float* %arr, i32 %n) {
entry:
br label %loop_header
loop_header:
; PHI nodes implement SSA for variables that merge at join points
; %i and %acc have different values depending on which predecessor we came from
%i = phi i32 [ 0, %entry ], [ %i_next, %loop_body ]
%acc = phi float [ 0.0, %entry ], [ %acc_next, %loop_body ]
; Loop condition
%cond = icmp slt i32 %i, %n
br i1 %cond, label %loop_body, label %exit
loop_body:
; Load arr[i]
%ptr = getelementptr float, float* %arr, i32 %i
%val = load float, float* %ptr
; Accumulate
%acc_next = fadd float %acc, %val
; Increment loop counter
%i_next = add i32 %i, 1
br label %loop_header
exit:
ret float %acc
}
Key LLVM IR concepts:
- Types:
i8,i32,i64(integers),float,double,ptr(pointer),[4 x float](array),{i32, float}(struct) - SSA: every
%nameis assigned exactly once. For values that change (loop variables), PHI nodes select the right value based on which basic block we came from. - Basic blocks: sequences of instructions with no branches except at the end.
entry:,loop_header:,loop_body:,exit:are basic blocks. - Instructions:
add,fadd,mul,load,store,getelementptr(pointer arithmetic),call,br(branch),phi
SSA Form - Why It Matters for Optimization
SSA (Static Single Assignment) is the foundational property that makes most optimizations efficient. Every variable is defined exactly once. This creates an explicit data flow graph: each use of %x points to exactly one definition of %x.
; Without SSA (pseudo-assembly):
x = 5
y = x + 3
x = 10 ; x redefined! Optimizers need to track which x is used where
z = x * 2 ; which x is this?
; In SSA form:
%x0 = 5 ; x is defined as %x0
%y = add %x0, 3 ; clearly uses first x
%x1 = 10 ; this is a NEW variable %x1
%z = mul %x1, 2 ; clearly uses second definition
With SSA, the optimizer can immediately see that %y and %z use different definitions of x. Constant propagation is trivial: substitute the constant 5 wherever %x0 is used. Dead code elimination is trivial: if %y has no uses, remove add %x0, 3. The entire optimization pipeline is built on SSA.
LLVM with Python - llvmlite
llvmlite is a Python binding for LLVM designed for numba. You can use it directly to generate and execute LLVM IR from Python:
import llvmlite.binding as llvm
import llvmlite.ir as ir
import ctypes
import numpy as np
# Initialize LLVM
llvm.initialize()
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
# Build LLVM IR using the Python API
def build_vector_sum_ir():
"""Build LLVM IR for: float sum_array(float* data, int32 n)"""
module = ir.Module(name="vector_sum")
float_type = ir.FloatType()
int32_type = ir.IntType(32)
float_ptr = float_type.as_pointer()
# Function signature: float(float*, i32)
func_type = ir.FunctionType(float_type, [float_ptr, int32_type])
func = ir.Function(module, func_type, name="sum_array")
# Name the arguments
data_ptr, n = func.args
data_ptr.name = "data"
n.name = "n"
# Create basic blocks
entry_bb = func.append_basic_block("entry")
loop_bb = func.append_basic_block("loop")
body_bb = func.append_basic_block("body")
exit_bb = func.append_basic_block("exit")
# entry: initialize and branch to loop
builder = ir.IRBuilder(entry_bb)
zero_i32 = ir.Constant(int32_type, 0)
zero_f32 = ir.Constant(float_type, 0.0)
builder.branch(loop_bb)
# loop: PHI nodes for i and acc, check condition
builder = ir.IRBuilder(loop_bb)
i_phi = builder.phi(int32_type, name="i")
acc_phi = builder.phi(float_type, name="acc")
# PHI incoming values filled in after body block is built
cond = builder.icmp_signed("<", i_phi, n, name="cond")
builder.cbranch(cond, body_bb, exit_bb)
# body: load arr[i], add to acc, increment i, branch back
builder = ir.IRBuilder(body_bb)
elem_ptr = builder.gep(data_ptr, [i_phi], name="elem_ptr")
val = builder.load(elem_ptr, name="val")
acc_next = builder.fadd(acc_phi, val, name="acc_next")
i_next = builder.add(i_phi, ir.Constant(int32_type, 1), name="i_next")
builder.branch(loop_bb)
# Now set PHI predecessors
i_phi.add_incoming(zero_i32, entry_bb)
i_phi.add_incoming(i_next, body_bb)
acc_phi.add_incoming(zero_f32, entry_bb)
acc_phi.add_incoming(acc_next, body_bb)
# exit: return acc
builder = ir.IRBuilder(exit_bb)
builder.ret(acc_phi)
return module
# Compile and execute
module_ir = build_vector_sum_ir()
print("Generated LLVM IR:")
print(str(module_ir))
# Create execution engine with JIT compilation
target = llvm.Target.from_default_triple()
target_machine = target.create_target_machine(opt=2) # O2 optimization
mod = llvm.parse_assembly(str(module_ir))
mod.verify()
with llvm.create_mcjit_compiler(mod, target_machine) as ee:
ee.finalize_object()
ee.run_static_constructors()
# Get function pointer
func_ptr = ee.get_function_address("sum_array")
# Create ctypes function
c_func = ctypes.CFUNCTYPE(
ctypes.c_float,
ctypes.POINTER(ctypes.c_float),
ctypes.c_int
)(func_ptr)
# Test with real data
data = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
result = c_func(data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), len(data))
print(f"\nsum([1,2,3,4,5]) = {result}") # 15.0
print(f"numpy check: {data.sum()}")
# Inspect LLVM's optimization passes on real IR
import llvmlite.binding as llvm
simple_ir = """
define i32 @redundant(i32 %x) {
%a = add i32 %x, 0 ; adding 0 - should be eliminated
%b = mul i32 %a, 1 ; multiplying by 1 - should be eliminated
%c = add i32 %b, %b ; can be replaced by shift
ret i32 %c
}
"""
llvm.initialize()
llvm.initialize_native_target()
mod = llvm.parse_assembly(simple_ir)
print("Before optimization:")
print(str(mod))
# Apply optimization passes
target = llvm.Target.from_default_triple()
tm = target.create_target_machine()
pmb = llvm.PassManagerBuilder()
pmb.opt_level = 3 # O3 equivalent
pm = llvm.ModulePassManager()
pmb.populate(pm)
pm.run(mod)
print("\nAfter O3 optimization:")
print(str(mod))
# %a = add i32 %x, 0 --> eliminated (x + 0 = x)
# %b = mul i32 %x, 1 --> eliminated (x * 1 = x)
# %c = add i32 %x, %x --> becomes: %c = shl i32 %x, 1 (shift left = *2)
LLVM Optimization Passes
LLVM's optimizer consists of modular "passes" that each transform the IR in some way. Key passes:
Pass Name | What It Does
--------------------|------------------------------------------
mem2reg | Promotes stack variables to SSA registers
instcombine | Algebraic simplification (x+0=x, x*1=x)
simplifycfg | Removes unreachable basic blocks, merges blocks
gvn | Global Value Numbering - eliminates redundant computation
licm | Loop Invariant Code Motion - moves invariants out of loops
loop-vectorize | Auto-vectorizes loops using SIMD (AVX, SSE, NEON)
slp-vectorize | Vectorizes independent adjacent operations
inline | Inlines function calls
dce | Dead Code Elimination
sccp | Sparse Conditional Constant Propagation
For ML workloads, the most impactful passes are:
loop-vectorize: converts scalar loops to SIMD. Crucial for CPU inference.licm: moves weight/bias loading out of inner loops.inline: inlines small helper functions, enabling further optimization.
MLIR - Multi-Level Intermediate Representation
MLIR addresses a fundamental problem: compiler optimizations need to happen at multiple abstraction levels. A single-level IR (like LLVM IR) is too low-level for high-level optimizations.
Each level is a dialect. Passes that operate on the linalg dialect can do high-level transformations (tiling, fusion) that would be invisible at the llvm level. Optimizations at each level are preserved as you lower to the next level.
MLIR Dialects
MLIR ships with a set of standard dialects. Each dialect defines:
- Operations (ops): the instructions of that dialect
- Types: the data types that dialect ops work on
- Attributes: static properties attached to ops (like constants)
- Interfaces: contracts that ops can implement (like
MemoryEffectsOpInterface)
Key dialects for ML:
| Dialect | Purpose | Example ops |
|---|---|---|
func | Function definitions and calls | func.func, func.call, func.return |
arith | Scalar arithmetic | arith.addf, arith.muli, arith.cmpf |
tensor | Immutable tensor operations | tensor.extract, tensor.insert_slice, tensor.empty |
linalg | Named linear algebra | linalg.matmul, linalg.conv_2d, linalg.generic |
affine | Affine loop transformations | affine.for, affine.load, affine.store |
gpu | GPU compute | gpu.launch, gpu.thread_id, gpu.barrier |
nvgpu | NVIDIA-specific | nvgpu.warpgroup.mma (tensor core operations) |
memref | Memory reference model | memref.alloc, memref.load, memref.store |
vector | SIMD-style operations | vector.dot, vector.contract, vector.transfer_read |
MLIR Syntax - Reading and Writing MLIR
MLIR has a text format that is human-readable (though verbose):
// A simple matrix multiplication in MLIR linalg dialect
// linalg.matmul operates on MemRef (mutable memory references)
func.func @matmul(%A: memref<128x256xf32>,
%B: memref<256x512xf32>,
%C: memref<128x512xf32>) {
// linalg.matmul: C += A * B
// This is a named op - the compiler knows this is matrix multiply
// and can apply tiling, packing, vectorization optimizations
linalg.matmul ins(%A, %B : memref<128x256xf32>, memref<256x512xf32>)
outs(%C : memref<128x512xf32>)
func.return
}
// After tiling transformation (tile sizes 32x32x32):
// linalg.matmul gets tiled into loop nests with affine.for
func.func @matmul_tiled(%A: memref<128x256xf32>,
%B: memref<256x512xf32>,
%C: memref<128x512xf32>) {
affine.for %i0 = 0 to 128 step 32 {
affine.for %j0 = 0 to 512 step 32 {
affine.for %k0 = 0 to 256 step 32 {
// Extract tile
%At = memref.subview %A[%i0, %k0][32, 32][1, 1] : ...
%Bt = memref.subview %B[%k0, %j0][32, 32][1, 1] : ...
%Ct = memref.subview %C[%i0, %j0][32, 32][1, 1] : ...
linalg.matmul ins(%At, %Bt : ...) outs(%Ct : ...)
}
}
}
func.return
}
// Element-wise operation with broadcasting in tensor dialect
func.func @elementwise(%a: tensor<32x512xf32>,
%b: tensor<512xf32>) -> tensor<32x512xf32> {
// Broadcast %b from (512) to (32, 512) then add
%expanded = tensor.expand_shape %b [[0, 1]] : tensor<512xf32> into tensor<1x512xf32>
%broadcast = linalg.broadcast ins(%expanded : tensor<1x512xf32>)
outs(%a : tensor<32x512xf32>)
dimensions = [0]
%result = linalg.add ins(%a, %broadcast : tensor<32x512xf32>, tensor<32x512xf32>)
outs(%a : tensor<32x512xf32>) -> tensor<32x512xf32>
func.return %result : tensor<32x512xf32>
}
Running mlir-opt
mlir-opt is the MLIR optimizer tool. You pipe MLIR text through it with transformation passes:
# Tile a matmul for L1 cache locality
echo "..." | mlir-opt \
--linalg-tile-and-fuse-tensor-ops="tile-sizes=32,32,32" \
--convert-linalg-to-loops \
--lower-affine \
--convert-scf-to-cf \
--convert-cf-to-llvm \
--convert-arith-to-llvm \
--finalize-memref-to-llvm \
--reconcile-unrealized-casts \
-o output.mlir
# Vectorize and lower to LLVM IR
echo "..." | mlir-opt \
--linalg-vectorize \
--canonicalize \
--convert-vector-to-llvm \
--convert-linalg-to-llvm \
-o vectorized.mlir
# Emit LLVM IR from MLIR
mlir-translate --mlir-to-llvmir output.mlir -o output.ll
How PyTorch Uses MLIR
TorchInductor (the default torch.compile backend) generates Triton code which is compiled through LLVM. But the path from PyTorch FX graph to Triton goes through MLIR:
# torch-mlir: convert PyTorch models to MLIR
# pip install torch-mlir
import torch
import torch_mlir
class SimpleMLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(512, 1024)
self.fc2 = torch.nn.Linear(1024, 256)
def forward(self, x):
return torch.relu(self.fc2(torch.relu(self.fc1(x))))
model = SimpleMLP()
example_input = torch.randn(32, 512)
# Convert to MLIR in torch dialect
mlir_module = torch_mlir.compile(
model,
example_input,
output_type=torch_mlir.OutputType.TORCH,
use_tracing=True
)
print("=== Torch Dialect MLIR ===")
print(mlir_module.operation.get_asm())
# Convert to linalg (more optimizable form)
mlir_linalg = torch_mlir.compile(
model,
example_input,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=True
)
print("\n=== Linalg MLIR ===")
print(mlir_linalg.operation.get_asm())
# Convert to LLVM IR (final form for CPU execution)
mlir_llvmir = torch_mlir.compile(
model,
example_input,
output_type=torch_mlir.OutputType.LLVM_IR,
use_tracing=True
)
Examining torch.compile's MLIR Output
import torch
import os
# Enable debug output from TorchInductor to see generated code
os.environ["TORCH_COMPILE_DEBUG"] = "1"
os.environ["INDUCTOR_MAX_AUTOTUNE"] = "0"
class SimpleModel(torch.nn.Module):
def forward(self, x, w):
return torch.relu(torch.mm(x, w) + x.sum(dim=1, keepdim=True))
model = SimpleModel()
compiled = torch.compile(model)
x = torch.randn(32, 32, device='cpu')
w = torch.randn(32, 32, device='cpu')
y = compiled(x, w)
# Debug files written to /tmp/torchinductor_*/
# Look for: .triton (Triton kernel source)
# .ll (LLVM IR)
# .ptx (GPU assembly)
import glob
debug_files = glob.glob("/tmp/torchinductor_*/output_code.py")
for f in debug_files[:2]:
print(f"Debug file: {f}")
with open(f) as fp:
print(fp.read()[:2000])
Triton - MLIR-Based GPU Kernel Language
Triton is a Python-based language for writing GPU kernels. It was created by Philippe Tillet at OpenAI (2021). TorchInductor generates Triton code rather than raw CUDA for GPU compilation.
Triton's key insight: program in terms of "blocks" of data, not individual threads. The Triton compiler handles thread assignment, memory coalescing, and vectorization automatically.
import triton
import triton.language as tl
import torch
import time
# Triton kernel: vector addition
# Programs execute as a grid of blocks; each block processes BLOCK_SIZE elements
@triton.jit
def vector_add_kernel(
a_ptr, b_ptr, c_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr # Must be known at compile time
):
# Program ID: which block am I?
pid = tl.program_id(axis=0)
# Compute which elements this block handles
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Mask for out-of-bounds elements (last block may be partial)
mask = offsets < n_elements
# Load from global memory (coalesced access pattern)
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
# Compute
c = a + b
# Store result
tl.store(c_ptr + offsets, c, mask=mask)
def triton_vector_add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
c = torch.empty_like(a)
n = a.numel()
BLOCK_SIZE = 1024
# Grid: number of blocks needed
grid = (triton.cdiv(n, BLOCK_SIZE),)
vector_add_kernel[grid](a, b, c, n, BLOCK_SIZE=BLOCK_SIZE)
return c
# More interesting: fused softmax kernel
@triton.jit
def fused_softmax_kernel(
output_ptr, input_ptr,
input_row_stride, output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr
):
"""Compute softmax row-by-row with fusion."""
row_idx = tl.program_id(0)
# Pointer to this row
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
# Load row with -inf for out-of-bounds (safe for max reduction)
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Numerically stable softmax:
# 1. Subtract max to prevent overflow
row_max = tl.max(row, axis=0)
row = row - row_max
# 2. Exponentiate
row_exp = tl.exp(row)
# 3. Divide by sum
row_sum = tl.sum(row_exp, axis=0)
softmax_output = row_exp / row_sum
# Store result
output_row_start_ptr = output_ptr + row_idx * output_row_stride
tl.store(output_row_start_ptr + col_offsets, softmax_output, mask=mask)
def triton_softmax(x: torch.Tensor) -> torch.Tensor:
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
y = torch.empty_like(x)
fused_softmax_kernel[(n_rows,)](
y, x,
x.stride(0), y.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE
)
return y
# Benchmark: Triton vs PyTorch eager softmax
if torch.cuda.is_available():
x = torch.randn(4096, 1024, device='cuda', dtype=torch.float32)
# Warmup
for _ in range(5):
triton_softmax(x)
torch.softmax(x, dim=-1)
torch.cuda.synchronize()
N = 100
t = time.perf_counter()
for _ in range(N): torch.softmax(x, dim=-1)
torch.cuda.synchronize()
eager_ms = (time.perf_counter() - t) / N * 1000
t = time.perf_counter()
for _ in range(N): triton_softmax(x)
torch.cuda.synchronize()
triton_ms = (time.perf_counter() - t) / N * 1000
print(f"PyTorch eager softmax: {eager_ms:.3f}ms")
print(f"Triton fused softmax: {triton_ms:.3f}ms")
print(f"Speedup: {eager_ms/triton_ms:.2f}x")
# Verify correctness
ref = torch.softmax(x, dim=-1)
out = triton_softmax(x)
print(f"Max error: {(ref - out).abs().max():.2e}")
IREE - Inference on Edge and Mobile
IREE (Intermediate Representation Execution Environment) is an ML compiler and runtime designed for deployment on edge devices, mobile, and embedded systems. It takes MLIR as input and produces highly optimized binaries for ARM, x86, RISC-V, and Vulkan/Metal GPUs.
IREE's compilation pipeline:
- Accept
linalgortorch-mlirMLIR - Apply polyhedral transformations (tiling for cache locality)
- Vectorize using MLIR vector dialect
- Lower to LLVM for CPU targets
- Lower to SPIR-V for Vulkan targets
- Package with a lightweight runtime
# IREE example (conceptual - install iree-runtime and iree-compiler)
# pip install iree-runtime iree-compiler
"""
import iree.compiler as ireec
import iree.runtime as iree_rt
import numpy as np
# Compile a model from MLIR
mlir_source = '''
func.func @add(%a: tensor<4xf32>, %b: tensor<4xf32>) -> tensor<4xf32> {
%result = arith.addf %a, %b : tensor<4xf32>
return %result : tensor<4xf32>
}
'''
# Compile for CPU (llvm-cpu backend)
flatbuffer = ireec.compile_str(
mlir_source,
target_backends=["llvm-cpu"]
)
# Create runtime instance
config = iree_rt.Config("local-sync")
context = iree_rt.SystemContext(config=config)
vm_module = iree_rt.VmModule.from_flatbuffer(context.instance, flatbuffer)
context.add_vm_module(vm_module)
# Run inference
a = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
b = np.array([5.0, 6.0, 7.0, 8.0], dtype=np.float32)
result = context.modules.module["add"](a, b)
print(result) # [6. 8. 10. 12.]
"""
print("""IREE targets:
llvm-cpu: x86, ARM, RISC-V via LLVM
vulkan-spirv: Vulkan GPU (Android, desktop)
metal: Apple GPU (iOS, macOS)
cuda: NVIDIA GPU via NVPTX
webgpu: Web browser GPU via WGSL
""")
TableGen - Defining Ops and Passes
MLIR uses TableGen, a domain-specific language for defining operations, types, and passes. TableGen records are the "source of truth" from which C++ headers, documentation, and serialization code are generated.
// Defining a new MLIR operation in TableGen
// File: ops.td
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
// Define a custom "ml" dialect
def ML_Dialect : Dialect {
let name = "ml";
let summary = "A simple ML ops dialect";
let cppNamespace = "::mlir::ml";
}
// Define a ReLU operation
def ML_ReluOp : Op<ML_Dialect, "relu",
[Pure, SameOperandsAndResultType]> {
let summary = "Rectified Linear Unit activation";
let description = [{
Computes element-wise ReLU: max(0, x)
Example:
%result = ml.relu %input : tensor<32x512xf32>
}];
let arguments = (ins AnyTensor:$input);
let results = (outs AnyTensor:$result);
// Build C++ implementation
let builders = [
OpBuilder<(ins "Value":$input)>
];
// Custom verifier
let verifier = [{ return verifyRelu(*this); }];
// Assembly format for text serialization
let assemblyFormat = "$input attr-dict `:` type($input)";
}
// Define a MatMul operation with tiling attribute
def ML_MatMulOp : Op<ML_Dialect, "matmul",
[AttrSizedOperandSegments]> {
let summary = "Matrix multiplication C += A * B";
let arguments = (ins
AnyMemRef:$A, // Input matrix A
AnyMemRef:$B, // Input matrix B
AnyMemRef:$C, // Output matrix C (accumulates)
OptionalAttr<I64ArrayAttr>:$tile_sizes // Optional tiling hints
);
let results = (outs); // In-place operation
let assemblyFormat = [{
`ins` `(` $A `,` $B `:` type($A) `,` type($B) `)`
`outs` `(` $C `:` type($C) `)`
(`tile_sizes` `=` $tile_sizes^)?
attr-dict
}];
}
# Generate C++ from TableGen
mlir-tblgen --gen-op-decls ops.td -I$(mlir-config --includedir) > ops.h.inc
mlir-tblgen --gen-op-defs ops.td -I$(mlir-config --includedir) > ops.cpp.inc
mlir-tblgen --gen-dialect-decls ops.td > dialect.h.inc
How Compiler Passes Improve ML Performance
The performance benefits from compiler passes are concrete and measurable. Let us trace a specific example: layer_norm(x) + bias in an attention block.
Transformation 1: Fusion (Inductor/Triton)
- Before:
layer_normkernel,addkernel,bias_addkernel - After: one fused Triton kernel
- Benefit: 3 HBM reads become 1; intermediate results stay in registers
Transformation 2: Tiling (MLIR linalg)
- Before: matmul reads entire matrices from HBM per row of output
- After: tiled matmul loads 32x32 blocks into L1/L2 cache and reuses them
- Benefit: HBM reads become block loads
- The arithmetic: untiled needs reads = 16M. Tiled needs = same FLOPs but each 32x32 block loaded once reused 32 times.
Transformation 3: Vectorization (LLVM loop-vectorize)
- Before: scalar loop
for i: result[i] = a[i] + b[i] - After:
vaddps ymm0, ymm1, ymm2- processes 8 floats per instruction with AVX2 - Benefit: 8x throughput for element-wise operations on CPU
# Demonstrate the benefit of tiling empirically
import numpy as np
import time
def naive_matmul(A, B):
"""Pure Python - shows the algorithm."""
m, k = A.shape
k2, n = B.shape
C = np.zeros((m, n), dtype=np.float32)
for i in range(m):
for j in range(n):
for l in range(k):
C[i, j] += A[i, l] * B[l, j]
return C
def tiled_matmul(A, B, tile_m=32, tile_n=32, tile_k=32):
"""Tiled matmul - same algorithm, better cache behavior."""
m, k = A.shape
k2, n = B.shape
C = np.zeros((m, n), dtype=np.float32)
for i0 in range(0, m, tile_m):
for j0 in range(0, n, tile_n):
for k0 in range(0, k, tile_k):
# Load tile into local cache (L1/L2 will hold these)
i_end = min(i0 + tile_m, m)
j_end = min(j0 + tile_n, n)
k_end = min(k0 + tile_k, k)
# This inner matmul reuses the A/B tiles many times
C[i0:i_end, j0:j_end] += (
A[i0:i_end, k0:k_end] @
B[k0:k_end, j0:j_end]
)
return C
# For real comparison, use NumPy (BLAS) which does tiling internally
A = np.random.randn(512, 256).astype(np.float32)
B = np.random.randn(256, 512).astype(np.float32)
t = time.perf_counter()
for _ in range(10): C = np.dot(A, B) # BLAS - tiled + vectorized
blas_time = (time.perf_counter() - t) / 10
print(f"NumPy BLAS (tiled + vectorized): {blas_time*1000:.2f}ms")
print(f"This is what MLIR tiling achieves automatically for ML ops")
# MLIR auto-vectorization: show generated SIMD for simple loop
import numba as nb # numba uses LLVM's vectorizer
@nb.njit
def vectorized_add(a, b, out):
"""LLVM will auto-vectorize this with AVX2/AVX-512."""
for i in range(len(a)):
out[i] = a[i] + b[i]
# To see the generated SIMD instructions:
a = np.random.randn(1024).astype(np.float32)
b = np.random.randn(1024).astype(np.float32)
out = np.empty_like(a)
vectorized_add(a, b, out)
# Check the assembly for SIMD instructions (ymm/zmm registers = AVX2/AVX-512)
asm_dict = vectorized_add.inspect_asm()
for sig, asm in asm_dict.items():
ymm_count = asm.count('ymm') # AVX2 256-bit SIMD
zmm_count = asm.count('zmm') # AVX-512 512-bit SIMD
print(f"SIMD instructions: {ymm_count} YMM (AVX2) + {zmm_count} ZMM (AVX-512)")
Production Engineering Notes
When to read LLVM IR directly. Reading the LLVM IR output is useful when: debugging why a loop is not auto-vectorizing (look for <4 x float> types indicating successful vectorization vs scalar float indicating failure), understanding memory access patterns, and verifying that inlining happened.
The Triton autotuning system. TorchInductor uses Triton's @triton.autotune to try multiple kernel implementations with different tile sizes and occupancy configurations, then selects the fastest. This is what mode="max-autotune" in torch.compile triggers.
MLIR debugging tools. mlir-opt --mlir-print-ir-after-all prints the IR after every pass. mlir-opt --debug-only=dialect-conversion shows the conversion rules that fire. For torch-mlir, --print-mlir-pipeline shows the full pass pipeline.
Hardware-specific lowering. MLIR's nvgpu dialect exposes NVIDIA Tensor Core operations (nvgpu.warpgroup.mma) which map to ldmatrix/mma.sync PTX instructions. These give 2-4x speedup over naive matmul on Ampere/Hopper GPUs. TorchInductor uses these for large matrix multiplies.
# Check what Triton sees as its compilation target
import triton
# The target determines which MLIR lowering path is used
target = triton.runtime.driver.active.get_current_target()
print(f"Triton target: {target}")
# E.g.: GPUTarget(backend='cuda', arch=90, warp_size=32)
# arch=90 = H100 (sm_90) - enables hopper-specific features
# Examine a compiled Triton kernel's LLVM IR
@triton.jit
def simple_kernel(x_ptr, y_ptr, n: tl.constexpr):
pid = tl.program_id(0)
offs = pid * n + tl.arange(0, n)
x = tl.load(x_ptr + offs)
tl.store(y_ptr + offs, x * 2.0)
if torch.cuda.is_available():
import torch
x = torch.randn(1024, device='cuda')
y = torch.empty_like(x)
simple_kernel[(1,)](x, y, n=1024)
# Get the compiled LLVM IR and PTX
# (Triton caches compilation artifacts)
print("Compilation succeeds - check ~/.triton/cache for artifacts")
:::tip Reading MLIR Output from torch.compile
Set TORCH_COMPILE_DEBUG=1 before your Python script. Inductor writes debug files to /tmp/torchinductor_*/. The .triton files show generated Triton kernels. The .ll files show LLVM IR. The .ptx files show GPU assembly. Looking at the LLVM IR is the fastest way to understand why a particular optimization did or did not fire.
:::
:::danger MLIR Dialect Version Mismatches
MLIR dialects evolve rapidly. Code that serializes MLIR (saves .mlir files or uses MLIR for model storage) can break between versions. The MLIR team does not guarantee serialization compatibility between versions. If you are building a deployment pipeline around MLIR, pin your LLVM/MLIR version and test carefully on version updates.
:::
Interview Q&A
Q1: What is LLVM IR and why is it important? What key properties does it have?
LLVM IR (Intermediate Representation) is a typed, target-independent assembly-like language that sits between language frontends (Clang for C++, rustc for Rust, numba for Python) and machine code backends (x86, ARM, NVPTX). Its importance: any language that compiles to LLVM IR gets access to 50+ optimization passes and code generation for every supported target architecture. Write the optimization once; every language benefits.
Key properties:
- SSA (Static Single Assignment): every variable is assigned exactly once. Uses PHI nodes at control flow join points to represent variables with multiple definitions. SSA makes dataflow analysis trivially explicit.
- Strong typing: every value and operation has an explicit type (
i32,float,<4 x float>for SIMD). This enables type-based aliasing analysis (TBAA). - Target-independent: the IR describes operations without specifying machine registers. The backend handles register allocation and instruction selection.
- Explicit memory model: loads and stores are explicit operations. There is no implicit mutation. This enables memory dependency analysis and alias analysis.
Q2: What is SSA form and why do compilers use it?
SSA (Static Single Assignment) requires that every variable has exactly one assignment in the program text. A variable defined in multiple branches (like a loop counter that increments) is split into separate SSA values, with a PHI node at the merge point to select the right value.
Why compilers use it: SSA makes the data flow graph explicit. Every use of %x points to exactly one definition. This enables efficient algorithms for:
- Constant propagation: if
%x = i32 42, substitute 42 everywhere%xis used - Dead code elimination: if
%xhas no uses, the assignment can be removed - Redundant computation elimination: if
%x = a + band%y = a + bwith no intervening modification,%ycan be replaced by%x - Register allocation: the SSA def-use chains directly correspond to live ranges
Without SSA, these analyses require iterative dataflow algorithms over all possible assignment-use paths. SSA makes them linear-time in most cases.
Q3: What is MLIR and how does it extend LLVM's design philosophy?
MLIR (Multi-Level Intermediate Representation) extends LLVM's core insight (shared IR enables shared optimizations) to multiple abstraction levels. LLVM IR sits at one level - close to machine instructions, without knowledge of tensor operations or loop structure. ML compilers need to optimize at higher levels: tensor algebra, named linear algebra operations, polyhedral loop nests.
MLIR introduces "dialects" - extensible sets of operations, types, and attributes that coexist in the same IR. A linalg.matmul operation can be tiled and fused at the linalg level (when its structure as a matrix multiply is still visible), then lowered to affine.for loops (when loop structure is needed for polyhedral analysis), then to vector operations (for SIMD), then to llvm dialect (for final codegen).
The key innovation: the same IR infrastructure handles all these levels. The same pass infrastructure, the same verifier, the same textual format. You write transformations at the most natural level. Information is not lost prematurely.
Q4: Explain what Triton is and how it relates to MLIR and LLVM.
Triton is a Python-based language for writing GPU kernels at a higher level than CUDA. Instead of thinking in individual GPU threads, you think in "blocks" of data. The Triton compiler handles thread assignment, memory coalescing, and shared memory management.
The compilation pipeline: Python Triton code is parsed into Triton IR (MLIR-based). Triton IR is lowered through multiple MLIR dialects: triton dialect to triton_gpu dialect (with explicit GPU thread/warp/block structure), to LLVM dialect, to LLVM IR, to NVPTX (NVIDIA GPU assembly), which the CUDA driver JIT-compiles to device-specific binary.
Triton's importance for ML: TorchInductor generates Triton code as its primary GPU backend. When you use torch.compile, the fused operations are compiled to Triton kernels. Triton provides an abstraction layer between PyTorch's graph-level IR and raw CUDA: safe enough to generate automatically, low enough to be fast.
Q5: How does MLIR enable operator fusion? Walk through a concrete example.
Operator fusion in MLIR happens by transforming the IR before individual operations are lowered to kernels.
Example: softmax(relu(linear(x))) without fusion generates 3 kernels. With MLIR fusion:
- At the
linalgdialect level, the ops arelinalg.matmul(for linear),linalg.generic(for relu),linalg.generic(for softmax). - MLIR's element-wise fusion analysis detects that relu's output is consumed only by softmax, and they have compatible access patterns.
- A "producer-consumer fusion" pass merges the two
linalg.genericops into one, with the relu computation inlined into the softmax loop body. - After lowering to
affine.forloops, the fused loop reads the matmul output once, applies relu and softmax in a single pass, writes the result.
Without fusion: matmul output written to HBM, relu reads it back and writes relu result to HBM, softmax reads relu result from HBM. Three HBM reads of the intermediate tensor.
With fusion: matmul output is consumed immediately in registers. Only the final softmax result is written to HBM. One HBM write.
For bandwidth-bound operations on modern GPUs (H100: 3.35 TB/s HBM, 1979 TFLOPS compute), eliminating two HBM reads can be worth more than optimizing the compute.
Q6: What is the difference between the linalg, affine, and vector MLIR dialects? Why are all three needed?
Each dialect captures different structural information needed for different optimization levels:
linalg dialect: named linear algebra operations with explicit semantics (linalg.matmul knows it is a matrix multiply, not just a triple-nested loop). This structural knowledge enables: tiling for cache blocking, packing for layout optimization, fusion analysis, parallelism extraction. A pass that sees linalg.matmul knows it can tile it and choose tile sizes based on L1/L2 cache sizes.
affine dialect: structured loop nests with affine access patterns (array indices that are linear functions of loop induction variables: A[2*i + 3*j + 1]). This level enables polyhedral analysis: dependency analysis, loop transformations (interchange, skewing, unrolling), and automatic parallelization. An affine.for loop is more constrained than a general scf.for loop, and that constraint enables stronger analysis.
vector dialect: fixed-length SIMD operations (vector.contract, vector.transfer_read). This is the abstraction for auto-vectorization. A vector.contract maps directly to SIMD hardware instructions (AVX-512 vfmadd, Tensor Cores mma.sync). This level is above specific SIMD instruction sets but below general loop structure.
You need all three because information at the higher level enables analysis that is impossible or expensive at the lower level. By the time you reach vector ops, the loop structure that enabled fusion is gone. By the time you reach affine, the high-level op semantics that enabled fusion are gone.
Q7: How does a compiler pass like loop-vectorize work? What code patterns enable auto-vectorization?
LLVM's loop-vectorize pass attempts to convert scalar loops into SIMD loops. It:
- Analyzes loop dependences: can iterations run in parallel? (No loop-carried dependence through memory)
- Analyzes memory access patterns: are accesses sequential (stride-1)? Stride-1 enables efficient
vmovupsloads. - Determines safe vector width: what SIMD width is profitable on the target (128-bit SSE, 256-bit AVX2, 512-bit AVX-512)?
- Generates vector code with peeling for the leftover elements (when n is not a multiple of vector width).
Code patterns that enable vectorization:
// GOOD: sequential access, no dependence
for (int i = 0; i < n; i++) c[i] = a[i] + b[i];
// BAD: loop-carried dependence (each iteration depends on previous)
for (int i = 1; i < n; i++) a[i] = a[i-1] * 2;
// BAD: indirect addressing (gather/scatter - expensive)
for (int i = 0; i < n; i++) c[i] = a[idx[i]]; // idx is variable
// BAD: data-dependent branching (SIMD has limited predication)
for (int i = 0; i < n; i++) if (a[i] > 0) c[i] = a[i]; else c[i] = 0;
// FIX: use branchless form: c[i] = max(a[i], 0)
For ML workloads: ensure element-wise operations on contiguous memory with no cross-element dependencies. Avoid indirect indexing in hot loops. Use float32 or float16 (not float64) to fit more elements in SIMD registers.
