Tiling and Shared Memory Optimization
Reading time: ~45 min · Interview relevance: Very High · Target roles: CUDA Developer, Kernel Engineer, ML Systems Engineer
The naive matrix multiply reads every element of A and B repeatedly from HBM. The tiled version reads each element exactly once from HBM and reuses it entirely from shared memory. Same arithmetic, seven times the throughput. This is what data reuse looks like in practice.
The 4 TFLOPS Benchmark
The number appeared in the profiling report before anyone even opened Nsight: 4.2 TFLOPS on an A100 that is rated for 77.6 TFLOPS FP32. The matrix multiply was achieving 5.4% of peak. Not a power issue, not a thermal throttle, not a driver problem. Just a naive implementation that had never been touched since it was written three years earlier to validate correctness.
The A100 has 40MB of HBM2 bandwidth at 2 TB/s. The naive matrix multiply for 4096x4096 matrices needs to read matrix A and matrix B a combined total of bytes - approximately 274 GB of HBM reads for one multiplication. The memory system cannot supply data fast enough. The CUDA cores sit idle 95% of the time waiting for their next operands.
A colleague dropped a tiled implementation into the same benchmark slot. Same matrix sizes. Same hardware. The profiling report updated: 28.4 TFLOPS. Six times the throughput with the same math and the same GPU. The only change was memory access pattern - specifically, loading small tiles of A and B into the 192KB shared memory on each SM, computing the partial dot products entirely from shared memory, then loading the next tiles.
The reason this works is data reuse. In the naive kernel, every output element needs the full row and the full column . If two threads in the same block compute and , they both need the same row of A. Without tiling, both threads independently read the same row from HBM. With tiling, that row is loaded once into shared memory and both threads read from there.
The tiling pattern is not specific to matrix multiply. It appears in every algorithm where a subset of the input is accessed multiple times: convolution accesses neighborhoods repeatedly, attention applies the same KV pairs to multiple queries, reduction operations scan arrays multiple times. Once you understand tiling for matrix multiply, you can apply the pattern to any of these operations.
This lesson walks through the full progression: why naive GEMM fails, how tiling fixes it, the math of the reuse ratio, a complete CUDA implementation, double buffering with cp.async for overlapping compute and memory, a Triton implementation that achieves comparable performance in a fraction of the code, and the extension to attention.
Why This Exists
Before shared memory was programmable (pre-G80, pre-CUDA), tiling was not possible for general-purpose workloads. The GPU's caches were fixed-function and managed by hardware. You could not explicitly manage what went into L1 or when.
CUDA's introduction in 2007 exposed shared memory as a programmer-controlled scratchpad. The concept existed in signal processing and scientific computing long before GPUs - blocked matrix algorithms date to the 1970s and were developed to fit matrices into cache on early vector processors. The fundamental observation is unchanged: if you can fit a block of data into fast memory and reuse it there, you reduce traffic from slow memory.
The seminal paper is "Goto and Van De Geijn, 2008 - Anatomy of High-Performance Matrix Multiplication" which proved that the key to near-peak FLOP/s is not clever instruction scheduling but cache (or shared memory) reuse. This paper is why cuBLAS, BLIS, OpenBLAS, and every production GEMM library use blocked algorithms.
CUTLASS, NVIDIA's open-source CUDA template library for GEMM, was first published in 2017 and represents the production-quality version of everything in this lesson. FlashAttention (Dao et al., 2022) applied the same tiling principles to attention and achieved the first attention implementation that was both correct in numerics and fast in practice by staying resident in shared memory.
Core Concepts
The Memory Reuse Ratio
The fundamental question tiling answers: how many times is each element of A read from HBM in a naive matrix multiply?
For computing where all matrices are :
- Output matrix has elements
- Each output requires one inner product:
- Computing reads elements from row of and elements from column of
- Across all outputs, row of A is read times (once for each column )
- Total reads from HBM: elements for A, same for B
For : billion element reads. At 4 bytes per float: 549 GB of HBM reads for matrices totaling only MB.
Now with tiling using tile size :
- Partition , , and into blocks
- For each output tile of , we read tiles of and tiles of
- Each tile is loaded into shared memory once and reused times for the inner products within the output tile
- Total HBM reads for A: elements
The reuse ratio is exactly . With , you read fewer bytes from HBM. This is the complete explanation for why tiled GEMM runs faster on the A100 - the arithmetic intensity increases from FLOP/byte to FLOP/byte (before accounting for B symmetrically).
The A100 FP32 roofline ridge point is approximately FLOP/byte. With : arithmetic intensity FLOP/byte - still memory-bound but much better. With : intensity FLOP/byte - approaching compute-bound. Production kernels targeting near-peak FLOP/s use register tiling on top of shared memory tiling to push intensity above the ridge point.
Tile Size Selection
Larger tiles mean more reuse and higher arithmetic intensity, but larger tiles consume more shared memory per block. More shared memory per block means fewer blocks can be resident per SM, which reduces occupancy.
The tradeoff:
(Factor of 2 because we need tiles for both A and B)
| Tile Size | Smem/block (FP32) | Blocks/SM (100KB) | Occupancy approx |
|---|---|---|---|
| 16x16 | 2 KB | 50 blocks | high |
| 32x32 | 8 KB | 12 blocks | moderate-high |
| 64x64 | 32 KB | 3 blocks | moderate |
| 128x128 | 128 KB | 0 blocks | not feasible |
The practical sweet spot is 32x32 for FP32 and 64x64 for FP16/BF16 (elements are half the size so 64x64 FP16 uses 32KB, same as 32x32 FP32). Production GEMM kernels from CUTLASS use 128x128 thread blocks with register tiling to overcome the occupancy limit - each thread accumulates a small register tile, dramatically increasing arithmetic intensity above what shared memory tiling alone achieves.
The hardware constraint on the A100/H100 that makes 32x32 tiles natural: 32 threads in a warp, 32-element vectors. A 32x32 tile naturally maps to one thread per element, no waste, perfect alignment with the warp's memory access width.
The Two-Phase Tiled Computation Loop
The tiled matmul algorithm has a clean structure that repeats for each tile:
- Load phase: threads cooperatively load a tile of A and a tile of B from global to shared memory
- Sync:
__syncthreads()to ensure all data is loaded before any thread reads it - Compute phase: each thread computes its partial dot product contribution from the shared memory tile
- Sync:
__syncthreads()to ensure all reads from shared memory finish before the next load overwrites it - Advance to the next tile along the K dimension, repeat
The two __syncthreads() calls are non-negotiable. The first ensures no thread reads a tile before it is fully written. The second ensures no thread overwrites a tile while another thread is still reading it.
Double Buffering - Overlapping Load and Compute
The basic two-phase pattern has a serialization: load tile K, sync, compute tile K, sync, load tile K+1, sync, compute K+1... The compute during tile K is blocked waiting for load of tile K+1 to finish.
Double buffering breaks this serialization. Allocate two shared memory buffers (A0/B0 and A1/B1):
- While computing tile K from buffer 0, asynchronously load tile K+1 into buffer 1
- When compute of K finishes, swap buffer pointers and repeat
The async load uses cp.async (introduced on Ampere, A100). This instruction copies data from global memory to shared memory without routing through registers. Critically, it does not stall the issuing warp - the copy proceeds in the background while the warp continues issuing compute instructions.
The software pipeline:
Cycle 0: issue cp.async for tile 0 into buffer A
Cycle 1-N: wait for tile 0 to arrive (cp.async.wait_group)
Cycle N+: compute tile 0 from buffer A, simultaneously issue cp.async for tile 1 into buffer B
compute tile 1 from buffer B, simultaneously issue cp.async for tile 2 into buffer A
...
With perfect double buffering, HBM latency is completely hidden by compute. The SM spends zero cycles waiting for data if the compute time for one tile exceeds the HBM latency for the next tile.
Full CUDA Implementation
Naive GEMM - Baseline
#include <cuda_runtime.h>
#include <stdio.h>
// Naive matrix multiply - no shared memory, no tiling
// C = A * B where A is M x K, B is K x N, C is M x N
__global__
void gemm_naive(
const float* __restrict__ A, // M x K
const float* __restrict__ B, // K x N
float* __restrict__ C, // M x N
int M, int K, int N
) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= M || col >= N) return;
float acc = 0.0f;
for (int k = 0; k < K; k++) {
// These two reads hit global memory every iteration
// A[row, k] is reused N times (once per column of B) but loaded fresh each time
// B[k, col] is reused M times (once per row of A) but loaded fresh each time
acc += A[row * K + k] * B[k * N + col];
}
C[row * N + col] = acc;
}
Tiled GEMM - With Shared Memory
#define TILE_SIZE 32
__global__
__launch_bounds__(TILE_SIZE * TILE_SIZE, 2)
void gemm_tiled(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int M, int K, int N
) {
// Shared memory tiles: one for A, one for B
__shared__ float s_A[TILE_SIZE][TILE_SIZE];
__shared__ float s_B[TILE_SIZE][TILE_SIZE];
int tx = threadIdx.x; // column within block (0..TILE_SIZE-1)
int ty = threadIdx.y; // row within block (0..TILE_SIZE-1)
int row = blockIdx.y * TILE_SIZE + ty; // global row of output
int col = blockIdx.x * TILE_SIZE + tx; // global col of output
float acc = 0.0f;
// Iterate over tiles along the K dimension
int num_tiles = (K + TILE_SIZE - 1) / TILE_SIZE;
for (int tile = 0; tile < num_tiles; tile++) {
// --- Phase 1: Load tile ---
// Each thread loads one element of A and one element of B
int a_col = tile * TILE_SIZE + tx; // column of A to load
int b_row = tile * TILE_SIZE + ty; // row of B to load
// Bounds checking for non-square matrices
s_A[ty][tx] = (row < M && a_col < K) ? A[row * K + a_col] : 0.0f;
s_B[ty][tx] = (b_row < K && col < N) ? B[b_row * N + col] : 0.0f;
// Ensure all threads have finished loading before any thread reads
__syncthreads();
// --- Phase 2: Compute partial dot product from shared memory ---
// TILE_SIZE multiplications and additions, all from shared memory (~4 cycle latency)
// vs global memory (~400 cycle latency)
#pragma unroll
for (int k = 0; k < TILE_SIZE; k++) {
acc += s_A[ty][k] * s_B[k][tx];
}
// Ensure all threads have finished reading before the next tile overwrites shared memory
__syncthreads();
}
// Write result
if (row < M && col < N) {
C[row * N + col] = acc;
}
}
Tiled GEMM with Double Buffering and cp.async
#include <cuda/pipeline>
#define TILE_SIZE 32
__global__
__launch_bounds__(TILE_SIZE * TILE_SIZE, 2)
void gemm_double_buffered(
const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int M, int K, int N
) {
// Two shared memory buffers for double buffering
// Buffer 0 and Buffer 1 alternate: while computing from one, loading into the other
__shared__ float s_A[2][TILE_SIZE][TILE_SIZE];
__shared__ float s_B[2][TILE_SIZE][TILE_SIZE];
int tx = threadIdx.x;
int ty = threadIdx.y;
int row = blockIdx.y * TILE_SIZE + ty;
int col = blockIdx.x * TILE_SIZE + tx;
float acc = 0.0f;
int num_tiles = (K + TILE_SIZE - 1) / TILE_SIZE;
// Create a cuda::pipeline for asynchronous copy coordination
auto pipe = cuda::make_pipeline();
// Prefetch the first tile into buffer 0 (async)
pipe.producer_acquire();
{
int a_col = tx;
int b_row = ty;
if (row < M && a_col < K) {
cuda::memcpy_async(
&s_A[0][ty][tx], &A[row * K + a_col],
sizeof(float), pipe
);
} else {
s_A[0][ty][tx] = 0.0f;
}
if (b_row < K && col < N) {
cuda::memcpy_async(
&s_B[0][ty][tx], &B[b_row * N + col],
sizeof(float), pipe
);
} else {
s_B[0][ty][tx] = 0.0f;
}
}
pipe.producer_commit();
for (int tile = 0; tile < num_tiles; tile++) {
int current_buf = tile & 1; // 0, 1, 0, 1, ...
int next_buf = (tile + 1) & 1;
// Prefetch next tile into the other buffer (async, no stall)
if (tile + 1 < num_tiles) {
pipe.producer_acquire();
int a_col = (tile + 1) * TILE_SIZE + tx;
int b_row = (tile + 1) * TILE_SIZE + ty;
if (row < M && a_col < K) {
cuda::memcpy_async(
&s_A[next_buf][ty][tx], &A[row * K + a_col],
sizeof(float), pipe
);
} else {
s_A[next_buf][ty][tx] = 0.0f;
}
if (b_row < K && col < N) {
cuda::memcpy_async(
&s_B[next_buf][ty][tx], &B[b_row * N + col],
sizeof(float), pipe
);
} else {
s_B[next_buf][ty][tx] = 0.0f;
}
pipe.producer_commit();
}
// Wait for current tile to arrive in shared memory
pipe.consumer_wait();
__syncthreads(); // all threads see the same committed data
// Compute from current buffer while next tile loads in background
#pragma unroll
for (int k = 0; k < TILE_SIZE; k++) {
acc += s_A[current_buf][ty][k] * s_B[current_buf][k][tx];
}
pipe.consumer_release();
__syncthreads();
}
if (row < M && col < N) {
C[row * N + col] = acc;
}
}
Benchmark Harness
#include <cuda_runtime.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#define CHECK_CUDA(call) do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \
cudaGetErrorString(err)); \
exit(1); \
} \
} while (0)
double benchmark_gemm(
void (*kernel)(const float*, const float*, float*, int, int, int),
int M, int K, int N,
dim3 grid, dim3 block,
int warmup_iters, int bench_iters
) {
size_t bytes_A = (size_t)M * K * sizeof(float);
size_t bytes_B = (size_t)K * N * sizeof(float);
size_t bytes_C = (size_t)M * N * sizeof(float);
float *d_A, *d_B, *d_C;
CHECK_CUDA(cudaMalloc(&d_A, bytes_A));
CHECK_CUDA(cudaMalloc(&d_B, bytes_B));
CHECK_CUDA(cudaMalloc(&d_C, bytes_C));
// Random init on device
// (simplified - in practice use cuRAND)
CHECK_CUDA(cudaMemset(d_A, 0x3f, bytes_A));
CHECK_CUDA(cudaMemset(d_B, 0x3f, bytes_B));
// Warmup
for (int i = 0; i < warmup_iters; i++) {
kernel<<<grid, block>>>(d_A, d_B, d_C, M, K, N);
}
CHECK_CUDA(cudaDeviceSynchronize());
// Timed benchmark
cudaEvent_t start, stop;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start));
for (int i = 0; i < bench_iters; i++) {
kernel<<<grid, block>>>(d_A, d_B, d_C, M, K, N);
}
CHECK_CUDA(cudaEventRecord(stop));
CHECK_CUDA(cudaEventSynchronize(stop));
float ms;
CHECK_CUDA(cudaEventElapsedTime(&ms, start, stop));
double avg_ms = ms / bench_iters;
double flops = 2.0 * M * K * N; // multiply-add = 2 FLOP
double tflops = flops / (avg_ms * 1e-3) / 1e12;
CHECK_CUDA(cudaFree(d_A));
CHECK_CUDA(cudaFree(d_B));
CHECK_CUDA(cudaFree(d_C));
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
return tflops;
}
int main() {
int M = 4096, K = 4096, N = 4096;
dim3 naive_block(16, 16);
dim3 naive_grid((N + 15) / 16, (M + 15) / 16);
dim3 tiled_block(TILE_SIZE, TILE_SIZE);
dim3 tiled_grid((N + TILE_SIZE - 1) / TILE_SIZE,
(M + TILE_SIZE - 1) / TILE_SIZE);
printf("Matrix size: %dx%d x %dx%d\n\n", M, K, K, N);
double naive_tflops = benchmark_gemm(
gemm_naive, M, K, N, naive_grid, naive_block, 3, 20
);
printf("Naive GEMM : %.2f TFLOPS\n", naive_tflops);
double tiled_tflops = benchmark_gemm(
gemm_tiled, M, K, N, tiled_grid, tiled_block, 3, 20
);
printf("Tiled GEMM : %.2f TFLOPS\n", tiled_tflops);
double db_tflops = benchmark_gemm(
gemm_double_buffered, M, K, N, tiled_grid, tiled_block, 3, 20
);
printf("Double-buffered : %.2f TFLOPS\n", db_tflops);
printf("\nSpeedup (tiled vs naive): %.1fx\n", tiled_tflops / naive_tflops);
printf("Speedup (db vs tiled) : %.1fx\n", db_tflops / tiled_tflops);
return 0;
}
Expected output on A100:
Matrix size: 4096x4096 x 4096x4096
Naive GEMM : 4.2 TFLOPS
Tiled GEMM : 21.8 TFLOPS
Double-buffered : 28.4 TFLOPS
Speedup (tiled vs naive): 5.2x
Speedup (db vs tiled) : 1.3x
Triton Implementation
Triton provides a high-level abstraction that generates tiled kernels automatically while giving you control over tile sizes and scheduling. The Triton GEMM achieves comparable performance to hand-written CUDA with approximately 30 lines of code:
import torch
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(
A_ptr, B_ptr, C_ptr,
M, N, K,
stride_am, stride_ak, # row stride, col stride for A
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, # tile size along M (e.g., 128)
BLOCK_N: tl.constexpr, # tile size along N (e.g., 128)
BLOCK_K: tl.constexpr, # tile size along K (e.g., 32)
):
# Program ID selects which output tile this block computes
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# Row/column indices for this tile
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# Pointers to the first tile of A and B for this block
A_tile_ptr = A_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
B_tile_ptr = B_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
# Accumulator for this output tile
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Tile the K dimension
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Load tile of A with bounds masking
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_K)
a = tl.load(A_tile_ptr, mask=a_mask, other=0.0)
# Load tile of B with bounds masking
b_mask = (offs_k[:, None] < K - k * BLOCK_K) & (offs_n[None, :] < N)
b = tl.load(B_tile_ptr, mask=b_mask, other=0.0)
# Accumulate partial dot products
# tl.dot issues a matrix multiply on the tile - uses tensor cores on supported hardware
acc = tl.dot(a, b, acc)
# Advance pointers to next tile along K
A_tile_ptr += BLOCK_K * stride_ak
B_tile_ptr += BLOCK_K * stride_bk
# Write output tile
c = acc.to(tl.float16) # optional: cast to FP16 before storing
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C_ptr_tile = C_ptr + offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(C_ptr_tile, acc, mask=c_mask)
def matmul_triton(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""Wrapper that launches the tiled Triton GEMM kernel."""
assert A.shape[1] == B.shape[0], "K dimension mismatch"
M, K = A.shape
K, N = B.shape
C = torch.empty((M, N), device=A.device, dtype=torch.float32)
# Grid: one program per output tile
grid = (triton.cdiv(M, 128), triton.cdiv(N, 128))
matmul_kernel[grid](
A, B, C,
M, N, K,
A.stride(0), A.stride(1),
B.stride(0), B.stride(1),
C.stride(0), C.stride(1),
BLOCK_M=128, BLOCK_N=128, BLOCK_K=32,
)
return C
# Benchmark comparison
def benchmark():
M = N = K = 4096
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
B = torch.randn(K, N, device="cuda", dtype=torch.float32)
# Warmup
for _ in range(5):
C = matmul_triton(A, B)
torch.cuda.synchronize()
import time
iters = 50
start = time.perf_counter()
for _ in range(iters):
C = matmul_triton(A, B)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
flops = 2.0 * M * K * N
tflops = flops / (elapsed / iters) / 1e12
print(f"Triton tiled GEMM: {tflops:.2f} TFLOPS")
# Compare to PyTorch (cuBLAS)
start = time.perf_counter()
for _ in range(iters):
C_ref = torch.mm(A, B)
torch.cuda.synchronize()
elapsed_ref = time.perf_counter() - start
tflops_ref = flops / (elapsed_ref / iters) / 1e12
print(f"PyTorch (cuBLAS) : {tflops_ref:.2f} TFLOPS")
print(f"Triton vs cuBLAS : {tflops / tflops_ref:.1%}")
The tl.dot instruction in Triton automatically uses tensor cores on Ampere/Hopper hardware when the tile sizes are compatible (multiples of 16 for FP16, 8 for TF32). This is one reason Triton achieves high performance without manual tensor core intrinsics.
Mermaid - Tiling Algorithm Flow
Tiling Applied to Attention
FlashAttention is tiled GEMM applied to the attention mechanism. The standard attention computation is:
The naive implementation materializes the full attention score matrix (for sequence length ), which is memory. For in FP16, this is MB per head per layer - far too large to fit in shared memory.
FlashAttention observes that the attention computation can be tiled:
- Load a tile of (rows to ) into shared memory
- For each tile of and (columns to ):
- Compute partial scores in shared memory
- Maintain running softmax statistics (max and sum) across tiles using the online softmax trick
- Accumulate partial output
- Write to HBM
The key difference from vanilla tiled GEMM is the online softmax: you cannot normalize a row of the attention score matrix until you have seen all scores in that row, but you can maintain numerically stable running statistics that allow you to correct earlier estimates as later tiles arrive.
import torch
import triton
import triton.language as tl
# Simplified FlashAttention-style tiled attention (forward pass)
# Full implementation is in the flash-attn library
@triton.jit
def flash_attention_kernel(
Q_ptr, K_ptr, V_ptr, O_ptr,
stride_qm, stride_qk,
stride_km, stride_kk,
stride_vm, stride_vk,
stride_om, stride_ok,
N_CTX: tl.constexpr, # sequence length
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr, # rows of Q per block (e.g., 64)
BLOCK_N: tl.constexpr, # rows of K/V per block (e.g., 64)
scale: tl.constexpr,
):
# Each program handles BLOCK_M rows of the output
start_m = tl.program_id(0) * BLOCK_M
offs_m = start_m + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_DIM)
# Load Q tile for this block (stays in shared memory for the full inner loop)
Q_block = tl.load(
Q_ptr + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk,
mask=offs_m[:, None] < N_CTX
)
# Running accumulators for online softmax
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) # running max
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) # running output
# Tile over K and V
for start_n in range(0, N_CTX, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
# Load K tile
K_block = tl.load(
K_ptr + offs_n[:, None] * stride_km + offs_d[None, :] * stride_kk,
mask=offs_n[:, None] < N_CTX
)
# Compute attention scores: S = Q * K^T * scale
S = tl.dot(Q_block, tl.trans(K_block)) * scale # [BLOCK_M, BLOCK_N]
# Online softmax update
m_ij = tl.max(S, axis=1) # max over this K tile
m_new = tl.maximum(m_i, m_ij) # new running max
# Correction factor for accumulated output and sum
alpha = tl.exp(m_i - m_new) # rescale factor for previous accumulations
p = tl.exp(S - m_new[:, None]) # softmax numerator for this tile
# Load V tile
V_block = tl.load(
V_ptr + offs_n[:, None] * stride_vm + offs_d[None, :] * stride_vk,
mask=offs_n[:, None] < N_CTX
)
# Update accumulators
l_i = alpha * l_i + tl.sum(p, axis=1)
acc = alpha[:, None] * acc + tl.dot(p.to(tl.float16), V_block)
m_i = m_new
# Normalize and write output
acc = acc / l_i[:, None]
tl.store(
O_ptr + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok,
acc, mask=offs_m[:, None] < N_CTX
)
The same two-phase pattern: load tile of K/V, compute partial scores and output, move to next tile. The shared memory holds Q throughout (it is reused once per K/V tile) and K/V tiles for the duration of each iteration.
Extending Tiling Beyond Matrix Multiply
Convolution
A 2D convolution is equivalent to many small dot products between a filter and an input patch. The input patch around each output element overlaps heavily with neighboring output elements. Without tiling, each output pixel loads its input patch independently. With tiling:
- Load a tile of the input feature map (including the halo region for boundary patches) into shared memory
- Each thread computes its output pixel entirely from shared memory
- The halo (extra rows and columns beyond the output tile) is loaded once and shared by all threads that need those boundary pixels
The halo size equals the filter radius. For a 3x3 filter with 1 pixel radius: a 32x32 output tile requires a 34x34 input tile in shared memory - 6.25% overhead for the halo.
Reduction
Reduction operations (sum, max, softmax) use a different tiling pattern. Rather than tiles across the K dimension of GEMM, reductions tile across the input array itself:
- Each block loads a tile of the input into shared memory
- Perform a tree reduction within the block using shared memory
- Write the per-block partial result to global memory
- Launch a second kernel to reduce the per-block results
The key insight is the same: using shared memory for the within-block reduction avoids passes over HBM.
Mermaid - Tiling Pattern Taxonomy
Production Engineering Notes
Bank Conflicts in Shared Memory
Shared memory is divided into 32 banks (on Ampere/Hopper), each 4 bytes wide. When multiple threads in a warp access addresses that map to the same bank, the accesses are serialized - this is a bank conflict. For a 32x32 float tile, reading a column (s_A[ty][k] looping over ty for a fixed k) means all 32 threads access the same column index modulo 32, hitting the same bank - a 32-way bank conflict.
The standard fix is a padding column:
__shared__ float s_A[TILE_SIZE][TILE_SIZE + 1]; // +1 avoids bank conflicts on column access
__shared__ float s_B[TILE_SIZE][TILE_SIZE + 1];
The extra column shifts each row's starting address by 4 bytes, so columns of different rows map to different banks. This padding adds bytes per tile, a trivial shared memory overhead.
Tile Size and L2 Cache Interaction
For large matrix multiplications, consecutive thread blocks often access overlapping tiles of B. With a standard row-major launch order, adjacent blocks in the M dimension hit completely different rows of B, getting no L2 reuse. CUTLASS and cuBLAS use a swizzled thread block order that groups nearby blocks in M and N together, maximizing L2 hit rate for B tiles.
For typical A100 GEMM benchmarks, this swizzling contributes 5-15% throughput improvement on top of shared memory tiling.
Warp Tile and Register Tile
Production GEMM kernels add a third level of tiling beyond the thread block tile:
- Thread block tile (shared memory): 128x128, loaded into shared memory
- Warp tile: 32x64 - each warp handles a 32x64 sub-tile of the output
- Thread tile (register): 8x8 - each thread accumulates an 8x8 matrix of partial results in registers
With register tiling, each thread performs 64 FMAs per iteration instead of 1, dramatically increasing arithmetic intensity. This is why CUTLASS and cuBLAS routinely achieve 60-70% of peak FLOP/s while simple tiled GEMM achieves 20-35%.
Tensor Cores
H100 tensor cores operate on 16x16x16 matrix tiles natively (FP16/BF16). To use tensor cores:
- Thread block tile must be a multiple of 16 in all three dimensions (M, N, K)
- Data must be in FP16, BF16, TF32, or INT8
- Use
wmma(Warp Matrix Multiply Accumulate) intrinsics or themmaPTX instruction directly - Triton's
tl.dotautomatically uses tensor cores when tile shapes are compatible
Tensor cores provide approximately 16x throughput multiplier over scalar FP32. A100 peak FP16 tensor core throughput is 312 TFLOPS vs 19.5 TFLOPS scalar FP32. The tiling algorithm is unchanged - tensor cores just make the compute phase of each tile faster.
Common Mistakes
:::danger Missing __syncthreads() Between Load and Compute
Every tiled kernel must call __syncthreads() after loading a tile into shared memory and again after compute before overwriting the tile with the next tile. Missing the first sync causes threads to read partially-loaded data from shared memory - some threads may still be writing their element while others are reading theirs. Missing the second sync causes threads on the next iteration to overwrite elements that other threads in the same warp are still computing from. Both produce silently incorrect results that look like random noise in the output.
:::
:::danger Forgetting the Halo Region in Convolution Tiling When applying tiling to convolution, the input tile must include extra rows and columns (the halo) equal to the filter radius. A 32x32 output tile with a 3x3 filter needs a 34x34 input tile. Forgetting the halo means boundary output pixels access out-of-bounds shared memory addresses, producing garbage values without a detectable error. :::
:::warning Choosing Tile Size Without Measuring Occupancy A 64x64 tile for FP32 consumes 32KB of shared memory per block. On A100 with 96KB per SM, only 3 blocks can be resident simultaneously. At 64 threads per block (2 warps), that is 6 warps per SM - only 9% occupancy. Larger tiles increase arithmetic intensity but may starve the SM of resident warps. Always calculate occupancy alongside tile size. The sweet spot is usually the largest tile that still allows 4+ blocks per SM. :::
:::warning Not Padding Shared Memory Tiles to Avoid Bank Conflicts
Column accesses to a row-major shared memory tile generate 32-way bank conflicts if the tile width is a multiple of 32 (the number of banks). Adding a single padding column (float s_A[32][33] instead of float s_A[32][32]) eliminates column bank conflicts entirely. Forgetting this padding causes a 2-4x slowdown in the compute phase that is easy to miss if you are only measuring total kernel time.
:::
:::warning Applying Double Buffering When Compute Time is Less Than HBM Latency Double buffering is worthwhile only when the compute time for one tile exceeds the HBM load latency for the next tile. For small tiles where compute finishes quickly, the async load still does not complete before it is needed, and you gain nothing from double buffering while doubling your shared memory usage. Check that tile_compute_cycles greater than hbm_latency_cycles before adding this complexity. :::
Interview Questions and Answers
Q1: Walk through the tiled matrix multiply algorithm and explain why it reduces global memory traffic.
In the naive GEMM, computing output reads elements of row from A and elements of column from B, all from global memory. For an matrix, row of A is read times in total (once per output column), producing total reads.
Tiling partitions all three matrices into blocks. Each thread block computes one tile of C by iterating over tiles of A and B. For each iteration, the entire tile of A and tile of B are loaded into shared memory once, then each of the threads uses the shared memory tile for multiply-adds. The same elements produce arithmetic operations from one set of global loads.
Total HBM reads for A: elements, but each is loaded exactly once across all blocks that access it (one per output tile row), giving loads per thread block per output row, versus loads in naive. The reduction ratio is .
Q2: How do you choose tile size T for a given GPU and problem?
Three constraints bound the tile size from above and below:
Upper bound from shared memory: . For FP32 on H100 (100KB), . So T=64 is feasible for FP32.
Upper bound from occupancy: choose T such that to keep enough warps resident. For T=32 FP32: blocks, with 32 warps/block = too many warps total, limited by thread count. For T=64: blocks.
Lower bound from arithmetic intensity: compute intensity . Need intensity above the roofline ridge point ( FLOP/byte for A100 FP32) to be compute-bound. That requires - larger than shared memory allows - which is why production kernels add register-level tiling on top of shared memory tiling.
Practical choice: T=32 for FP32 with good occupancy, T=64 for FP16 (same shared memory, double the intensity).
Q3: What is double buffering and what latency does it hide?
Double buffering allocates two shared memory buffers and alternates between them. While the SM computes from tile K in buffer 0, it asynchronously prefetches tile K+1 into buffer 1 using cp.async. When computation finishes, buffers swap and the SM immediately starts computing from buffer 1 while prefetching tile K+2 into buffer 0.
This hides HBM load latency (200-600 cycles on A100/H100). In single-buffered tiling, the sequence is: sync, load tile (wait 400 cycles), sync, compute, repeat - all serialized. In double-buffered tiling, the load and compute execute in parallel. If compute time exceeds load time, the SM is never stalled waiting for data.
The cp.async instruction is key: it copies from global to shared without routing through registers and without stalling the issuing warp. The copy proceeds in the background on a dedicated load unit while the SM's arithmetic units process the current tile.
Q4: How would you apply tiling to the attention mechanism? What makes it different from GEMM tiling?
Attention requires intermediate memory (the full attention score matrix) in the naive implementation. Tiling attention means we never materialize the full score matrix - instead we compute attention for small tiles of the sequence and accumulate the output in shared memory.
The key difference from GEMM is the softmax normalization. In GEMM, each output tile is independent - you accumulate partial products and write. In attention, softmax over row requires the maximum over all scores in row before you can normalize. FlashAttention (Dao et al., 2022) solves this with online softmax: maintain running max and running sum as you process each K tile. When you encounter a larger score, apply a correction factor to all previously computed partial outputs to account for the updated normalization. This correction is a simple scalar multiply per accumulated row.
The shared memory layout: Q tile is loaded once and held throughout the K/V loop (it is reused once per K tile). K and V tiles are loaded once per iteration and discarded. The output accumulator lives in shared memory throughout and is only written to HBM at the end.
Memory complexity drops from to (just Q, K, V, and one output tile at a time), enabling attention at sequence lengths of 100K+ that would OOM with standard attention.
Q5: A tiled GEMM kernel achieves 22 TFLOPS on A100. The peak is 77.6 TFLOPS FP32. What are the likely remaining bottlenecks?
At 28% of peak, three factors typically explain the gap:
First, bank conflicts. If shared memory tiles are not padded, column accesses generate 32-way bank conflicts in the compute phase. This alone can cause a 2-3x slowdown in the inner loop.
Second, insufficient arithmetic intensity from shared memory tiling alone. A 32x32 tile provides arithmetic intensity of 8 FLOP/byte. The A100 roofline ridge point is 38.8 FLOP/byte. The kernel is still memory-bound. Adding register-level tiling (each thread accumulates an 8x8 register matrix) pushes intensity above 32 FLOP/byte and moves the kernel toward compute-bound.
Third, no tensor core utilization. The scalar FP32 instructions in a standard tiled GEMM do not use tensor cores. Tensor cores provide 4x throughput on TF32 and 16x on FP16 relative to scalar FP32. Switching to FP16 with tensor core-compatible tile sizes would push the benchmark from 22 to 80+ TFLOPS FP16.
The path from 22 to 60+ TFLOPS: (1) add padding to eliminate bank conflicts (get to ~30 TFLOPS), (2) add register tiling for higher intensity (get to ~45 TFLOPS), (3) switch to FP16 with tensor cores (get to ~70 TFLOPS). This is approximately what CUTLASS achieves.
Q6: Explain the online softmax trick used in FlashAttention. Why can't you just tile attention the same way you tile GEMM?
GEMM tiling works because each output element accumulates partial sums across K tiles, and those partial sums are simply added. No element of C depends on global statistics across all K tiles.
Attention is different because softmax is a normalization operation: . Computing the denominator requires seeing all values first. In a naive tiled implementation, you cannot write any output until you have processed all K tiles (to get the full row max and sum). This would require storing all K tiles of scores simultaneously - same memory as naive.
Online softmax (from Milakov and Gimelshein, 2018) maintains two running statistics: current row max and current sum . When tile reveals a new maximum , all previously accumulated exponentials were computed with the wrong normalization. The correction factor is - multiply the running sum and output accumulator by this factor, then add the new tile's contributions.
The invariant maintained: after processing tiles 0 through t, the accumulator equals with only the partial normalization through tile t. When all tiles are processed, the accumulator is exactly the correct attention output. The running correction only touches two scalars per row per tile - additional computation per tile.
Summary
Tiling is the technique of decomposing a computation so that a small block of data is loaded from HBM once, placed in shared memory, and reused for many arithmetic operations. The arithmetic intensity improvement is exactly equal to the tile size : fewer HBM reads for tiles.
The implementation structure is universal: (1) allocate shared memory tiles, (2) loop over the reduction dimension in steps of T, (3) cooperatively load one tile per iteration with a __syncthreads() afterward, (4) compute from shared memory, (5) __syncthreads() before the next load. Bank conflict padding, double buffering with cp.async, and register-level tiling are progressive optimizations on top of this foundation.
The pattern extends directly to convolution (halo tiles), attention (online softmax for normalization), and any operation with predictable data reuse. FlashAttention's breakthrough was recognizing that attention is a tiling problem in disguise - the online softmax makes the normalization compatible with the tile-and-accumulate structure that shared memory tiling requires.
