Skip to main content

Module 4: Kernel Optimization

FlashAttention is not faster because it does less arithmetic. It does the same arithmetic as standard attention. It is faster because it restructures the computation to minimize the number of times data moves between HBM and the SM's shared memory. That is the central insight of kernel optimization: the bottleneck is usually memory, not compute.

This module teaches you to think like a kernel author. You will learn why certain operations are slow, how to reason about memory access patterns, how to use Triton to write custom kernels without writing CUDA, and how torch.compile and XLA fit into the optimization story.

Why Kernels Matter

When you call F.scaled_dot_product_attention() in PyTorch, you are calling a kernel - a piece of GPU code that has been carefully written to use the hardware efficiently. When you write output = q @ k.T, you are calling a matrix multiply kernel followed by a transpose kernel followed by another matrix multiply kernel, with HBM reads and writes between each one.

FlashAttention fuses these operations and tiles the computation to keep the working set in SRAM. The result is 2-4x faster attention and O(sqrt(N)) memory instead of O(N^2). The math is identical. The hardware utilization is completely different.

The same principle applies everywhere: every time you can avoid writing to HBM and reading back from HBM, you save time. Operator fusion, tiling, and memory layout choices are all variations on this theme.

The Optimization Stack

Lessons in This Module

#LessonKey Concept
1What Is a Kernel and Why It MattersKernel dispatch, the gap between PyTorch and hardware
2FlashAttention Deep DiveTiling, online softmax, memory bandwidth analysis
3Triton Language for Custom KernelsWriting Triton kernels, blocked programming model
4Fused Operations and Operator FusionWhy fusion helps, what torch.compile fuses
5Memory Bandwidth vs Compute BoundArithmetic intensity, roofline for real operations
6Kernel Benchmarking and Profilingtriton.testing.do_bench, Nsight, comparing kernels
7torch.compile and TorchInductorCompilation modes, when it helps, when it hurts
8XLA and JAX CompilationXLA's approach, graph-level optimization, TPU targeting

Key Concepts You Will Master

  • Arithmetic intensity and the roofline - determining whether an operation is memory or compute bound
  • Tiled matrix multiply - the technique that makes efficient GPU matrix ops possible
  • FlashAttention algorithm - how online softmax enables O(sqrt(N)) memory attention
  • Triton programming model - blocked iteration, loading/storing tiles, writing your first custom kernel
  • torch.compile - understanding the dynamo/inductor pipeline and interpreting compilation failures

Prerequisites

© 2026 EngineersOfAI. All rights reserved.