Skip to main content

Flash Attention Kernel Deep Dive

Reading time: ~45 min · Interview relevance: Very High · Target roles: ML Systems Engineer, CUDA Developer, Research Engineer

Standard attention materializes the full N x N attention matrix in HBM - O(N²) memory that limits context length. FlashAttention never writes that matrix to HBM at all, reducing memory complexity to O(N) while computing the exact same result. The same 80GB A100 that struggled at 2048 tokens now handles 128k tokens.


The Production Scenario

It is 11pm and the inference team is trying to push context length from 4096 to 16384 tokens for the new document-summarization product. The A100 has 80GB of HBM. Someone runs the benchmark. It OOMs immediately.

The math is simple and brutal. Standard attention with N=16384 tokens and d=64 per head produces an attention matrix S of shape [16384, 16384]. That is 163842×416384^2 \times 4 bytes = 1.07 GB per attention head. With 32 heads and 32 transformer layers, the attention matrices alone need 1.07 TB. The GPU has 80GB.

Even at reasonable context lengths the situation is bad. N=4096, d=64, 32 heads: 40962×4×32=2.14096^2 \times 4 \times 32 = 2.1 GB just for the S matrices in a single layer. Reading and writing 2.1 GB per layer through HBM at 2 TB/s takes about 1ms per layer, purely for memory traffic on a matrix the model does not actually need to store.

FlashAttention (Dao et al., 2022) solved this by asking a different question: what if we never wrote the full N x N matrix at all? What if we computed the output directly from Q, K, V using tiles, keeping all intermediate values in SRAM? The result is exact - not approximate - and uses O(N) memory instead of O(N²). The same A100 that failed at 16k tokens now runs 128k context.

This lesson explains exactly how that works: the online softmax trick, the tiling algorithm, IO complexity analysis, and the FlashAttention 2 and 3 improvements.


Why This Exists - The Problem with Standard Attention

The Standard Attention Algorithm

Recall the attention mechanism. Given queries Q, keys K, values V of shape [N,d][N, d] where N is sequence length and d is head dimension:

S=QKTshape: [N,N]S = QK^T \quad \text{shape: } [N, N]

P=softmax(S)shape: [N,N]P = \text{softmax}(S) \quad \text{shape: } [N, N]

O=PVshape: [N,d]O = PV \quad \text{shape: } [N, d]

This is three separate operations. Each one reads from HBM and writes back to HBM. Here is the memory traffic breakdown:

OperationReads from HBMWrites to HBMShape
S=QKTS = QK^TQ: N×dN \times d, K: N×dN \times dS: N×NN \times NS is the problem
P=softmax(S)P = \text{softmax}(S)S: N×NN \times NP: N×NN \times NReads and re-writes N²
O=PVO = PVP: N×NN \times N, V: N×dN \times dO: N×dN \times dReads N² again

The S and P matrices each cost N2N^2 floats. For N=4096 and float32:

Memory per head=40962×4 bytes=67 MB\text{Memory per head} = 4096^2 \times 4 \text{ bytes} = 67 \text{ MB}

With 32 heads: 2.1 GB just for attention matrices in one layer. For a 32-layer model: 67 GB. That is 84% of an A100's entire HBM capacity, and that's before weights, activations, or anything else.

Why Standard Attention is Memory-Bound, Not Compute-Bound

The roofline model tells us when a kernel is compute-bound versus memory-bound. The arithmetic intensity of standard attention is low because most of the work is moving the N x N matrix around.

For S=QKTS = QK^T: the operation requires N2×dN^2 \times d multiply-adds, but reads 2Nd2Nd floats and writes N2N^2 floats. When N is large and d is small (d=64 is typical), the write dominates and intensity drops:

Arithmetic Intensity=2N2d(2Nd+N2)×4 bytes2d432 FLOP/byte\text{Arithmetic Intensity} = \frac{2N^2 d}{(2Nd + N^2) \times 4 \text{ bytes}} \approx \frac{2d}{4} \approx 32 \text{ FLOP/byte}

The A100 can do 312 TFLOPS but only sustains 2 TB/s HBM bandwidth. The ratio is 156 FLOP/byte. Standard attention's 32 FLOP/byte is well below this - the kernel is memory-bound by a factor of 5x. The GPU's compute units sit idle, waiting for HBM to deliver the next chunk of the N x N matrix.

:::note What Memory-Bound Really Means "Memory-bound" means the GPU computes so fast that it is starved for data. The arithmetic logic units finish their work and then wait - sometimes hundreds of nanoseconds - for HBM to send the next batch of numbers. Improving the algorithm's flop count does nothing. The only way to go faster is to reduce the number of bytes that need to cross the HBM-to-chip bus. :::


Historical Context

The Pre-FlashAttention Era

