Flash Attention Kernel Deep Dive
Reading time: ~45 min · Interview relevance: Very High · Target roles: ML Systems Engineer, CUDA Developer, Research Engineer
Standard attention materializes the full N x N attention matrix in HBM - O(N²) memory that limits context length. FlashAttention never writes that matrix to HBM at all, reducing memory complexity to O(N) while computing the exact same result. The same 80GB A100 that struggled at 2048 tokens now handles 128k tokens.
The Production Scenario
It is 11pm and the inference team is trying to push context length from 4096 to 16384 tokens for the new document-summarization product. The A100 has 80GB of HBM. Someone runs the benchmark. It OOMs immediately.
The math is simple and brutal. Standard attention with N=16384 tokens and d=64 per head produces an attention matrix S of shape [16384, 16384]. That is bytes = 1.07 GB per attention head. With 32 heads and 32 transformer layers, the attention matrices alone need 1.07 TB. The GPU has 80GB.
Even at reasonable context lengths the situation is bad. N=4096, d=64, 32 heads: GB just for the S matrices in a single layer. Reading and writing 2.1 GB per layer through HBM at 2 TB/s takes about 1ms per layer, purely for memory traffic on a matrix the model does not actually need to store.
FlashAttention (Dao et al., 2022) solved this by asking a different question: what if we never wrote the full N x N matrix at all? What if we computed the output directly from Q, K, V using tiles, keeping all intermediate values in SRAM? The result is exact - not approximate - and uses O(N) memory instead of O(N²). The same A100 that failed at 16k tokens now runs 128k context.
This lesson explains exactly how that works: the online softmax trick, the tiling algorithm, IO complexity analysis, and the FlashAttention 2 and 3 improvements.
Why This Exists - The Problem with Standard Attention
The Standard Attention Algorithm
Recall the attention mechanism. Given queries Q, keys K, values V of shape where N is sequence length and d is head dimension:
This is three separate operations. Each one reads from HBM and writes back to HBM. Here is the memory traffic breakdown:
| Operation | Reads from HBM | Writes to HBM | Shape |
|---|---|---|---|
| Q: , K: | S: | S is the problem | |
| S: | P: | Reads and re-writes N² | |
| P: , V: | O: | Reads N² again |
The S and P matrices each cost floats. For N=4096 and float32:
With 32 heads: 2.1 GB just for attention matrices in one layer. For a 32-layer model: 67 GB. That is 84% of an A100's entire HBM capacity, and that's before weights, activations, or anything else.
Why Standard Attention is Memory-Bound, Not Compute-Bound
The roofline model tells us when a kernel is compute-bound versus memory-bound. The arithmetic intensity of standard attention is low because most of the work is moving the N x N matrix around.
For : the operation requires multiply-adds, but reads floats and writes floats. When N is large and d is small (d=64 is typical), the write dominates and intensity drops:
The A100 can do 312 TFLOPS but only sustains 2 TB/s HBM bandwidth. The ratio is 156 FLOP/byte. Standard attention's 32 FLOP/byte is well below this - the kernel is memory-bound by a factor of 5x. The GPU's compute units sit idle, waiting for HBM to deliver the next chunk of the N x N matrix.
:::note What Memory-Bound Really Means "Memory-bound" means the GPU computes so fast that it is starved for data. The arithmetic logic units finish their work and then wait - sometimes hundreds of nanoseconds - for HBM to send the next batch of numbers. Improving the algorithm's flop count does nothing. The only way to go faster is to reduce the number of bytes that need to cross the HBM-to-chip bus. :::
Historical Context
The Pre-FlashAttention Era
Before 2022, the standard approach to long-context attention was approximation. Sparse attention (Child et al., 2019) in GPT-3 era models dropped 90% of attention connections. Linear attention (Katharopoulos et al., 2020) replaced softmax with kernel functions to get complexity. Longformer (Beltagy et al., 2020) used sliding windows plus global tokens.
All of these changed what the model computed. The outputs were not the same as standard attention - they were approximations trading accuracy for speed.
The Online Softmax Insight
The key mathematical insight that makes FlashAttention possible was actually discovered earlier. Milakov and Gimelshein (2018) at NVIDIA published "Online normalizer calculation for softmax" - a way to compute softmax incrementally without seeing all values first. You process inputs in chunks, maintaining a running maximum and running sum, and produce the exact same result as batch softmax.
FlashAttention's core contribution was recognizing that this online softmax trick, combined with careful tiling, makes it possible to fuse the entire attention computation into a single kernel that never materializes the N x N matrix.
FlashAttention Timeline
| Year | Work | Speedup |
|---|---|---|
| 2018 | Online softmax (Milakov & Gimelshein) | Foundation |
| 2022 | FlashAttention v1 (Dao et al.) | 2-4x over PyTorch |
| 2023 | FlashAttention v2 (Dao) | 2x over FA1 |
| 2024 | FlashAttention v3 (Shah et al., H100) | 1.5-2x over FA2 |
Core Concept: Online Softmax
Before understanding FlashAttention, you need to understand how to compute softmax without seeing all values at once.
Naive Softmax (Requires Full Row)
Standard softmax for a row :
This requires two passes over all N values: first to find the denominator, then to compute each output. If N = 4096 and x lives in HBM, that's two full HBM reads of a 4096-element row.
In practice, numerical stability requires subtracting the max first:
This makes three passes: one to find max, one to compute the denominator, one to compute outputs.
Online Softmax (Processes One Tile at a Time)
The online softmax algorithm processes x in blocks without storing all values. For each new block, it updates a running maximum and running sum .
Initialize:
After processing block with values :
The rescaling factor corrects the previous partial sum for the new maximum. When , the previous exponentials were computed with a smaller normalization, so we scale them down.
Final output: divide each by .
This produces the exact same result as the three-pass version, but processes the input in one pass through tiles that fit in SRAM.
import numpy as np
def online_softmax_demo(x_blocks):
"""
Demonstrate online softmax - processes blocks without storing full input.
x_blocks: list of 1D numpy arrays (the tiles)
Returns: exact softmax over concatenated x_blocks
"""
m = float('-inf') # running max
l = 0.0 # running sum of exp(x - m)
# Pass 1: compute running max and sum (keeping per-element exp for demo)
all_exps = []
for block in x_blocks:
m_new = max(m, float(np.max(block)))
# Rescale previous sum to account for new max
l = np.exp(m - m_new) * l + np.sum(np.exp(block - m_new))
m = m_new
all_exps.append(np.exp(block - m)) # store for normalization
# Normalize: divide each exp by final l
# (In actual FlashAttention, output accumulation handles this inline)
result = np.concatenate(all_exps) / l
# Verification
x_full = np.concatenate(x_blocks)
reference = np.exp(x_full - np.max(x_full))
reference = reference / reference.sum()
assert np.allclose(result, reference, atol=1e-6), "Online softmax result mismatch!"
print(f"Max absolute error: {np.max(np.abs(result - reference)):.2e}")
return result
# Test with 4 tiles of 16 elements each (simulating 64-element sequence)
np.random.seed(42)
x_blocks = [np.random.randn(16).astype(np.float32) for _ in range(4)]
result = online_softmax_demo(x_blocks)
print(f"Output sum: {result.sum():.6f}") # Should be 1.0
FlashAttention Tiling Algorithm
Now we put the online softmax together with tiling to compute attention without materializing N x N.
The Tiling Strategy
Split Q into row blocks of size and split K, V into column blocks of size . For each query block , iterate over all K, V blocks to accumulate the output.
The sizes are chosen so that each tile triple fits in SRAM:
where is the SRAM capacity (typically 96 KB - 200 KB per SM on modern GPUs).
Solving: (a common heuristic when ).
The Algorithm (Pseudocode)
Input: Q, K, V in HBM, shape [N, d]
Output: O in HBM, shape [N, d]
Initialize O = zeros[N, d] in HBM
Initialize l = zeros[N] in HBM (running denominator)
Initialize m = -inf[N] in HBM (running max)
For each block Q_i (rows i*Br to (i+1)*Br):
Load Q_i from HBM to SRAM
Initialize O_i = zeros[Br, d] in registers
Initialize l_i = zeros[Br] in registers
Initialize m_i = -inf[Br] in registers
For each block K_j, V_j (cols j*Bc to (j+1)*Bc):
Load K_j, V_j from HBM to SRAM
# Compute tile of attention scores
S_ij = Q_i @ K_j^T # shape [Br, Bc], stays in SRAM/registers
# Online softmax update
m_ij = rowmax(S_ij) # shape [Br]
P_ij = exp(S_ij - m_ij[:, None]) # shape [Br, Bc]
l_ij = rowsum(P_ij) # shape [Br]
# Update running statistics
m_i_new = max(m_i, m_ij)
l_i_new = exp(m_i - m_i_new) * l_i + exp(m_ij - m_i_new) * l_ij
# Rescale accumulated output and add new contribution
O_i = diag(exp(m_i - m_i_new)) @ O_i
+ diag(exp(m_ij - m_i_new)) @ (P_ij @ V_j)
m_i = m_i_new
l_i = l_i_new
# Final normalization
O_i = diag(1/l_i) @ O_i
Write O_i to HBM
The entire N x N attention matrix is never written to HBM. The only HBM writes are the final output O, which is - linear in N.
Memory Complexity Analysis
Concretely for N=4096, d=64, float16:
A 64x memory reduction. For N=32768, the reduction is 4096x.
IO Complexity Analysis
How many bytes travel between HBM and chip? Let M = SRAM size, N = sequence length, d = head dimension.
Standard attention HBM reads/writes:
- Write S: elements
- Read S, write P: elements
- Read P, read V, write O: elements
- Total: elements =
FlashAttention HBM reads/writes:
- Read Q, K, V: elements
- For each of query blocks, read all K and V: = elements... but , so this is
- Write O: elements
- Total: elements
For d=64, M=96KB (6144 float16 elements):
FlashAttention reads/writes about 6x fewer bytes from HBM. In practice, measured speedups are 2-4x because there is overhead in the tiling logic and the roofline is not perfectly tight.
def compute_flash_attention_io(N, d, M_sram_bytes, dtype_bytes=2):
"""
Estimate HBM IO for standard vs FlashAttention.
Args:
N: sequence length
d: head dimension
M_sram_bytes: SRAM capacity in bytes
dtype_bytes: 2 for fp16, 4 for fp32
"""
M = M_sram_bytes // dtype_bytes # SRAM in elements
# Standard attention HBM IO (read + write counts in elements)
standard_io = 4 * N * N + 4 * N * d # S, P read+write + Q,K,V,O
# FlashAttention HBM IO
# Q, K, V read once: 3*N*d
# K, V re-read for each query block: (N/Br) * N * d, Br = M / (4d)
Br = M // (4 * d)
Bc = Br
n_blocks_q = (N + Br - 1) // Br
flash_io = (3 * N * d + # initial Q, K, V reads
n_blocks_q * N * d + # K re-reads per query block
n_blocks_q * N * d + # V re-reads per query block
N * d) # O write
standard_bytes = standard_io * dtype_bytes
flash_bytes = flash_io * dtype_bytes
print(f"Sequence length N={N}, head dim d={d}")
print(f"SRAM capacity: {M_sram_bytes/1024:.0f} KB = {M} elements")
print(f"Block size Br=Bc={Br}")
print(f"Standard attention HBM IO: {standard_bytes/1e6:.1f} MB")
print(f"FlashAttention HBM IO: {flash_bytes/1e6:.1f} MB")
print(f"IO reduction: {standard_bytes/flash_bytes:.1f}x")
print()
# A100 has 192KB L2 per SM; using 96KB as effective SRAM for tiling
compute_flash_attention_io(N=4096, d=64, M_sram_bytes=96*1024)
compute_flash_attention_io(N=8192, d=128, M_sram_bytes=96*1024)
compute_flash_attention_io(N=32768, d=64, M_sram_bytes=96*1024)
FlashAttention 2: Warp-Level Improvements
FlashAttention v1 was already a major step. FlashAttention 2 (Dao, 2023) extracted another 2x by fixing inefficiencies in how work was distributed across GPU warps.
Problem 1: Parallelism was Sequence-Length Limited
FA1 parallelized across heads and batch size, but within each head, work was sequential across Q blocks. When sequence length is short (e.g., N=512), there are not enough blocks to keep all SMs busy. FA2 adds parallelism across Q blocks, so even short sequences saturate the GPU.
Problem 2: Non-Matmul FLOPS Killed Tensor Core Utilization
The A100's Tensor Cores are 16x faster for matmul than for other ops. FA1 had too many non-matmul operations: the online softmax updates, rescaling factors, and normalization. These ran on regular CUDA cores at 19.5 TFLOPS instead of Tensor Cores at 312 TFLOPS.
FA2 reduced non-matmul FLOPs by reorganizing the computation. The key change: rather than computing the softmax normalization tile by tile (which requires many exp and division ops per tile), FA2 batches the rescaling and defers normalization to the end of each Q block.
Problem 3: Warp Communication Overhead
FA1 split K/V tiles across warps, requiring warps to communicate their partial results via shared memory. This created synchronization barriers. FA2 splits Q tiles across warps instead - each warp owns its query rows completely and independently accumulates the output. No inter-warp synchronization needed for the inner loop.
FA1 warp strategy (shared K/V):
Warp 0: owns K_j[:Bc/2], V_j[:Bc/2] → partial O
Warp 1: owns K_j[Bc/2:], V_j[Bc/2:] → partial O
→ must sum partial O across warps (shared memory communication)
FA2 warp strategy (independent Q rows):
Warp 0: owns Q_i[:Br/4] rows → full O rows, no coordination needed
Warp 1: owns Q_i[Br/4:Br/2] rows → full O rows, no coordination
→ warps are completely independent, no synchronization barriers
FA2 Speedup in Practice
| Configuration | FA1 throughput | FA2 throughput | Speedup |
|---|---|---|---|
| A100, N=2048, d=64, causal | 149 TFLOPS | 227 TFLOPS | 1.52x |
| A100, N=4096, d=128, causal | 198 TFLOPS | 335 TFLOPS | 1.69x |
| A100, N=8192, d=128, causal | 213 TFLOPS | 368 TFLOPS | 1.73x |
FlashAttention 3: H100-Specific Optimizations
FlashAttention 3 (Shah et al., 2024) targets the H100 specifically, exploiting hardware features that did not exist on A100.
Asynchronous Memory Copies (TMA)
H100 has the Tensor Memory Accelerator (TMA), a dedicated hardware unit that moves data between HBM and shared memory without using CUDA cores. In FA2, memory loads consumed compute resources. In FA3, TMA handles all K, V loading asynchronously in the background while CUDA cores run the matmul for the previous tile.
This is the "ping-pong" scheduling: while GEMM processes tile j, TMA pre-fetches tile j+1 into a second shared memory buffer. By the time GEMM finishes tile j, tile j+1 is already waiting in SRAM.
Warp Specialization
FA3 dedicates different warps to different tasks. "Producer warps" exclusively run TMA copy operations. "Consumer warps" exclusively run GEMM and softmax. This specialization allows both to run at full speed concurrently, compared to FA2 where each warp alternated between copying and computing.
FP8 Support
H100 supports FP8 (E4M3) Tensor Core operations at 2x the throughput of FP16. FA3 adds FP8 attention, which requires "incoherent processing" - a technique that handles the reduced precision of FP8 without sacrificing output quality by applying per-block scaling factors.
FA3 Performance on H100
| Configuration | FA2 throughput | FA3 throughput | Speedup |
|---|---|---|---|
| H100, N=8192, d=128, fp16 | 670 TFLOPS | 1000 TFLOPS | 1.49x |
| H100, N=8192, d=128, fp8 | - | 1300 TFLOPS | - |
| H100, N=32768, d=128, fp16 | 720 TFLOPS | 1100 TFLOPS | 1.53x |
Using FlashAttention in PyTorch
The Easy Path: scaled_dot_product_attention
Since PyTorch 2.0, torch.nn.functional.scaled_dot_product_attention automatically dispatches to FlashAttention when conditions are met:
import torch
import torch.nn.functional as F
# This automatically uses FlashAttention on supported hardware
def efficient_attention(q, k, v, causal=False):
"""
q, k, v: [batch, heads, seq_len, head_dim]
All must be fp16 or bf16 for FlashAttention to activate.
"""
# PyTorch dispatches to Flash Attention when:
# 1. CUDA is available
# 2. dtype is fp16 or bf16 (NOT float32)
# 3. head_dim is 64 or 128 (or 256 in newer versions)
# 4. No custom attention mask (or causal mask only)
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False, # disable fallback math kernel
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=None # defaults to 1/sqrt(head_dim)
)
return output
# Verify FlashAttention is being used
def check_flash_attention_active(seq_len=2048, head_dim=64, num_heads=8, batch=2):
device = 'cuda'
dtype = torch.float16 # Required for FlashAttention
q = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)
k = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)
v = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)
# Check which kernels are available
print("Flash attention available:", torch.backends.cuda.flash_sdp_enabled())
print("Math kernel available:", torch.backends.cuda.math_sdp_enabled())
print("Mem-efficient available:", torch.backends.cuda.mem_efficient_sdp_enabled())
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
try:
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(f"FlashAttention succeeded. Output shape: {out.shape}")
except RuntimeError as e:
print(f"FlashAttention NOT used: {e}")
return out
if torch.cuda.is_available():
check_flash_attention_active()
When FlashAttention Does NOT Activate
import torch
import torch.nn.functional as F
# Case 1: float32 - FlashAttention requires fp16/bf16
q_fp32 = torch.randn(1, 8, 1024, 64, device='cuda', dtype=torch.float32)
# This falls back to math kernel silently. Profile shows slow N^2 attention.
# Case 2: Arbitrary attention mask - Flash only supports causal mask
custom_mask = torch.zeros(1024, 1024, device='cuda', dtype=torch.bool)
custom_mask[:512, 512:] = True # Some custom pattern
# Flash can NOT handle arbitrary masks. Falls back to standard attention.
# Case 3: Head dim not in supported set
q_bad_dim = torch.randn(1, 8, 1024, 96, device='cuda', dtype=torch.float16)
# head_dim=96 is not supported in older Flash versions (64, 128, 256 are safe)
# How to detect which kernel was used:
# Use torch.profiler or nsight compute.
# In torch.profiler, look for "flash_fwd" vs "scaled_dot_product_attention_flash_attention_kernel"
# vs slower "native_sdp" or "aten::_softmax" patterns.
def audit_attention_kernel(q, k, v):
"""Profile which kernel actually runs."""
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True
) as prof:
_ = F.scaled_dot_product_attention(q, k, v, is_causal=True)
# Look for flash_fwd in kernel names
events = prof.key_averages()
flash_found = any('flash' in e.key.lower() for e in events)
print(f"FlashAttention kernel detected: {flash_found}")
for e in events:
if 'attention' in e.key.lower() or 'flash' in e.key.lower():
print(f" {e.key}: {e.cuda_time_total/1e3:.2f}ms")
The flash_attn Package
The flash_attn package provides direct access to the FlashAttention kernels with more control:
# Install: pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention
# Basic usage - qkv packed (common for self-attention)
def flash_self_attention_packed(qkv, causal=True, softmax_scale=None):
"""
qkv: [batch, seq_len, 3, num_heads, head_dim]
Returns: [batch, seq_len, num_heads, head_dim]
"""
return flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
softmax_scale=softmax_scale, # None -> 1/sqrt(head_dim)
causal=causal,
)
# Separate Q, K, V (cross-attention or when shapes differ)
def flash_cross_attention_explicit(q, k, v, causal=False, softmax_scale=None):
"""
q: [batch, seq_len_q, num_heads, head_dim]
k: [batch, seq_len_k, num_heads, head_dim]
v: [batch, seq_len_k, num_heads, head_dim]
Returns: [batch, seq_len_q, num_heads, head_dim]
"""
return flash_attn_func(
q, k, v,
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=causal,
window_size=(-1, -1), # full attention
)
# Sliding window attention (FA2 feature - useful for long context)
def flash_sliding_window_attention(q, k, v, window_size=512):
"""
Each token attends only to the previous window_size tokens.
Memory and compute scale as O(N * window_size) not O(N^2).
"""
return flash_attn_func(
q, k, v,
causal=True,
window_size=(window_size, 0), # (left, right) window
)
# Drop-in module replacement
import torch.nn as nn
class FlashMultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, causal=True):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.causal = causal
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.flash_attn = FlashSelfAttention(causal=causal)
def forward(self, x):
B, T, C = x.shape
# x: [B, T, C]
qkv = self.qkv_proj(x)
# Reshape to [B, T, 3, H, D]
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
# flash_attn expects fp16 or bf16
qkv = qkv.to(torch.float16)
out = self.flash_attn(qkv) # [B, T, H, D]
out = out.reshape(B, T, C).to(x.dtype)
return self.out_proj(out)
Architecture Diagram: FlashAttention Tiling
Production Engineering Notes
Memory Savings vs Standard Attention at Scale
def attention_memory_comparison(N, d, num_heads, num_layers, batch_size, dtype_bytes=2):
"""
Compare peak HBM usage between standard and FlashAttention.
All values in GB.
"""
# Standard attention: must hold S and P matrices
# Per head per layer: 2 * N^2 floats (S and P)
standard_attn_bytes = (2 * N * N * dtype_bytes *
num_heads * num_layers * batch_size)
# FlashAttention: only needs Q, K, V, O
# Per head per layer: 4 * N * d floats
flash_attn_bytes = (4 * N * d * dtype_bytes *
num_heads * num_layers * batch_size)
print(f"N={N}, d={d}, {num_heads} heads, {num_layers} layers, batch={batch_size}")
print(f"Standard attention matrices: {standard_attn_bytes/1e9:.1f} GB")
print(f"FlashAttention activations: {flash_attn_bytes/1e9:.1f} GB")
print(f"Memory reduction: {standard_attn_bytes/flash_attn_bytes:.0f}x")
print()
# LLaMA-2 7B parameters, inference scenarios
attention_memory_comparison(N=4096, d=128, num_heads=32, num_layers=32, batch_size=1)
attention_memory_comparison(N=32768, d=128, num_heads=32, num_layers=32, batch_size=1)
attention_memory_comparison(N=128000, d=128, num_heads=32, num_layers=32, batch_size=1)
Verifying FlashAttention is Active in Training
import torch
import torch.nn.functional as F
from contextlib import contextmanager
@contextmanager
def require_flash_attention():
"""Context manager that raises if FlashAttention is not used."""
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False, # disable fallback
enable_mem_efficient=False
):
yield
def verify_model_uses_flash(model, sample_input):
"""
Run forward pass and verify FlashAttention kernels were invoked.
Raises RuntimeError if standard attention fallback occurs.
"""
with require_flash_attention():
try:
output = model(sample_input)
print("FlashAttention verified - forward pass succeeded under flash-only mode")
return output
except RuntimeError as e:
print(f"FALLBACK DETECTED: {e}")
print("Check: dtype must be fp16/bf16, head_dim in [64, 128, 256], no custom masks")
raise
Common Production Failure: Wrong dtype
The single most common reason FlashAttention silently falls back to standard attention is float32 inputs. Models often use float32 by default, and the fallback is silent - you see no error, just slower performance.
# WRONG: float32 silently falls back to standard attention
model = TransformerModel(...)
output = model(input_ids) # If model weights are fp32, attention is fp32 too
# RIGHT: explicitly cast to bf16 or fp16
model = model.to(torch.bfloat16) # bf16 is preferred for training stability
output = model(input_ids.cuda())
# Or use automatic mixed precision:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = model(input_ids)
Common Mistakes
:::danger FlashAttention Does Not Activate for float32
F.scaled_dot_product_attention silently falls back to the standard O(N^2) memory kernel when inputs are float32. This is the most common production mistake - the code runs, the output is correct, but memory usage is 64x higher than expected and throughput is 5x lower. Always use fp16 or bf16 for attention. Use torch.autocast or explicitly cast model parameters.
:::
:::danger Using Unsupported Head Dimensions FlashAttention is implemented for specific head dimensions: 64, 128, and 256 (varies by version). If your model uses d=96 or d=160, it silently falls back. This is especially common when people experiment with custom model architectures. Always print the head dimension and verify it matches a supported value before benchmarking. :::
:::warning Custom Attention Masks Break FlashAttention
FlashAttention only supports causal masking natively. Arbitrary attention masks (e.g., block-diagonal for document packing, sliding window via manual masking, ALiBi-style biases added before softmax) all cause fallback to the standard kernel. Use the window_size parameter in flash_attn_func for sliding window attention, and use FA's built-in causal flag rather than constructing a manual causal mask.
:::
:::warning Backward Pass Memory is Higher Than Forward FlashAttention's O(N) memory claim applies to the forward pass. During backward, the kernel needs to recompute S and P tiles for each Q block to compute gradients. This recomputation is done on the fly (no materialization), but the backward pass does about 2-3x more FLOPS than the forward pass. Peak activation memory during training is still much lower than standard attention, but the speedup is smaller in backward than forward. :::
:::tip Checking FlashAttention Version at Runtime
import flash_attn
print(flash_attn.__version__) # Should be 2.x for production use
# Also check:
import torch
print(torch.__version__)
# torch >= 2.0 required for F.scaled_dot_product_attention dispatch
:::
Interview Questions
Q1: Why does standard attention have O(N²) memory and how does FlashAttention reduce this to O(N)?
Standard attention computes where S has shape [N, N]. With N=4096, S alone is 4096^2 x 2 bytes = 32 MB per head. This matrix must be written to HBM (because it's too large for SRAM), then read back for the softmax, then read again for the output projection. Total HBM traffic is approximately elements.
FlashAttention avoids materializing S entirely. It tiles Q into blocks of size and K, V into blocks of size , chosen so that a single tile fits in SRAM. The key enabler is online softmax: you can accumulate the softmax output incrementally as tiles arrive, maintaining only a running max and running sum, without ever seeing the full row at once.
The only thing written to HBM is the final output O of shape [N, d], which is O(N). The N x N matrix S exists only transiently in SRAM/registers, never in HBM.
The IO complexity is where M is SRAM size, versus for standard attention - a reduction factor of for typical values.
Q2: Derive the online softmax update rule. Why is the rescaling factor exp(m_prev - m_new) necessary?
Say you have computed partial softmax statistics over the first j tiles: running max and running denominator .
Now you see tile j+1 with max .
The new global max is .
The previous partial sum was computed relative to the old max :
But we need the partial sum relative to the new max :
This is the rescaling factor: . When the new tile has a higher maximum, , so the factor is less than 1 - we scale down the old partial sum because the new maximum means old values were "over-weighted."
The accumulated output O_i must be similarly rescaled: .
After processing all tiles, we divide O_i by the final to get the normalized output. The result is identical to computing full softmax first and then doing the weighted sum.
Q3: What specific improvements did FlashAttention 2 make over FlashAttention 1?
Three categories of improvement:
1. Parallelism across Q blocks - FA1 parallelized over batch and head dimensions only. Within a head, it processed Q blocks sequentially. FA2 assigns different Q blocks to different thread blocks, which run in parallel on different SMs. This eliminates the sequential bottleneck for short sequences.
2. Warp partitioning for K/V processing - FA1 split K/V tiles across warps, requiring warps to sum their partial outputs via shared memory (inter-warp communication). FA2 assigns non-overlapping rows of Q to each warp. Each warp independently computes its output rows with full access to all K/V. No inter-warp communication in the inner loop.
3. Reduced non-matmul FLOPs - The A100 Tensor Cores run at 312 TFLOPS for fp16 matmul but only 19.5 TFLOPS for non-matmul ops. FA1 had many exp, division, and comparison operations in the inner loop. FA2 restructures to minimize these operations - for example, deferring the final normalization to after the K/V loop rather than applying it each tile.
Combined, these produce about 2x throughput improvement over FA1 on A100.
Q4: How would you verify that FlashAttention is actually being used in your model?
Three complementary approaches:
Method 1: Force flash-only mode and observe failures
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
out = model(x)
# If this raises RuntimeError, standard attention was being used as fallback.
Method 2: Profile with PyTorch profiler
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
out = model(x)
# Look for 'flash_fwd' in the kernel names. Standard attention shows 'aten::_softmax'.
Method 3: Memory usage signature - FlashAttention should use roughly bytes for attention activations, not bytes. At N=4096 with 32 heads, standard attention uses ~2GB, FlashAttention uses ~32MB. Use torch.cuda.memory_stats() before and after the attention forward pass to measure.
Common reasons FlashAttention is not used:
- dtype is float32 (must be fp16 or bf16)
- head_dim is not in the supported set (64, 128, 256)
- Custom attention mask is provided (only causal mask is natively supported)
- PyTorch version is older than 2.0 (no SDPA dispatch)
- flash_attn package version mismatch with installed CUDA version
Q5: What is the IO complexity of FlashAttention and how does it compare to standard attention?
Let N = sequence length, d = head dimension, M = SRAM capacity in elements.
Standard attention:
- Write S: elements
- Read S, compute softmax, write P: elements
- Read P and V, write O: elements
- Total: HBM IO
FlashAttention: The outer loop has iterations. In each, we load Q_i once ( elements) and iterate over K/V blocks. Each inner iteration loads elements. Total:
With : total = .
For A100 with M = 40 KB SRAM per SM and d=64:
FlashAttention uses about 5x fewer HBM bytes. In practice (accounting for tiling overhead and non-matmul ops), measured speedups are 2-4x.
Q6: FlashAttention 3 targets H100 specifically. What H100 features does it exploit that A100 lacks?
Three H100-specific hardware features:
1. Tensor Memory Accelerator (TMA) - H100 has a dedicated hardware unit for async data movement between HBM and shared memory. FA3 uses TMA to pre-fetch the next K/V tile into a second shared memory buffer while the current tile's GEMM runs. This overlaps memory transfer with compute, eliminating the "stall and load" pattern of FA2. FA3 uses a "ping-pong" schedule: GEMM processes buffer A while TMA fills buffer B, then they swap.
2. Warp group specialization - H100 supports "warp group async" where different warp groups run different instruction mixes concurrently. FA3 designates "producer warps" that exclusively run TMA operations and "consumer warps" that exclusively run GEMM + softmax. Both run concurrently at full efficiency, unlike FA2 where each warp alternated between data copying and compute.
3. FP8 Tensor Cores - H100 FP8 Tensor Cores achieve 2x the throughput of FP16. FA3 supports FP8 attention computation with "incoherent processing" - a technique that handles accumulated FP8 quantization error through per-block scaling factors without requiring full FP16 intermediate accumulation.
Result: FA3 achieves 1000+ TFLOPS on H100 at fp16, reaching about 75% of H100's theoretical peak attention throughput.
Summary
FlashAttention is one of the most important algorithmic innovations in deep learning systems of the past decade. The key ideas are simple once you see them:
-
Standard attention is memory-bound because the N x N attention matrix is too large for SRAM and must live in HBM. Every read/write is slow.
-
Tiling avoids materialization by processing Q in blocks and streaming K, V through SRAM. The N x N matrix never exists as a contiguous allocation.
-
Online softmax enables exact computation despite tiling, by maintaining a running max and sum that can be updated incrementally as each K/V tile arrives.
-
IO complexity drops from to - roughly a 5-6x reduction for typical hyperparameters, translating to 2-4x real-world speedup.
-
FA2 and FA3 extracted further improvements through better parallelism (FA2), hardware-specific async pipelines (FA3 on H100), and FP8 support.
In production, FlashAttention is not optional for sequences above 2048 tokens. It is the difference between a model that fits in memory and one that does not.
