Module 4: Kernel Optimization
FlashAttention is not faster because it does less arithmetic. It does the same arithmetic as standard attention. It is faster because it restructures the computation to minimize the number of times data moves between HBM and the SM's shared memory. That is the central insight of kernel optimization: the bottleneck is usually memory, not compute.
This module teaches you to think like a kernel author. You will learn why certain operations are slow, how to reason about memory access patterns, how to use Triton to write custom kernels without writing CUDA, and how torch.compile and XLA fit into the optimization story.
Why Kernels Matter
When you call F.scaled_dot_product_attention() in PyTorch, you are calling a kernel - a piece of GPU code that has been carefully written to use the hardware efficiently. When you write output = q @ k.T, you are calling a matrix multiply kernel followed by a transpose kernel followed by another matrix multiply kernel, with HBM reads and writes between each one.
FlashAttention fuses these operations and tiles the computation to keep the working set in SRAM. The result is 2-4x faster attention and O(sqrt(N)) memory instead of O(N^2). The math is identical. The hardware utilization is completely different.
The same principle applies everywhere: every time you can avoid writing to HBM and reading back from HBM, you save time. Operator fusion, tiling, and memory layout choices are all variations on this theme.
The Optimization Stack
Lessons in This Module
| # | Lesson | Key Concept |
|---|---|---|
| 1 | What Is a Kernel and Why It Matters | Kernel dispatch, the gap between PyTorch and hardware |
| 2 | FlashAttention Deep Dive | Tiling, online softmax, memory bandwidth analysis |
| 3 | Triton Language for Custom Kernels | Writing Triton kernels, blocked programming model |
| 4 | Fused Operations and Operator Fusion | Why fusion helps, what torch.compile fuses |
| 5 | Memory Bandwidth vs Compute Bound | Arithmetic intensity, roofline for real operations |
| 6 | Kernel Benchmarking and Profiling | triton.testing.do_bench, Nsight, comparing kernels |
| 7 | torch.compile and TorchInductor | Compilation modes, when it helps, when it hurts |
| 8 | XLA and JAX Compilation | XLA's approach, graph-level optimization, TPU targeting |
Key Concepts You Will Master
- Arithmetic intensity and the roofline - determining whether an operation is memory or compute bound
- Tiled matrix multiply - the technique that makes efficient GPU matrix ops possible
- FlashAttention algorithm - how online softmax enables O(sqrt(N)) memory attention
- Triton programming model - blocked iteration, loading/storing tiles, writing your first custom kernel
- torch.compile - understanding the dynamo/inductor pipeline and interpreting compilation failures
Prerequisites
- GPU Architecture
- CUDA Programming - helpful but not required
- PyTorch proficiency