Before 2022, the standard approach to long-context attention was approximation. Sparse attention (Child et al., 2019) in GPT-3 era models dropped 90% of attention connections. Linear attention (Katharopoulos et al., 2020) replaced softmax with kernel functions to get O(N)O(N) complexity. Longformer (Beltagy et al., 2020) used sliding windows plus global tokens.

All of these changed what the model computed. The outputs were not the same as standard attention - they were approximations trading accuracy for speed.

The Online Softmax Insight

The key mathematical insight that makes FlashAttention possible was actually discovered earlier. Milakov and Gimelshein (2018) at NVIDIA published "Online normalizer calculation for softmax" - a way to compute softmax incrementally without seeing all values first. You process inputs in chunks, maintaining a running maximum and running sum, and produce the exact same result as batch softmax.

FlashAttention's core contribution was recognizing that this online softmax trick, combined with careful tiling, makes it possible to fuse the entire attention computation into a single kernel that never materializes the N x N matrix.

FlashAttention Timeline

YearWorkSpeedup
2018Online softmax (Milakov & Gimelshein)Foundation
2022FlashAttention v1 (Dao et al.)2-4x over PyTorch
2023FlashAttention v2 (Dao)2x over FA1
2024FlashAttention v3 (Shah et al., H100)1.5-2x over FA2

Core Concept: Online Softmax

Before understanding FlashAttention, you need to understand how to compute softmax without seeing all values at once.

Naive Softmax (Requires Full Row)

Standard softmax for a row xRNx \in \mathbb{R}^N:

softmax(x)i=exij=1Nexj\text{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}

This requires two passes over all N values: first to find the denominator, then to compute each output. If N = 4096 and x lives in HBM, that's two full HBM reads of a 4096-element row.

In practice, numerical stability requires subtracting the max first:

softmax(x)i=eximax(x)j=1Nexjmax(x)\text{softmax}(x)_i = \frac{e^{x_i - \max(x)}}{\sum_{j=1}^{N} e^{x_j - \max(x)}}

This makes three passes: one to find max, one to compute the denominator, one to compute outputs.

Online Softmax (Processes One Tile at a Time)

The online softmax algorithm processes x in blocks without storing all values. For each new block, it updates a running maximum mm and running sum \ell.

Initialize: m0=,0=0m_0 = -\infty, \quad \ell_0 = 0

After processing block bb with values x(b)x^{(b)}:

mb=max(mb1,max(x(b)))m_b = \max(m_{b-1}, \max(x^{(b)}))

b=emb1mbb1+jx(b)exj(b)mb\ell_b = e^{m_{b-1} - m_b} \cdot \ell_{b-1} + \sum_{j \in x^{(b)}} e^{x_j^{(b)} - m_b}

The rescaling factor emb1mbe^{m_{b-1} - m_b} corrects the previous partial sum for the new maximum. When mb>mb1m_b > m_{b-1}, the previous exponentials were computed with a smaller normalization, so we scale them down.

Final output: divide each eximfinale^{x_i - m_{\text{final}}} by final\ell_{\text{final}}.

This produces the exact same result as the three-pass version, but processes the input in one pass through tiles that fit in SRAM.

import numpy as np

def online_softmax_demo(x_blocks):
"""
Demonstrate online softmax - processes blocks without storing full input.
x_blocks: list of 1D numpy arrays (the tiles)
Returns: exact softmax over concatenated x_blocks
"""
m = float('-inf') # running max
l = 0.0 # running sum of exp(x - m)

# Pass 1: compute running max and sum (keeping per-element exp for demo)
all_exps = []
for block in x_blocks:
m_new = max(m, float(np.max(block)))
# Rescale previous sum to account for new max
l = np.exp(m - m_new) * l + np.sum(np.exp(block - m_new))
m = m_new
all_exps.append(np.exp(block - m)) # store for normalization

# Normalize: divide each exp by final l
# (In actual FlashAttention, output accumulation handles this inline)
result = np.concatenate(all_exps) / l

# Verification
x_full = np.concatenate(x_blocks)
reference = np.exp(x_full - np.max(x_full))
reference = reference / reference.sum()

assert np.allclose(result, reference, atol=1e-6), "Online softmax result mismatch!"
print(f"Max absolute error: {np.max(np.abs(result - reference)):.2e}")
return result


# Test with 4 tiles of 16 elements each (simulating 64-element sequence)
np.random.seed(42)
x_blocks = [np.random.randn(16).astype(np.float32) for _ in range(4)]
result = online_softmax_demo(x_blocks)
print(f"Output sum: {result.sum():.6f}") # Should be 1.0

FlashAttention Tiling Algorithm

Now we put the online softmax together with tiling to compute attention without materializing N x N.

The Tiling Strategy

Split Q into row blocks of size BrB_r and split K, V into column blocks of size BcB_c. For each query block QiQ_i, iterate over all K, V blocks to accumulate the output.

The sizes are chosen so that each tile triple (Qi,Kj,Vj)(Q_i, K_j, V_j) fits in SRAM:

