Tensor Cores and Mixed Precision
Reading time: ~40 min · Interview relevance: Very High · Target roles: ML Engineer, Systems Engineer, DL Researcher
The Production Scenario
It is 2:17 AM and you are watching a training run that has been going for six days. You are training a 13-billion parameter language model on a cluster of 64 A100s. The loss curve looks healthy. Your estimated finish time is 22 more days. Your compute budget expires in 19.
You message your team lead. She replies immediately: "Did you verify Tensor Cores are actually being used?" You pull up Nsight Compute and look at sm__pipe_tensor_cycles_active. It reads 3%. Your stomach drops. The model is running almost entirely on FP32 CUDA cores.
The culprit turns out to be a single linear layer in your custom attention implementation where the sequence length dimension was hardcoded to 1023 instead of 1024. One wrong integer. Tensor Cores require matrix dimensions that are multiples of 16. The GPU silently fell back to scalar FP32 math, cutting your throughput by roughly 12x. Fixing that one number drops your estimated completion time from 28 days to 4.
That is the lesson of Tensor Cores: they are extraordinarily powerful, completely silent when they are not active, and the difference between using them and not using them is often a shape you wrote in a config file three weeks ago.
This lesson explains exactly how Tensor Cores work, why they exist, what breaks them, and how to verify you are actually using them in production. By the end you will understand why a single GPU instruction from a Tensor Core replaces thousands of scalar FP32 operations, and why the precision format you choose - FP32, BF16, FP16, TF32, or FP8 - can mean the difference between a training run that finishes and one that never does.
Why This Exists
Before Tensor Cores shipped in the Volta V100 in 2017, every GPU computation worked the same way: one floating point multiply-accumulate (FMA) per CUDA core per clock cycle. A matrix multiplication of two matrices requires FMA operations. At , that is 137 billion operations per matrix multiply. With 5120 CUDA cores running at 1.5 GHz, a V100 in pure FP32 scalar mode could theoretically execute 15.7 TFLOPS. In practice, after memory latency, synchronization overhead, and pipeline stalls, real throughput on large GEMMs was closer to 10-12 TFLOPS.
Neural network training is almost entirely composed of matrix multiplications. The forward pass through a transformer block is a sequence of GEMMs: QKV projections, attention scores, output projections, feed-forward expansions. The backward pass doubles that count. A training step for a 7B parameter model involves thousands of large matrix multiplications. At 12 TFLOPS with 32-bit arithmetic, training GPT-3 (175B parameters) on V100s would have taken years at a cost no company could justify.
The fundamental insight that Tensor Cores exploit is that matrix multiplication has a regular, predictable structure that can be hardwired into silicon. Rather than decomposing a matrix multiply into individual FMAs and issuing them one at a time to scalar cores, a Tensor Core takes a fixed-size matrix fragment, performs the entire partial dot product in one hardware instruction, and accumulates the result in a higher-precision register. The silicon area that would have gone to control logic for thousands of independent scalar units instead becomes a dense array of multiply-accumulate circuits that all fire in lockstep.
The result is astonishing: a single Tensor Core instruction on an A100 in BF16 performs the equivalent of 4096 scalar FMA operations in one clock cycle. The A100 ships with 432 Tensor Cores running at 1.41 GHz, delivering 312 TFLOPS in BF16 - roughly 26x more than its own FP32 CUDA core throughput of 19.5 TFLOPS. The H100 pushes this further: 989 TFLOPS in BF16 with sparsity, 494 TFLOPS without. These numbers are not marketing - they are what you observe on real training workloads when everything is properly aligned.
Historical Context
The story of Tensor Cores begins not with GPUs but with TPUs. Google's first Tensor Processing Unit, announced in 2016, was built around systolic arrays - a hardware architecture specifically designed for matrix multiplication that runs multiply-accumulate units in a wave-like cascade. The TPU v1 achieved 92 TFLOPS INT8 at 40W, an efficiency that shocked the industry. Intel, NVIDIA, and AMD all had to respond.
NVIDIA's answer was the Volta architecture, released in May 2017 as the Tesla V100. The V100 introduced Tensor Cores for the first time - 640 of them, organized as 8 per Streaming Multiprocessor across 80 SMs. Each Tensor Core performed a matrix multiply-accumulate in one clock cycle using FP16 inputs and FP32 accumulation. The stated throughput was 125 TFLOPS FP16 - versus 15.7 TFLOPS FP32. The ratio, roughly 8x, made clear that the era of scalar floating point as the primary training mode was over.
The Turing architecture (2018, RTX 20xx series) added INT8 and INT4 Tensor Core support, targeting inference workloads where integer quantization was already practical. Ampere (2020, A100) enlarged the Tensor Core operation from to effectively fragments in what NVIDIA called the second-generation Tensor Core, and introduced TF32 - a format that preserves FP32's exponent range but rounds the mantissa to 10 bits, matching BF16 mantissa width while retaining FP32 compatibility for legacy code.
Hopper (2022, H100) introduced the third major Tensor Core generation. The key addition was FP8 support in two variants: E4M3 (4 exponent bits, 3 mantissa bits) and E5M2 (5 exponent bits, 2 mantissa bits). FP8 doubles Tensor Core throughput versus FP16/BF16 on the same hardware, enabling nearly 2 PFLOPS on a single H100 in FP8 with sparsity. Hopper also introduced the Transformer Engine, which dynamically selects precision per-layer and per-step, automatically tracking scaling factors required for FP8 stability.
The "aha moment" in Tensor Core history is the realization that neural networks do not need full FP32 precision for most of their computation. Weights that encode semantic relationships between tokens do not need 23 bits of mantissa precision. The gradients flowing back through a transformer benefit far more from a wider dynamic range than from higher mantissa resolution. This insight - that precision can be traded strategically without sacrificing training stability - is what makes the entire Tensor Core ecosystem viable.
Core Concepts
The GEMM Problem: Why Matrix Multiply Dominates
A transformer's compute budget is almost entirely matrix multiplications. For a model with hidden dimension and sequence length , the QKV projection alone computes:
The FLOPs for this operation are . For GPT-3 style dimensions (, , ), a single QKV projection costs GFLOPs. A transformer block has 4-6 such projections. A 96-layer model with a batch size of 32 runs thousands of these per training step.
The matrix multiply operation is defined as:
This is the GEMM (General Matrix Multiply) that BLAS libraries have optimized for decades, and it is exactly the operation Tensor Cores are hardwired to perform.
How a Tensor Core Works: The Warp-Level MMA
Understanding Tensor Cores requires understanding the GPU's thread hierarchy. CUDA threads are grouped into warps of 32 threads. A warp is the fundamental unit of execution - all 32 threads in a warp execute the same instruction simultaneously on different data (SIMT model).
The key insight: a Tensor Core operation is a warp-level operation. All 32 threads in a warp collectively execute a single matrix multiply-accumulate instruction (MMA). The matrix fragments are distributed across the 32 thread registers - no single thread holds a complete matrix fragment. The hardware fuses these 32 threads' register files into a single unified compute path.
The operation executed by one Tensor Core instruction:
where:
- is a matrix fragment (FP16 or BF16)
- is a matrix fragment (FP16 or BF16)
- is a accumulator matrix (FP32)
- is the result matrix (FP32)
This MMA involves individual multiply-add operations. A single Tensor Core instruction executes all 8192 operations in one cycle.
One Tensor Core MMA cycle:
Input A: 16×16 FP16 = 512 bytes from warp registers
Input B: 16×16 FP16 = 512 bytes from warp registers
Accum C: 16×16 FP32 = 1024 bytes from warp registers
Output D: 16×16 FP32 = 1024 bytes to warp registers
Operations: 8192 multiply-adds
Cycles: 1
The WMMA (Warp Matrix Multiply Accumulate) API in CUDA C++ exposes this directly. PyTorch's cuBLAS backend calls these instructions automatically when the inputs satisfy the precision and shape requirements.
Precision Formats: The Trade-offs
Every floating point format makes a trade-off between dynamic range (controlled by exponent bits) and precision (controlled by mantissa bits). Here is the full comparison:
| Format | Sign | Exponent | Mantissa | Dynamic Range | Precision |
|---|---|---|---|---|---|
| FP32 | 1 | 8 | 23 | ~7 decimal digits | |
| FP16 | 1 | 5 | 10 | ~3 decimal digits | |
| BF16 | 1 | 8 | 7 | ~2 decimal digits | |
| TF32 | 1 | 8 | 10 | ~3 decimal digits | |
| FP8 E4M3 | 1 | 4 | 3 | ~1 decimal digit | |
| FP8 E5M2 | 1 | 5 | 2 | ~0.5 decimal digit |
The critical observation about BF16: it has the same exponent bits as FP32. This means BF16 and FP32 represent the same range of numbers - only the precision differs. Gradients that would overflow FP16 (max ) fit comfortably in BF16. This is why BF16 requires no loss scaling and FP16 does.
Mixed Precision Training: The Full Picture
Mixed precision training is not simply "use FP16 everywhere." The correct mental model is: store master weights in FP32, compute in lower precision, accumulate gradients in FP32.
The reason for this split is subtle. Neural network weights are typically in the range and small gradient updates need to be representable relative to the current weight magnitude. If a weight has value and the gradient step is , the relative magnitude is . FP16 with 10 mantissa bits can represent relative differences down to about . Near the edge of representability, rounding errors accumulate and weights stop updating - the model stalls.
FP32 accumulation in the master weight copy avoids this. The update rule is:
In mixed precision:
- Cast master FP32 weights to FP16/BF16 for the forward pass
- Compute loss in FP16/BF16 (Tensor Core accelerated)
- Compute gradients in FP16/BF16 (Tensor Core accelerated)
- Cast gradients to FP32 and update master weights in FP32
The compute-expensive steps (forward, backward) run at Tensor Core speed. The weight update, which is memory-bound anyway, runs in FP32 with no significant throughput loss.
Loss Scaling for FP16
FP16's limited dynamic range (max ) creates a specific failure mode during backward passes. Small gradient values that should be on the order of to underflow to zero when represented in FP16. When gradients vanish due to underflow - not because of actual gradient flow - the model appears to converge but parameters are not actually being updated.
Loss scaling solves this by multiplying the loss by a large scalar (typically 2048 to 65536) before the backward pass. By the chain rule, every gradient is multiplied by . Gradients that would have been become to , safely within FP16 range. After the backward pass, gradients are divided by before the weight update.
The PyTorch GradScaler implements dynamic loss scaling: if no overflow is detected for steps, the scale increases; if an overflow (nan/inf) is detected, the scale decreases and the step is skipped.
BF16 does not need loss scaling because its 8-bit exponent matches FP32's exponent range. Gradients that fit in FP32 fit in BF16. The precision loss in the mantissa affects accuracy but not overflow behavior.
Code Examples
Basic Mixed Precision with torch.cuda.amp
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
import time
# Simple transformer-like model for demonstration
class FeedForward(nn.Module):
def __init__(self, d_model: int = 4096, d_ff: int = 16384):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.act = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(self.act(self.w1(x)))
def benchmark_precision(
model: nn.Module,
x: torch.Tensor,
use_amp: bool = True,
dtype: torch.dtype = torch.bfloat16,
n_steps: int = 50,
) -> float:
"""Returns average step time in milliseconds."""
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# GradScaler only needed for FP16, not BF16
scaler = GradScaler() if (use_amp and dtype == torch.float16) else None
# Warmup
for _ in range(5):
with autocast(dtype=dtype, enabled=use_amp):
loss = model(x).mean()
if scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(n_steps):
with autocast(dtype=dtype, enabled=use_amp):
loss = model(x).mean()
if scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - start) * 1000 / n_steps
return elapsed_ms
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FeedForward(d_model=4096, d_ff=16384).to(device)
# Batch=32, seq_len=2048, d_model=4096 - all multiples of 16 for Tensor Core alignment
x = torch.randn(32, 2048, 4096, device=device)
fp32_time = benchmark_precision(model, x, use_amp=False)
bf16_time = benchmark_precision(model, x, use_amp=True, dtype=torch.bfloat16)
fp16_time = benchmark_precision(model, x, use_amp=True, dtype=torch.float16)
print(f"FP32 (no AMP): {fp32_time:.2f} ms/step")
print(f"BF16 AMP: {bf16_time:.2f} ms/step ({fp32_time / bf16_time:.1f}x speedup)")
print(f"FP16 AMP: {fp16_time:.2f} ms/step ({fp32_time / fp16_time:.1f}x speedup)")
Verifying Tensor Core Usage with CUDA Events and Shape Inspection
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
def check_tensor_core_alignment(tensor: torch.Tensor, name: str = "tensor") -> None:
"""
Tensor Cores require dimensions to be multiples of 8 (minimum) or 16 (optimal).
FP16/BF16: multiples of 8 minimum, 16 for best performance.
FP8: multiples of 16.
"""
shape = tensor.shape
violations = []
for dim_idx, size in enumerate(shape):
if size % 16 != 0:
if size % 8 == 0:
violations.append(
f" dim {dim_idx}={size}: multiple of 8 but not 16 "
f"(suboptimal, may miss peak throughput)"
)
else:
violations.append(
f" dim {dim_idx}={size}: NOT a multiple of 8 "
f"(Tensor Cores WILL NOT be used)"
)
if violations:
print(f"[WARNING] {name} shape {shape} has alignment issues:")
for v in violations:
print(v)
else:
print(f"[OK] {name} shape {shape} - all dims multiples of 16")
def timed_matmul(
a: torch.Tensor,
b: torch.Tensor,
n_repeats: int = 100,
) -> float:
"""Returns throughput in TFLOPS."""
# Warmup
for _ in range(10):
_ = torch.matmul(a, b)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(n_repeats):
_ = torch.matmul(a, b)
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
avg_ms = elapsed_ms / n_repeats
M, K = a.shape
_, N = b.shape
flops = 2 * M * K * N # multiply-add = 2 ops
tflops = flops / (avg_ms * 1e-3) / 1e12
return tflops
if __name__ == "__main__":
device = torch.device("cuda")
print("=== Shape Alignment Checks ===")
check_tensor_core_alignment(torch.empty(1024, 1024), "good_square")
check_tensor_core_alignment(torch.empty(1023, 1024), "bad_M")
check_tensor_core_alignment(torch.empty(2048, 512), "good_rectangle")
check_tensor_core_alignment(torch.empty(2048, 511), "bad_N")
print("\n=== Throughput: Aligned vs Misaligned Shapes ===")
# Aligned - should use Tensor Cores
a_good = torch.randn(4096, 4096, device=device, dtype=torch.float16)
b_good = torch.randn(4096, 4096, device=device, dtype=torch.float16)
tflops_good = timed_matmul(a_good, b_good)
# Misaligned by 1 - will NOT use Tensor Cores
a_bad = torch.randn(4095, 4096, device=device, dtype=torch.float16)
b_bad = torch.randn(4096, 4095, device=device, dtype=torch.float16)
tflops_bad = timed_matmul(a_bad, b_bad)
print(f"Aligned (4096x4096 FP16): {tflops_good:.1f} TFLOPS")
print(f"Misaligned (4095x4096 FP16): {tflops_bad:.1f} TFLOPS")
print(f"Tensor Core speedup: {tflops_good / tflops_bad:.1f}x")
GradScaler in Production: FP16 with Dynamic Loss Scaling
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
import logging
logger = logging.getLogger(__name__)
class MixedPrecisionTrainer:
"""
Production-ready wrapper for FP16 mixed precision training.
For BF16, set dtype=torch.bfloat16 and scaler will be disabled automatically.
"""
def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
dtype: torch.dtype = torch.bfloat16,
initial_scale: float = 2.0 ** 16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
):
self.model = model
self.optimizer = optimizer
self.dtype = dtype
# GradScaler is only needed for FP16
self.use_scaler = dtype == torch.float16
self.scaler = GradScaler(
init_scale=initial_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=self.use_scaler,
) if self.use_scaler else None
self._step_count = 0
self._overflow_count = 0
def training_step(self, batch_inputs: torch.Tensor, batch_labels: torch.Tensor):
self.optimizer.zero_grad(set_to_none=True)
with autocast(dtype=self.dtype):
logits = self.model(batch_inputs)
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
batch_labels.view(-1),
ignore_index=-100,
)
if self.scaler is not None:
# FP16 path: scale loss, backward, unscale, step
self.scaler.scale(loss).backward()
# Unscale before gradient clipping - CRITICAL ordering
self.scaler.unscale_(self.optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=1.0
)
# scaler.step() skips the step if gradients contain inf/nan
self.scaler.step(self.optimizer)
self.scaler.update()
# Track overflow events to detect instability
if self.scaler.get_scale() < 2.0 ** 8:
self._overflow_count += 1
logger.warning(
f"Step {self._step_count}: loss scale dropped to "
f"{self.scaler.get_scale():.0f} - FP16 overflow detected"
)
else:
# BF16 path: no scaling needed
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=1.0
)
self.optimizer.step()
self._step_count += 1
return loss.item(), grad_norm.item()
def get_scale(self) -> float:
"""Returns current loss scale (1.0 for BF16 where scaling is disabled)."""
if self.scaler is not None:
return self.scaler.get_scale()
return 1.0
def overflow_rate(self) -> float:
"""Fraction of steps with FP16 overflow. Should stay below 0.01 in healthy runs."""
if self._step_count == 0:
return 0.0
return self._overflow_count / self._step_count
The WMMA API: How cuBLAS Calls Tensor Cores
While you will almost never write WMMA code directly in Python, understanding what cuBLAS is doing under the hood clarifies why shapes and dtypes matter so much.
# This is a conceptual illustration of what happens at the CUDA C++ level
# when torch.matmul is called with FP16 inputs on a GPU with Tensor Cores.
# You do not write this code - cuBLAS does.
"""
// CUDA C++ WMMA API (illustrative - cuBLAS abstracts this)
#include <mma.h>
using namespace nvcuda::wmma;
__global__ void tensor_core_gemm_kernel(
const half* A, const half* B, float* C,
int M, int N, int K
) {
// Each warp of 32 threads collectively holds matrix fragments
fragment<matrix_a, 16, 16, 16, half, row_major> a_frag;
fragment<matrix_b, 16, 16, 16, half, col_major> b_frag;
fragment<accumulator, 16, 16, 16, float> c_frag;
// Initialize accumulator to zero
fill_fragment(c_frag, 0.0f);
// Load input fragments from shared memory into warp registers
// All 32 threads cooperate - each thread holds part of the fragment
load_matrix_sync(a_frag, A + warp_row * K, K);
load_matrix_sync(b_frag, B + warp_col, N);
// THE TENSOR CORE INSTRUCTION
// This single call executes 8192 multiply-adds in one hardware cycle
mma_sync(c_frag, a_frag, b_frag, c_frag);
// Store result back to global memory
store_matrix_sync(C + warp_row * N + warp_col, c_frag, N, mem_row_major);
}
// cuBLAS selects the optimal kernel (tiling, warp count, pipeline depth)
// based on M, N, K dimensions and available GPU architecture.
"""
# In Python, you just do this and cuBLAS handles everything:
import torch
M, K, N = 4096, 4096, 4096 # all multiples of 16 - Tensor Cores WILL be used
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
# cuBLAS detects FP16 inputs + aligned shapes + Volta+ architecture
# and selects the Tensor Core WMMA kernel automatically
c = torch.matmul(a, b) # internally calls cublasSgemmEx or cublasHgemm
# c is FP16 by default; for FP32 accumulation: torch.mm with float16 inputs
# still accumulates in FP32 internally within cuBLAS
Checking Tensor Core Activity Programmatically
import torch
import subprocess
import re
from contextlib import contextmanager
@contextmanager
def nvml_check():
"""Context manager that prints GPU utilization info after a block."""
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
info_before = pynvml.nvmlDeviceGetUtilizationRates(handle)
yield
torch.cuda.synchronize()
info_after = pynvml.nvmlDeviceGetUtilizationRates(handle)
print(f"GPU utilization: {info_after.gpu}%")
pynvml.nvmlShutdown()
def print_matmul_config(a: torch.Tensor, b: torch.Tensor) -> None:
"""
Diagnose whether a given matmul will use Tensor Cores.
Tensor Core requirements:
1. dtype must be FP16, BF16, TF32, or FP8 (not FP32)
2. M, N, K all multiples of 8 (FP16/BF16) or 16 (FP8)
3. Input pointers must be 16-byte aligned (guaranteed by PyTorch allocator)
4. GPU must be Volta (cc 7.0) or newer
"""
M, K = a.shape[-2], a.shape[-1]
_, N = b.shape[-2], b.shape[-1]
dtype_ok = a.dtype in (torch.float16, torch.bfloat16)
align_ok_8 = (M % 8 == 0) and (K % 8 == 0) and (N % 8 == 0)
align_ok_16 = (M % 16 == 0) and (K % 16 == 0) and (N % 16 == 0)
cc = torch.cuda.get_device_capability()
arch_ok = cc[0] >= 7 # Volta = 7.0
print(f"Matrix shapes: A={a.shape}, B={b.shape}")
print(f"dtype: {a.dtype} ({'OK' if dtype_ok else 'FAIL - must be FP16/BF16'})")
print(f"Shape mod 8: {'OK' if align_ok_8 else 'FAIL'} "
f"({'optimal' if align_ok_16 else 'suboptimal (use multiples of 16)'})")
print(f"GPU arch: sm_{cc[0]}{cc[1]} ({'OK' if arch_ok else 'FAIL - need Volta+'})")
print(f"Tensor Core usage: {'LIKELY' if (dtype_ok and align_ok_8 and arch_ok) else 'UNLIKELY'}")
if __name__ == "__main__":
print("--- Good case ---")
a = torch.randn(2048, 4096, dtype=torch.float16, device="cuda")
b = torch.randn(4096, 2048, dtype=torch.float16, device="cuda")
print_matmul_config(a, b)
print("\n--- Bad dtype ---")
a2 = torch.randn(2048, 4096, dtype=torch.float32, device="cuda")
b2 = torch.randn(4096, 2048, dtype=torch.float32, device="cuda")
print_matmul_config(a2, b2)
print("\n--- Bad shape ---")
a3 = torch.randn(2047, 4096, dtype=torch.float16, device="cuda")
b3 = torch.randn(4096, 2048, dtype=torch.float16, device="cuda")
print_matmul_config(a3, b3)
FP8 Training with Transformer Engine (H100)
# transformer_engine is available on H100 (Hopper) systems
# pip install transformer-engine
try:
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
TE_AVAILABLE = True
except ImportError:
TE_AVAILABLE = False
print("transformer_engine not available - FP8 requires H100")
import torch
import torch.nn as nn
def build_fp8_model(d_model: int = 4096, d_ff: int = 16384) -> nn.Module:
"""
Build a feedforward block using Transformer Engine's FP8-capable layers.
TE's Linear layer transparently selects FP8 E4M3/E5M2 when fp8_autocast is active.
"""
if not TE_AVAILABLE:
# Fallback to standard PyTorch BF16
return nn.Sequential(
nn.Linear(d_model, d_ff, bias=False),
nn.GELU(),
nn.Linear(d_ff, d_model, bias=False),
)
return nn.Sequential(
te.Linear(d_model, d_ff, bias=False),
nn.GELU(),
te.Linear(d_ff, d_model, bias=False),
)
def train_fp8_step(
model: nn.Module,
x: torch.Tensor,
optimizer: torch.optim.Optimizer,
) -> float:
"""Single training step using FP8 compute on H100."""
if not TE_AVAILABLE:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
loss = model(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
return loss.item()
# FP8 recipe: E4M3 for forward, E5M2 for backward (more dynamic range for gradients)
fp8_recipe = DelayedScaling(
fp8_format=Format.HYBRID, # E4M3 forward, E5M2 backward
amax_history_len=16, # how many steps to track for auto-scaling
amax_compute_algo="max", # use max of recent amax history
)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
loss = model(x).mean()
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
return loss.item()
if __name__ == "__main__" and TE_AVAILABLE:
device = torch.device("cuda")
model = build_fp8_model().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# All dims must be multiples of 16 for FP8 Tensor Cores
x = torch.randn(32, 2048, 4096, device=device)
for step in range(10):
loss = train_fp8_step(model, x, optimizer)
if step % 2 == 0:
print(f"Step {step}: loss={loss:.4f}")
Mermaid: Tensor Core Decision Tree
Performance Numbers: Real Hardware
These are measured throughput values for large GEMM operations (not synthetic micro-benchmarks) on real datacenter GPUs:
| GPU | FP32 CUDA Cores | TF32 Tensor Cores | BF16 Tensor Cores | FP16 Tensor Cores | FP8 Tensor Cores |
|---|---|---|---|---|---|
| V100 | 14 TFLOPS | - | - | 112 TFLOPS | - |
| A100 40GB | 19.5 TFLOPS | 156 TFLOPS | 312 TFLOPS | 312 TFLOPS | - |
| A100 80GB | 19.5 TFLOPS | 156 TFLOPS | 312 TFLOPS | 312 TFLOPS | - |
| H100 SXM | 67 TFLOPS | 989 TFLOPS | 989 TFLOPS | 989 TFLOPS | 1979 TFLOPS |
| H100 PCIe | 51 TFLOPS | 756 TFLOPS | 756 TFLOPS | 756 TFLOPS | 1513 TFLOPS |
Numbers with sparsity enabled (2:4 structured sparsity) are 2x the dense values for Tensor Core paths.
The A100 ratio of BF16 Tensor Cores to FP32 CUDA Cores is 16x. This is the gap you leave on the table when a shape alignment issue forces fallback to scalar compute.
Production Engineering Notes
The TF32 Trap
Ampere and later GPUs run FP32 CUDA core operations in TF32 mode by default when torch.backends.cuda.matmul.allow_tf32 = True (the default in PyTorch 1.7+). TF32 truncates the FP32 mantissa to 10 bits before feeding into Tensor Cores, silently introducing numerical differences.
For most training workloads, TF32 is fine. For numerical validation, scientific computing, or any code that compares against reference FP32 results, disable it:
import torch
# Disable TF32 for matmul (enabled by default on A100+)
torch.backends.cuda.matmul.allow_tf32 = False
# Disable TF32 for cuDNN convolutions as well
torch.backends.cudnn.allow_tf32 = False
# Verify current settings
print(torch.backends.cuda.matmul.allow_tf32) # False
print(torch.backends.cudnn.allow_tf32) # False
If you notice small numerical differences between a V100 and an A100 on the same model with FP32 dtype, TF32 is the likely cause.
Padding Tensors to Tensor Core Alignment
When your problem dimensions are not multiples of 16 - for example, a vocabulary size of 32001 or a custom sequence length of 511 - you can pad inputs to the next multiple of 16 without affecting correctness:
import torch
import torch.nn.functional as F
def pad_to_multiple(tensor: torch.Tensor, multiple: int = 16, dim: int = -1) -> torch.Tensor:
"""
Pad the last dimension (or specified dim) of a tensor to the next multiple.
The padded region is masked out in loss computation or attention.
"""
size = tensor.shape[dim]
remainder = size % multiple
if remainder == 0:
return tensor
pad_amount = multiple - remainder
# F.pad pads from the last dim inward: (left, right, bottom, top, ...)
pad_spec = [0] * (2 * tensor.ndim)
# Set right-padding on the target dimension
pad_spec[-(2 * (dim % tensor.ndim + 1))] = pad_amount
return F.pad(tensor, pad_spec[::-1])
# Example: vocabulary embedding with size 32001 (not multiple of 16)
vocab_size = 32001
padded_size = ((vocab_size + 15) // 16) * 16 # 32016
print(f"Original: {vocab_size}, Padded: {padded_size}")
embedding = torch.nn.Embedding(padded_size, 4096).cuda()
# Input indices must still be < vocab_size; padded rows are never accessed
Profiling with Nsight Compute
The definitive way to verify Tensor Core usage is Nsight Compute. The key metric is sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_active.
# Profile a single training step
ncu --metrics sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_active \
--target-processes all \
python train.py --steps 1 --profile
# Full tensor core breakdown
ncu --set full --target-processes all python train.py --steps 1 --profile
# For A100, also check:
# sm__pipe_tensor_op_hmma_cycles_active (FP16/BF16 Tensor Core)
# sm__pipe_tensor_op_imma_cycles_active (INT8 Tensor Core)
A well-tuned BF16 transformer training run should show sm__pipe_tensor_cycles_active above 60-70%. Below 20% is a warning sign. Below 5% means something is seriously misaligned - check shapes immediately.
NCCL and Mixed Precision in Distributed Training
In distributed training (DDP, FSDP), gradient all-reduce happens in the communication dtype. By default, PyTorch DDP all-reduces in FP32 regardless of compute dtype. For large models, this can make communication the bottleneck.
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Enable BF16 gradient all-reduce (reduces communication bandwidth by 2x)
# Available in PyTorch >= 2.0
model = DDP(
model,
device_ids=[local_rank],
gradient_as_bucket_view=True, # avoids extra copy
)
# Override communication dtype for gradient buckets
# This is experimental but stable in recent PyTorch
for param in model.parameters():
if param.requires_grad:
param.register_hook(
lambda grad: grad.to(torch.bfloat16).to(grad.dtype)
)
# Better approach: use FSDP with mixed_precision policy
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
bf16_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16, # gradient all-reduce in BF16
buffer_dtype=torch.bfloat16,
)
model = FSDP(model, mixed_precision=bf16_policy)
Choosing Between BF16 and FP16 in Practice
The choice is almost always BF16 for training on modern hardware. The practical rules:
- A100, H100, TPU v4+: use BF16. Same exponent range as FP32, no loss scaling, clean and stable.
- V100, RTX 2080/3080 consumer GPUs: use FP16. These GPUs do not support BF16 Tensor Cores.
- Inference only: INT8 or INT4 quantization (bitsandbytes, llm.int8()) for maximum throughput.
- FP8 (H100 only): use Transformer Engine's
fp8_autocast. Do not implement manually.
Common Mistakes
:::danger Using FP32 dtype and expecting Tensor Core speed
If you call torch.matmul(a, b) where a and b are torch.float32, you are using CUDA cores, not Tensor Cores (except on Ampere+ with TF32 enabled). The throughput gap is 16x on A100. Always verify your dtype before benchmarking.
# WRONG - FP32 bypasses Tensor Cores entirely on pre-Ampere
x = torch.randn(4096, 4096, device="cuda") # default dtype is float32
y = torch.matmul(x, x) # no Tensor Cores on V100
# RIGHT - explicit FP16 or BF16
x = torch.randn(4096, 4096, device="cuda", dtype=torch.bfloat16)
y = torch.matmul(x, x) # Tensor Cores engaged
:::
:::danger Using FP16 for LLM training without understanding overflow FP16 maximum representable value is 65504. LLM gradient norms routinely exceed this during early training, causing NaN propagation. Without proper loss scaling and overflow detection, FP16 LLM training silently corrupts weights.
Use BF16 for any model with more than ~1B parameters. If you must use FP16 (V100, older hardware), monitor scaler.get_scale() every 100 steps. If it drops below 256, your training is in overflow trouble.
:::
:::warning Shapes that are not multiples of 16 This is the most common silent performance killer in production. A batch size of 96 (multiple of 16: yes), sequence length of 1023 (multiple of 16: no), vocabulary size of 50257 (multiple of 16: no) will kill Tensor Core efficiency on the affected matrix multiplications.
The fix for vocabulary: round up to 50272 (next multiple of 16). The fix for sequence length: pad to 1024. These changes are worth 5-10x throughput recovery on affected layers. :::
:::warning Calling scaler.unscale_() in the wrong order
If you clip gradients with torch.nn.utils.clip_grad_norm_() while using FP16 AMP, you must call scaler.unscale_(optimizer) BEFORE clipping. Clipping scaled gradients produces incorrect gradient norms and silently corrupts training.
# WRONG - clip before unscale
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # clips SCALED grads
scaler.step(optimizer)
scaler.update()
# RIGHT - unscale first, then clip
scaler.scale(loss).backward()
scaler.unscale_(optimizer) # must come before clip_grad_norm_
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # clips true grads
scaler.step(optimizer)
scaler.update()
:::
:::warning Autocast context does not persist through DataLoader workers
torch.cuda.amp.autocast is a thread-local context manager. It does not apply inside DataLoader worker processes or subprocesses spawned before the autocast context. Any preprocessing that does GPU work in workers runs in FP32. This is usually not a performance issue since DataLoader workers typically run on CPU, but it can cause confusion when debugging dtype mismatches.
:::
:::danger FP8 without Transformer Engine
FP8 is not simply tensor.to(torch.float8_e4m3fn). FP8 training requires per-tensor scaling factors that must be tracked and updated every step. Raw FP8 matmuls without proper scaling will diverge within 100 steps. Always use NVIDIA's Transformer Engine (import transformer_engine.pytorch as te) for FP8 training - it handles scaling automatically.
:::
The Tensor Core Architecture in Full
Interview Q&A
Q1: What exactly does a Tensor Core do that a CUDA core cannot, and what is the throughput difference on an A100?
Answer:
A CUDA core executes one floating point multiply-accumulate (FMA) per clock cycle. It operates on scalar values: one number in, one number out. At 1.41 GHz with 6912 CUDA cores, the A100's peak FP32 throughput is about 19.5 TFLOPS.
A Tensor Core is a fixed-function matrix multiply unit. It takes two matrix fragments and an accumulator as inputs and produces a matrix result. One Tensor Core instruction on an A100 executes a MMA, which involves multiply-add operations. This happens in a single hardware cycle, across all 32 threads of a warp cooperatively. With 432 Tensor Cores at 1.41 GHz, the A100 achieves 312 TFLOPS in BF16 - a 16x improvement over its CUDA core FP32 throughput.
The operation is D = A * B + C where A and B are FP16/BF16 inputs and C/D are FP32 accumulators. The higher-precision accumulation prevents error accumulation across many tiles in a large GEMM.
The key constraint is that Tensor Cores are a warp-level (not per-thread) resource. A single thread cannot invoke a Tensor Core - it takes 32 threads acting together. This is why the WMMA API in CUDA C++ operates on "fragments" distributed across warp registers.
Q2: Why does FP16 training require loss scaling but BF16 does not?
Answer:
The difference comes from how exponent bits are allocated. FP16 uses 5 exponent bits, giving a maximum representable value of 65504 and a minimum positive normal value around . BF16 uses 8 exponent bits - the same as FP32 - giving a maximum of .
During backpropagation, gradient values for deep networks are often very small - on the order of to . These values are representable in FP32 but underflow to zero in FP16 (min positive FP16 normal is ). When gradients underflow to zero, those weights receive no update signal. The model appears to train (loss decreases) but parameter updates stall for layers where gradients are small.
Loss scaling multiplies the loss by a large constant before backward, which multiplies all gradients by via the chain rule. Gradients of become to , safely in FP16 range. After the backward pass but before the optimizer step, gradients are divided by .
BF16 does not need this because BF16's exponent range matches FP32. Any gradient that survives in FP32 also survives in BF16. The lower mantissa precision of BF16 (7 bits vs 23 in FP32) introduces quantization noise but not systematic underflow.
Dynamic loss scaling (GradScaler) starts with a large scale, detects overflows (inf/nan in gradients), and reduces the scale on overflow. It gradually increases the scale when no overflow is detected for several hundred steps, adapting to the actual gradient magnitude range of the current training phase.
Q3: Your training run is 10x slower than expected on an A100. Walk me through how you would diagnose whether Tensor Cores are being used.
Answer:
The diagnostic process has three levels, going from coarse to fine:
Level 1 - Quick sanity checks in code:
- Verify input dtypes:
print(model.parameters().__next__().dtype)- should be float16 or bfloat16 after AMP autocast. - Check all matrix dimensions are multiples of 8 (minimum) and ideally 16. Pay special attention to custom layers, vocabulary sizes, and sequence lengths.
- Verify GPU arch:
torch.cuda.get_device_capability()- should return (8, 0) for A100 or higher.
Level 2 - Pytorch profiler:
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
# run a few steps
prof.export_chrome_trace("trace.json")
Open in chrome://tracing. Look for volta_h884gemm or ampere_h16816gemm kernel names - these confirm Tensor Core usage. sgemm kernels (no h/half prefix) indicate FP32 scalar fallback.
Level 3 - Nsight Compute:
ncu --metrics sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_active \
python train.py
This directly measures the fraction of SM cycles where Tensor Core units were active. A healthy BF16 training run should show 60-80%. Below 10% confirms the problem.
Common root causes once confirmed: a single layer with misaligned shapes (often vocabulary embedding or a custom attention variant), the entire model running in FP32 because autocast was applied to the loss but not the model forward (wrong scope), or a torch.compile interaction that forced eager mode on key layers.
Q4: Explain TF32. Why does it exist and what are its risks?
Answer:
TF32 (TensorFloat-32) is a compute format introduced with the A100 that is not a storage format - you never allocate a TF32 tensor. Instead, when the GPU performs an FP32 matrix multiplication on A100+, it silently truncates each FP32 operand's mantissa from 23 bits to 10 bits before feeding it into the Tensor Core. The exponent (8 bits) is preserved. The output is accumulated and stored in FP32.
The motivation: Tensor Cores need fixed-size operands. By truncating the FP32 mantissa to match BF16's 7-bit mantissa (approximately; TF32 uses 10 bits, slightly more precise), A100 can run FP32-typed matrix multiplications through Tensor Cores and achieve roughly 156 TFLOPS versus 19.5 TFLOPS for true scalar FP32. This gives legacy FP32 code a significant speedup with no code changes.
The risk: TF32 introduces numerical differences compared to "true" FP32. Results differ in the last 13 mantissa bits. For neural network training, this is almost never a problem - the noise from TF32 truncation is smaller than gradient noise. For numerical algorithms (scientific simulations, iterative solvers), it can cause subtle convergence failures that are hard to debug.
torch.backends.cuda.matmul.allow_tf32 defaults to True in PyTorch 1.7+. Setting it to False forces true FP32 on CUDA cores at 19.5 TFLOPS. Teams doing reproducibility validation across GPU generations should set it to False and accept the speed cost.
Q5: What are the shape requirements for Tensor Core efficiency and what strategies do you use when your problem dimensions do not satisfy them?
Answer:
The requirements exist because Tensor Core hardware operates on fixed-size matrix tiles. The minimum tile is 8 in each dimension for FP16/BF16, and optimal is 16. For FP8 on Hopper, all dimensions must be multiples of 16. If any dimension is not a multiple of 8, cuBLAS falls back to scalar CUDA cores entirely. If dimensions are multiples of 8 but not 16, Tensor Cores are used but tile utilization is suboptimal.
Practically, this means M, N, and K in your GEMM must all be multiples of 16 for full efficiency.
Strategy 1 - Pad vocabulary size: Language model embedding matrices have vocab_size rows. GPT-2 uses 50257, GPT-NeoX uses 50432 (already 16-aligned). Simply round up: padded_vocab = ((vocab_size + 15) // 16) * 16. The extra rows are never accessed in normal operation.
Strategy 2 - Pad batch/sequence dimensions: If your sequence length is 1023 or your batch size is 6, pad to 1024/8 respectively. Use attention masks to ignore padded positions.
Strategy 3 - Redesign for alignment: When building new models, choose dimensions that are multiples of 64 (a common Tensor Core tile size in cuBLAS kernels). Hidden dims of 768, 1024, 2048, 4096 are all good. Avoid 1023, 1000, or arbitrary power-of-10 sizes.
Strategy 4 - torch.compile with max_autotune: In PyTorch 2.0+, torch.compile(model, mode="max-autotune") uses Triton's autotuner to find the best tiling strategy for your specific shapes, sometimes recovering efficiency even for non-standard dimensions.
Strategy 5 - Profile before fixing: Not all layers matter equally. A misaligned batch norm layer with small matrices contributes negligibly to total runtime. Use Nsight Compute or PyTorch Profiler to identify which specific kernel calls are falling back to scalar compute, then prioritize those.
Q6: How does FP8 training work on H100 and what makes it different from simply casting to FP8?
Answer:
FP8 on H100 is not a drop-in replacement for BF16. It requires a per-tensor scaling framework because FP8's extremely limited dynamic range (max 448 for E4M3, max 57344 for E5M2) means raw values from neural networks will overflow in most layers without pre-scaling.
The NVIDIA Transformer Engine implements Delayed Scaling: it tracks the maximum absolute value (amax) of each tensor over the last N steps, uses this history to compute a per-tensor scale factor, and applies the scale before the FP8 cast and its inverse after the computation. Each te.Linear layer maintains its own scale tensors for inputs, weights, and output gradients.
Two FP8 variants serve different roles:
- FP8 E4M3: used in the forward pass. Higher mantissa precision (3 bits) at the cost of dynamic range. Forward activations tend to be well-behaved and don't need wide range.
- FP8 E5M2: used in the backward pass. Wider dynamic range (5 exponent bits) handles the larger variance in gradient magnitudes.
The practical result: H100 with FP8 Transformer Engine achieves approximately 2x the compute throughput of H100 BF16 for transformer workloads - from ~989 TFLOPS to ~1979 TFLOPS with dense arithmetic. Early adopters (Meta LLaMA 3, Mistral) report training cost savings of 30-40% compared to BF16 runs on the same hardware.
The risk is that FP8 training is still sensitive to model architecture. Layers with unusual activation distributions (sharp attention logits, poorly initialized layers) can see increased instability. The standard recommendation is to train the first 1-5% of steps in BF16 to stabilize the model, then switch to FP8.
Summary
Tensor Cores are hardware matrix multiply-accumulate units that execute MMA operations at the warp level. They achieve 16x the throughput of FP32 CUDA cores on A100 (312 vs 19.5 TFLOPS) by hardwiring the inner loop of matrix multiplication into silicon and operating on all 32 threads of a warp simultaneously.
Using them requires three things: the right dtype (FP16, BF16, TF32, or FP8), shapes that are multiples of 8 (preferably 16), and a GPU from Volta (2017) or later. One misaligned dimension silently kills Tensor Core usage across an entire matrix multiplication.
Mixed precision training keeps FP32 master weights in the optimizer state while computing in BF16 or FP16. BF16 is strongly preferred for modern training because its exponent range matches FP32, eliminating overflow and loss scaling complexity. FP8 on H100 doubles throughput again but requires the NVIDIA Transformer Engine for stable per-tensor scaling.
In production, always verify Tensor Core utilization via sm__pipe_tensor_cycles_active in Nsight Compute before declaring a run optimized. A healthy BF16 training workload should show this metric above 60%.