Br×d+Bc×d+Br×dMSRAMB_r \times d + B_c \times d + B_r \times d \leq M_{\text{SRAM}}

where MSRAMM_{\text{SRAM}} is the SRAM capacity (typically 96 KB - 200 KB per SM on modern GPUs).

Solving: Bc=M/(4d)B_c = \lfloor M / (4d) \rfloor (a common heuristic when Br=BcB_r = B_c).

The Algorithm (Pseudocode)

Input: Q, K, V in HBM, shape [N, d]
Output: O in HBM, shape [N, d]

Initialize O = zeros[N, d] in HBM
Initialize l = zeros[N] in HBM (running denominator)
Initialize m = -inf[N] in HBM (running max)

For each block Q_i (rows i*Br to (i+1)*Br):
Load Q_i from HBM to SRAM
Initialize O_i = zeros[Br, d] in registers
Initialize l_i = zeros[Br] in registers
Initialize m_i = -inf[Br] in registers

For each block K_j, V_j (cols j*Bc to (j+1)*Bc):
Load K_j, V_j from HBM to SRAM

# Compute tile of attention scores
S_ij = Q_i @ K_j^T # shape [Br, Bc], stays in SRAM/registers

# Online softmax update
m_ij = rowmax(S_ij) # shape [Br]
P_ij = exp(S_ij - m_ij[:, None]) # shape [Br, Bc]
l_ij = rowsum(P_ij) # shape [Br]

# Update running statistics
m_i_new = max(m_i, m_ij)
l_i_new = exp(m_i - m_i_new) * l_i + exp(m_ij - m_i_new) * l_ij

# Rescale accumulated output and add new contribution
O_i = diag(exp(m_i - m_i_new)) @ O_i
+ diag(exp(m_ij - m_i_new)) @ (P_ij @ V_j)

m_i = m_i_new
l_i = l_i_new

# Final normalization
O_i = diag(1/l_i) @ O_i

Write O_i to HBM

The entire N x N attention matrix is never written to HBM. The only HBM writes are the final output O, which is N×dN \times d - linear in N.

Memory Complexity Analysis

Standard attention: O(N2) HBM memory\text{Standard attention: } O(N^2) \text{ HBM memory}

FlashAttention: O(N) HBM memory\text{FlashAttention: } O(N) \text{ HBM memory}

Concretely for N=4096, d=64, float16:

Standard: N2×2 bytes=40962×2=32 MB per head\text{Standard: } N^2 \times 2 \text{ bytes} = 4096^2 \times 2 = 32 \text{ MB per head}

FlashAttention: N×d×2 bytes=4096×64×2=512 KB per head\text{FlashAttention: } N \times d \times 2 \text{ bytes} = 4096 \times 64 \times 2 = 512 \text{ KB per head}

A 64x memory reduction. For N=32768, the reduction is 4096x.

IO Complexity Analysis

How many bytes travel between HBM and chip? Let M = SRAM size, N = sequence length, d = head dimension.

Standard attention HBM reads/writes:

  • Write S: N2N^2 elements
  • Read S, write P: 2N22N^2 elements
  • Read P, read V, write O: N2+Nd+NdN^2 + Nd + Nd elements
  • Total: 4N2\approx 4N^2 elements = O(N2)O(N^2)

FlashAttention HBM reads/writes:

  • Read Q, K, V: 3Nd3Nd elements
  • For each of N/BrN/B_r query blocks, read all K and V: (N/Br)×N×d(N/B_r) \times N \times d = N2d/BrN^2 d / B_r elements... but BrM/dB_r \approx M/d, so this is N2d/(M/d)=N2d2/MN^2 d / (M/d) = N^2 d^2 / M
  • Write O: NdNd elements
  • Total: O(N2d2/M)O(N^2 d^2 / M) elements

HBM IO ratio: FlashAttentionStandard=N2d2/M4N2=d24M\text{HBM IO ratio: } \frac{\text{FlashAttention}}{\text{Standard}} = \frac{N^2 d^2 / M}{4N^2} = \frac{d^2}{4M}

For d=64, M=96KB (6144 float16 elements):

Ratio6424×614440962457616\text{Ratio} \approx \frac{64^2}{4 \times 6144} \approx \frac{4096}{24576} \approx \frac{1}{6}

FlashAttention reads/writes about 6x fewer bytes from HBM. In practice, measured speedups are 2-4x because there is overhead in the tiling logic and the roofline is not perfectly tight.

def compute_flash_attention_io(N, d, M_sram_bytes, dtype_bytes=2):
"""
Estimate HBM IO for standard vs FlashAttention.

Args:
N: sequence length
d: head dimension
M_sram_bytes: SRAM capacity in bytes
dtype_bytes: 2 for fp16, 4 for fp32
"""
M = M_sram_bytes // dtype_bytes # SRAM in elements

# Standard attention HBM IO (read + write counts in elements)
standard_io = 4 * N * N + 4 * N * d # S, P read+write + Q,K,V,O

# FlashAttention HBM IO
# Q, K, V read once: 3*N*d
# K, V re-read for each query block: (N/Br) * N * d, Br = M / (4d)
Br = M // (4 * d)
Bc = Br
n_blocks_q = (N + Br - 1) // Br
flash_io = (3 * N * d + # initial Q, K, V reads
n_blocks_q * N * d + # K re-reads per query block
n_blocks_q * N * d + # V re-reads per query block
N * d) # O write

standard_bytes = standard_io * dtype_bytes
flash_bytes = flash_io * dtype_bytes

print(f"Sequence length N={N}, head dim d={d}")
print(f"SRAM capacity: {M_sram_bytes/1024:.0f} KB = {M} elements")
print(f"Block size Br=Bc={Br}")
print(f"Standard attention HBM IO: {standard_bytes/1e6:.1f} MB")
print(f"FlashAttention HBM IO: {flash_bytes/1e6:.1f} MB")
print(f"IO reduction: {standard_bytes/flash_bytes:.1f}x")
print()


# A100 has 192KB L2 per SM; using 96KB as effective SRAM for tiling
compute_flash_attention_io(N=4096, d=64, M_sram_bytes=96*1024)
compute_flash_attention_io(N=8192, d=128, M_sram_bytes=96*1024)
compute_flash_attention_io(N=32768, d=64, M_sram_bytes=96*1024)

FlashAttention 2: Warp-Level Improvements

FlashAttention v1 was already a major step. FlashAttention 2 (Dao, 2023) extracted another 2x by fixing inefficiencies in how work was distributed across GPU warps.

Problem 1: Parallelism was Sequence-Length Limited

FA1 parallelized across heads and batch size, but within each head, work was sequential across Q blocks. When sequence length is short (e.g., N=512), there are not enough blocks to keep all SMs busy. FA2 adds parallelism across Q blocks, so even short sequences saturate the GPU.

Problem 2: Non-Matmul FLOPS Killed Tensor Core Utilization

The A100's Tensor Cores are 16x faster for matmul than for other ops. FA1 had too many non-matmul operations: the online softmax updates, rescaling factors, and normalization. These ran on regular CUDA cores at 19.5 TFLOPS instead of Tensor Cores at 312 TFLOPS.

FA2 reduced non-matmul FLOPs by reorganizing the computation. The key change: rather than computing the softmax normalization tile by tile (which requires many exp and division ops per tile), FA2 batches the rescaling and defers normalization to the end of each Q block.

Problem 3: Warp Communication Overhead

FA1 split K/V tiles across warps, requiring warps to communicate their partial results via shared memory. This created synchronization barriers. FA2 splits Q tiles across warps instead - each warp owns its query rows completely and independently accumulates the output. No inter-warp synchronization needed for the inner loop.

FA1 warp strategy (shared K/V):
Warp 0: owns K_j[:Bc/2], V_j[:Bc/2] → partial O
Warp 1: owns K_j[Bc/2:], V_j[Bc/2:] → partial O
→ must sum partial O across warps (shared memory communication)

FA2 warp strategy (independent Q rows):
Warp 0: owns Q_i[:Br/4] rows → full O rows, no coordination needed
Warp 1: owns Q_i[Br/4:Br/2] rows → full O rows, no coordination
→ warps are completely independent, no synchronization barriers

FA2 Speedup in Practice

ConfigurationFA1 throughputFA2 throughputSpeedup
A100, N=2048, d=64, causal149 TFLOPS227 TFLOPS1.52x
A100, N=4096, d=128, causal198 TFLOPS335 TFLOPS1.69x
A100, N=8192, d=128, causal213 TFLOPS368 TFLOPS1.73x

FlashAttention 3: H100-Specific Optimizations

FlashAttention 3 (Shah et al., 2024) targets the H100 specifically, exploiting hardware features that did not exist on A100.

Asynchronous Memory Copies (TMA)

H100 has the Tensor Memory Accelerator (TMA), a dedicated hardware unit that moves data between HBM and shared memory without using CUDA cores. In FA2, memory loads consumed compute resources. In FA3, TMA handles all K, V loading asynchronously in the background while CUDA cores run the matmul for the previous tile.

This is the "ping-pong" scheduling: while GEMM processes tile j, TMA pre-fetches tile j+1 into a second shared memory buffer. By the time GEMM finishes tile j, tile j+1 is already waiting in SRAM.

Warp Specialization

FA3 dedicates different warps to different tasks. "Producer warps" exclusively run TMA copy operations. "Consumer warps" exclusively run GEMM and softmax. This specialization allows both to run at full speed concurrently, compared to FA2 where each warp alternated between copying and computing.

FP8 Support

H100 supports FP8 (E4M3) Tensor Core operations at 2x the throughput of FP16. FA3 adds FP8 attention, which requires "incoherent processing" - a technique that handles the reduced precision of FP8 without sacrificing output quality by applying per-block scaling factors.

FA3 Performance on H100

ConfigurationFA2 throughputFA3 throughputSpeedup
H100, N=8192, d=128, fp16670 TFLOPS1000 TFLOPS1.49x
H100, N=8192, d=128, fp8-1300 TFLOPS-
H100, N=32768, d=128, fp16720 TFLOPS1100 TFLOPS1.53x

Using FlashAttention in PyTorch

The Easy Path: scaled_dot_product_attention

Since PyTorch 2.0, torch.nn.functional.scaled_dot_product_attention automatically dispatches to FlashAttention when conditions are met:

import torch
import torch.nn.functional as F

# This automatically uses FlashAttention on supported hardware
def efficient_attention(q, k, v, causal=False):
"""
q, k, v: [batch, heads, seq_len, head_dim]
All must be fp16 or bf16 for FlashAttention to activate.
"""
# PyTorch dispatches to Flash Attention when:
# 1. CUDA is available
# 2. dtype is fp16 or bf16 (NOT float32)
# 3. head_dim is 64 or 128 (or 256 in newer versions)
# 4. No custom attention mask (or causal mask only)

with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False, # disable fallback math kernel
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=None # defaults to 1/sqrt(head_dim)
)
return output


# Verify FlashAttention is being used
def check_flash_attention_active(seq_len=2048, head_dim=64, num_heads=8, batch=2):
device = 'cuda'
dtype = torch.float16 # Required for FlashAttention

q = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)
k = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)
v = torch.randn(batch, num_heads, seq_len, head_dim, device=device, dtype=dtype)

# Check which kernels are available
print("Flash attention available:", torch.backends.cuda.flash_sdp_enabled())
print("Math kernel available:", torch.backends.cuda.math_sdp_enabled())
print("Mem-efficient available:", torch.backends.cuda.mem_efficient_sdp_enabled())

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
try:
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(f"FlashAttention succeeded. Output shape: {out.shape}")
except RuntimeError as e:
print(f"FlashAttention NOT used: {e}")

return out


if torch.cuda.is_available():
check_flash_attention_active()

When FlashAttention Does NOT Activate

import torch
import torch.nn.functional as F

# Case 1: float32 - FlashAttention requires fp16/bf16
q_fp32 = torch.randn(1, 8, 1024, 64, device='cuda', dtype=torch.float32)
# This falls back to math kernel silently. Profile shows slow N^2 attention.

# Case 2: Arbitrary attention mask - Flash only supports causal mask
custom_mask = torch.zeros(1024, 1024, device='cuda', dtype=torch.bool)
custom_mask[:512, 512:] = True # Some custom pattern
# Flash can NOT handle arbitrary masks. Falls back to standard attention.

# Case 3: Head dim not in supported set
q_bad_dim = torch.randn(1, 8, 1024, 96, device='cuda', dtype=torch.float16)
# head_dim=96 is not supported in older Flash versions (64, 128, 256 are safe)

# How to detect which kernel was used:
# Use torch.profiler or nsight compute.
# In torch.profiler, look for "flash_fwd" vs "scaled_dot_product_attention_flash_attention_kernel"
# vs slower "native_sdp" or "aten::_softmax" patterns.

def audit_attention_kernel(q, k, v):
"""Profile which kernel actually runs."""
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True
) as prof:
_ = F.scaled_dot_product_attention(q, k, v, is_causal=True)

# Look for flash_fwd in kernel names
events = prof.key_averages()
flash_found = any('flash' in e.key.lower() for e in events)
print(f"FlashAttention kernel detected: {flash_found}")
for e in events:
if 'attention' in e.key.lower() or 'flash' in e.key.lower():
print(f" {e.key}: {e.cuda_time_total/1e3:.2f}ms")

The flash_attn Package

The flash_attn package provides direct access to the FlashAttention kernels with more control:

# Install: pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention

# Basic usage - qkv packed (common for self-attention)
def flash_self_attention_packed(qkv, causal=True, softmax_scale=None):
"""
qkv: [batch, seq_len, 3, num_heads, head_dim]
Returns: [batch, seq_len, num_heads, head_dim]
"""
return flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
softmax_scale=softmax_scale, # None -> 1/sqrt(head_dim)
causal=causal,
)


# Separate Q, K, V (cross-attention or when shapes differ)
def flash_cross_attention_explicit(q, k, v, causal=False, softmax_scale=None):
"""
q: [batch, seq_len_q, num_heads, head_dim]
k: [batch, seq_len_k, num_heads, head_dim]
v: [batch, seq_len_k, num_heads, head_dim]
Returns: [batch, seq_len_q, num_heads, head_dim]
"""
return flash_attn_func(
q, k, v,
dropout_p=0.0,
softmax_scale=softmax_scale,
causal=causal,
window_size=(-1, -1), # full attention
)


# Sliding window attention (FA2 feature - useful for long context)
def flash_sliding_window_attention(q, k, v, window_size=512):
"""
Each token attends only to the previous window_size tokens.
Memory and compute scale as O(N * window_size) not O(N^2).
"""
return flash_attn_func(
q, k, v,
causal=True,
window_size=(window_size, 0), # (left, right) window
)


# Drop-in module replacement
import torch.nn as nn

class FlashMultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, causal=True):
super().__init__()
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.causal = causal

self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.flash_attn = FlashSelfAttention(causal=causal)

def forward(self, x):
B, T, C = x.shape
# x: [B, T, C]
qkv = self.qkv_proj(x)
# Reshape to [B, T, 3, H, D]
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
# flash_attn expects fp16 or bf16
qkv = qkv.to(torch.float16)

out = self.flash_attn(qkv) # [B, T, H, D]
out = out.reshape(B, T, C).to(x.dtype)
return self.out_proj(out)

Architecture Diagram: FlashAttention Tiling


Production Engineering Notes

Memory Savings vs Standard Attention at Scale

def attention_memory_comparison(N, d, num_heads, num_layers, batch_size, dtype_bytes=2):
"""
Compare peak HBM usage between standard and FlashAttention.
All values in GB.
"""
# Standard attention: must hold S and P matrices
# Per head per layer: 2 * N^2 floats (S and P)
standard_attn_bytes = (2 * N * N * dtype_bytes *
num_heads * num_layers * batch_size)

# FlashAttention: only needs Q, K, V, O
# Per head per layer: 4 * N * d floats
flash_attn_bytes = (4 * N * d * dtype_bytes *
num_heads * num_layers * batch_size)

print(f"N={N}, d={d}, {num_heads} heads, {num_layers} layers, batch={batch_size}")
print(f"Standard attention matrices: {standard_attn_bytes/1e9:.1f} GB")
print(f"FlashAttention activations: {flash_attn_bytes/1e9:.1f} GB")
print(f"Memory reduction: {standard_attn_bytes/flash_attn_bytes:.0f}x")
print()

# LLaMA-2 7B parameters, inference scenarios
attention_memory_comparison(N=4096, d=128, num_heads=32, num_layers=32, batch_size=1)
attention_memory_comparison(N=32768, d=128, num_heads=32, num_layers=32, batch_size=1)
attention_memory_comparison(N=128000, d=128, num_heads=32, num_layers=32, batch_size=1)

Verifying FlashAttention is Active in Training

import torch
import torch.nn.functional as F
from contextlib import contextmanager

@contextmanager
def require_flash_attention():
"""Context manager that raises if FlashAttention is not used."""
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False, # disable fallback
enable_mem_efficient=False
):
yield

def verify_model_uses_flash(model, sample_input):
"""
Run forward pass and verify FlashAttention kernels were invoked.
Raises RuntimeError if standard attention fallback occurs.
"""
with require_flash_attention():
try:
output = model(sample_input)
print("FlashAttention verified - forward pass succeeded under flash-only mode")
return output
except RuntimeError as e:
print(f"FALLBACK DETECTED: {e}")
print("Check: dtype must be fp16/bf16, head_dim in [64, 128, 256], no custom masks")
raise

Common Production Failure: Wrong dtype

The single most common reason FlashAttention silently falls back to standard attention is float32 inputs. Models often use float32 by default, and the fallback is silent - you see no error, just slower performance.

# WRONG: float32 silently falls back to standard attention
model = TransformerModel(...)
output = model(input_ids) # If model weights are fp32, attention is fp32 too

# RIGHT: explicitly cast to bf16 or fp16
model = model.to(torch.bfloat16) # bf16 is preferred for training stability
output = model(input_ids.cuda())

# Or use automatic mixed precision:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = model(input_ids)

Common Mistakes

:::danger FlashAttention Does Not Activate for float32 F.scaled_dot_product_attention silently falls back to the standard O(N^2) memory kernel when inputs are float32. This is the most common production mistake - the code runs, the output is correct, but memory usage is 64x higher than expected and throughput is 5x lower. Always use fp16 or bf16 for attention. Use torch.autocast or explicitly cast model parameters. :::

:::danger Using Unsupported Head Dimensions FlashAttention is implemented for specific head dimensions: 64, 128, and 256 (varies by version). If your model uses d=96 or d=160, it silently falls back. This is especially common when people experiment with custom model architectures. Always print the head dimension and verify it matches a supported value before benchmarking. :::

:::warning Custom Attention Masks Break FlashAttention FlashAttention only supports causal masking natively. Arbitrary attention masks (e.g., block-diagonal for document packing, sliding window via manual masking, ALiBi-style biases added before softmax) all cause fallback to the standard kernel. Use the window_size parameter in flash_attn_func for sliding window attention, and use FA's built-in causal flag rather than constructing a manual causal mask. :::

:::warning Backward Pass Memory is Higher Than Forward FlashAttention's O(N) memory claim applies to the forward pass. During backward, the kernel needs to recompute S and P tiles for each Q block to compute gradients. This recomputation is done on the fly (no materialization), but the backward pass does about 2-3x more FLOPS than the forward pass. Peak activation memory during training is still much lower than standard attention, but the speedup is smaller in backward than forward. :::

:::tip Checking FlashAttention Version at Runtime

import flash_attn
print(flash_attn.__version__) # Should be 2.x for production use
# Also check:
import torch
print(torch.__version__)
# torch >= 2.0 required for F.scaled_dot_product_attention dispatch

:::


Interview Questions

Q1: Why does standard attention have O(N²) memory and how does FlashAttention reduce this to O(N)?

Standard attention computes S=QKTS = QK^T where S has shape [N, N]. With N=4096, S alone is 4096^2 x 2 bytes = 32 MB per head. This matrix must be written to HBM (because it's too large for SRAM), then read back for the softmax, then read again for the output projection. Total HBM traffic is approximately 4N24N^2 elements.

FlashAttention avoids materializing S entirely. It tiles Q into blocks of size BrB_r and K, V into blocks of size BcB_c, chosen so that a single tile (Qi,Kj,Vj)(Q_i, K_j, V_j) fits in SRAM. The key enabler is online softmax: you can accumulate the softmax output incrementally as tiles arrive, maintaining only a running max and running sum, without ever seeing the full row at once.

The only thing written to HBM is the final output O of shape [N, d], which is O(N). The N x N matrix S exists only transiently in SRAM/registers, never in HBM.

The IO complexity is O(N2d2/M)O(N^2 d^2 / M) where M is SRAM size, versus O(N2)O(N^2) for standard attention - a reduction factor of d2/M1/6d^2 / M \approx 1/6 for typical values.


Q2: Derive the online softmax update rule. Why is the rescaling factor exp(m_prev - m_new) necessary?

Say you have computed partial softmax statistics over the first j tiles: running max mm and running denominator =ijBcexim\ell = \sum_{i \leq j \cdot B_c} e^{x_i - m}.

Now you see tile j+1 with max m=max(x(j+1)-th tile)m' = \max(x_{(j+1)\text{-th tile}}).

The new global max is mnew=max(m,m)m_{\text{new}} = \max(m, m').

The previous partial sum \ell was computed relative to the old max mm: =ijBcexim\ell = \sum_{i \leq j \cdot B_c} e^{x_i - m}

But we need the partial sum relative to the new max mnewm_{\text{new}}: new, prev part=ijBceximnew=ieximemmnew=emmnew\ell_{\text{new, prev part}} = \sum_{i \leq j \cdot B_c} e^{x_i - m_{\text{new}}} = \sum_{i} e^{x_i - m} \cdot e^{m - m_{\text{new}}} = \ell \cdot e^{m - m_{\text{new}}}

This is the rescaling factor: emprevmnewe^{m_{\text{prev}} - m_{\text{new}}}. When the new tile has a higher maximum, mprevmnew<0m_{\text{prev}} - m_{\text{new}} < 0, so the factor is less than 1 - we scale down the old partial sum because the new maximum means old values were "over-weighted."

The accumulated output O_i must be similarly rescaled: OiOiemprevmnewO_i \leftarrow O_i \cdot e^{m_{\text{prev}} - m_{\text{new}}}.

After processing all tiles, we divide O_i by the final \ell to get the normalized output. The result is identical to computing full softmax first and then doing the weighted sum.


Q3: What specific improvements did FlashAttention 2 make over FlashAttention 1?

Three categories of improvement:

1. Parallelism across Q blocks - FA1 parallelized over batch and head dimensions only. Within a head, it processed Q blocks sequentially. FA2 assigns different Q blocks to different thread blocks, which run in parallel on different SMs. This eliminates the sequential bottleneck for short sequences.

2. Warp partitioning for K/V processing - FA1 split K/V tiles across warps, requiring warps to sum their partial outputs via shared memory (inter-warp communication). FA2 assigns non-overlapping rows of Q to each warp. Each warp independently computes its output rows with full access to all K/V. No inter-warp communication in the inner loop.

3. Reduced non-matmul FLOPs - The A100 Tensor Cores run at 312 TFLOPS for fp16 matmul but only 19.5 TFLOPS for non-matmul ops. FA1 had many exp, division, and comparison operations in the inner loop. FA2 restructures to minimize these operations - for example, deferring the final 1/1/\ell normalization to after the K/V loop rather than applying it each tile.

Combined, these produce about 2x throughput improvement over FA1 on A100.


Q4: How would you verify that FlashAttention is actually being used in your model?

Three complementary approaches:

Method 1: Force flash-only mode and observe failures

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
out = model(x)
# If this raises RuntimeError, standard attention was being used as fallback.

Method 2: Profile with PyTorch profiler

with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
out = model(x)
# Look for 'flash_fwd' in the kernel names. Standard attention shows 'aten::_softmax'.

Method 3: Memory usage signature - FlashAttention should use roughly 4Nd4Nd bytes for attention activations, not N2N^2 bytes. At N=4096 with 32 heads, standard attention uses ~2GB, FlashAttention uses ~32MB. Use torch.cuda.memory_stats() before and after the attention forward pass to measure.

Common reasons FlashAttention is not used:

  • dtype is float32 (must be fp16 or bf16)
  • head_dim is not in the supported set (64, 128, 256)
  • Custom attention mask is provided (only causal mask is natively supported)
  • PyTorch version is older than 2.0 (no SDPA dispatch)
  • flash_attn package version mismatch with installed CUDA version

Q5: What is the IO complexity of FlashAttention and how does it compare to standard attention?

Let N = sequence length, d = head dimension, M = SRAM capacity in elements.

Standard attention:

  • Write S: N2N^2 elements
  • Read S, compute softmax, write P: 2N22N^2 elements
  • Read P and V, write O: N2+NdN^2 + Nd elements
  • Total: Θ(N2)\Theta(N^2) HBM IO

FlashAttention: The outer loop has Tr=N/BrT_r = \lceil N/B_r \rceil iterations. In each, we load Q_i once (BrdB_r d elements) and iterate over Tc=N/BcT_c = \lceil N/B_c \rceil K/V blocks. Each inner iteration loads BcdB_c d elements. Total:

HBM reads=TrBrd+TrTc2Bcd+Nd\text{HBM reads} = T_r \cdot B_r d + T_r \cdot T_c \cdot 2 B_c d + Nd =Nd+NBrNBc2Bcd+Nd= Nd + \frac{N}{B_r} \cdot \frac{N}{B_c} \cdot 2 B_c d + Nd =2Nd+2N2dBr= 2Nd + \frac{2N^2 d}{B_r}

With Br=O(M/d)B_r = O(M/d): total = O(N2d2/M)O(N^2 d^2 / M).

For A100 with M = 40 KB SRAM per SM and d=64: Ratio=N2d2/MN2=d2M=4096204800.2\text{Ratio} = \frac{N^2 d^2/M}{N^2} = \frac{d^2}{M} = \frac{4096}{20480} \approx 0.2

FlashAttention uses about 5x fewer HBM bytes. In practice (accounting for tiling overhead and non-matmul ops), measured speedups are 2-4x.


Q6: FlashAttention 3 targets H100 specifically. What H100 features does it exploit that A100 lacks?

Three H100-specific hardware features:

1. Tensor Memory Accelerator (TMA) - H100 has a dedicated hardware unit for async data movement between HBM and shared memory. FA3 uses TMA to pre-fetch the next K/V tile into a second shared memory buffer while the current tile's GEMM runs. This overlaps memory transfer with compute, eliminating the "stall and load" pattern of FA2. FA3 uses a "ping-pong" schedule: GEMM processes buffer A while TMA fills buffer B, then they swap.

2. Warp group specialization - H100 supports "warp group async" where different warp groups run different instruction mixes concurrently. FA3 designates "producer warps" that exclusively run TMA operations and "consumer warps" that exclusively run GEMM + softmax. Both run concurrently at full efficiency, unlike FA2 where each warp alternated between data copying and compute.

3. FP8 Tensor Cores - H100 FP8 Tensor Cores achieve 2x the throughput of FP16. FA3 supports FP8 attention computation with "incoherent processing" - a technique that handles accumulated FP8 quantization error through per-block scaling factors without requiring full FP16 intermediate accumulation.

Result: FA3 achieves 1000+ TFLOPS on H100 at fp16, reaching about 75% of H100's theoretical peak attention throughput.


Summary

FlashAttention is one of the most important algorithmic innovations in deep learning systems of the past decade. The key ideas are simple once you see them:

  1. Standard attention is memory-bound because the N x N attention matrix is too large for SRAM and must live in HBM. Every read/write is slow.

  2. Tiling avoids materialization by processing Q in blocks and streaming K, V through SRAM. The N x N matrix never exists as a contiguous allocation.

  3. Online softmax enables exact computation despite tiling, by maintaining a running max and sum that can be updated incrementally as each K/V tile arrives.

  4. IO complexity drops from O(N2)O(N^2) to O(N2d2/M)O(N^2 d^2 / M) - roughly a 5-6x reduction for typical hyperparameters, translating to 2-4x real-world speedup.

  5. FA2 and FA3 extracted further improvements through better parallelism (FA2), hardware-specific async pipelines (FA3 on H100), and FP8 support.

In production, FlashAttention is not optional for sequences above 2048 tokens. It is the difference between a model that fits in memory and one that does not.

© 2026 EngineersOfAI. All rights reserved.