Mixed Precision and Quantization Kernels
Reading time: ~45 min · Interview relevance: Very High · Target roles: CUDA Developer, ML Systems Engineer, Inference Engineer
The mistake is not quantizing to INT8. The mistake is quantizing only half the pipeline and calling it done.
The Benchmark That Lied
A team at a mid-size AI startup gets a directive: cut inference costs by 4x. The plan is straightforward - run the 70B parameter model in INT8 instead of FP16. INT8 is 8 bits, FP16 is 16 bits. Half the memory, twice the throughput on Tensor Cores. Math checks out.
They spend three weeks hand-rolling a custom INT8 linear layer. Weights are stored as INT8. They load them, dequantize to FP16, run the matmul in FP16. Benchmark day arrives. The INT8 kernel is 12% slower than the original FP16 kernel.
The post-mortem reveals three separate mistakes layered on top of each other. First, they dequantized weights before the matmul instead of fusing the dequantization into the GEMM epilogue - this means they're loading INT8 weights and immediately converting them to FP16 in registers, effectively paying FP16 memory bandwidth for a FP16 operation. Second, they left activations in FP32, not FP16, so their matmul was actually FP32 with a narrow bottleneck. Third, they never verified their kernel was issuing INT8 GEMM instructions at all - it was falling back to FP32 scalar operations because they mixed types incorrectly.
This lesson explains the full pipeline. Why quantization is only as fast as its slowest dtype. How to correctly fuse dequantization. How INT8 GEMM actually works on Tensor Cores. And how to debug all of this when the benchmark disagrees with your theory.
Why This Exists
The Arithmetic of Model Size
A GPT-3 scale model has 175 billion parameters. In FP32, that is 700 GB of weights - more than 8 A100 GPUs can hold. In FP16 or BF16, it is 350 GB. In INT8, it is 175 GB. In INT4, it is 87.5 GB.
Memory is the primary constraint in large model inference. When a model does not fit on the available GPUs, you either rent more hardware or reduce precision. Reducing precision also changes the arithmetic throughput - INT8 Tensor Core operations are roughly 2x the throughput of FP16 on A100, and 4x for INT4 matrix operations on some implementations.
But the gains only materialize if you implement the pipeline correctly. The entire datapath - from weight loading through matmul through activation - must use the target dtype. A single FP32 operation in the critical path can nullify all the gains from quantization.
The Two Distinct Wins
Quantization produces two different wins, and you need to keep them separate in your head:
Memory bandwidth win: Quantized weights are smaller, so loading them from HBM takes fewer bytes. This helps for memory-bandwidth-bound operations, which most large transformer inference is. Even if you dequantize back to FP16 immediately after loading, you still win if the dequantization cost is less than the bandwidth saved.
Compute win: Lower-precision arithmetic has higher throughput on Tensor Cores. INT8 GEMM is 2x faster than FP16 GEMM on A100 (in theory). FP8 GEMM on H100 is 2x faster than FP16. This win requires keeping the data in the quantized format all the way through the matrix multiply - you cannot dequantize before the GEMM and still get the compute win.
Most production systems get the memory bandwidth win. Getting the compute win requires careful kernel engineering.
Historical Context
Low-precision training and inference have a long history. The first major practical result was in 2017, when NVIDIA's Micikevicius et al. published "Mixed Precision Training" showing that FP16 training with FP32 master weights was stable for large networks. This became the standard training approach and is now the default in every modern framework.
INT8 inference emerged around 2018-2019. NVIDIA added INT8 Tensor Core support in Turing (RTX 2000 / T4). Google published quantization-aware training papers. TensorRT added INT8 calibration workflows.
The breakthrough for large language models came from Tim Dettmers at the University of Washington. His LLM.int8() paper (2022) showed that INT8 quantization for transformers was surprisingly tricky - a small fraction of activation channels had dramatically larger values, causing INT8 to catastrophically fail for large models. His solution was vector-wise quantization with outlier decomposition in FP16. This was implemented in the bitsandbytes library and became the first practical approach to running 65B parameter models on consumer GPUs.
GPTQ (2022) from Elias Frantar et al. at IST Austria pushed further - quantizing weights to INT4 or even 3-bit using second-order information from the Hessian. AWQ (2023) from MIT improved on GPTQ by using activation-aware weight quantization, preserving the most important weights at higher precision.
NVIDIA's Transformer Engine on Hopper (H100) introduced hardware-native FP8 with delayed scaling - the first GPU with dedicated FP8 Tensor Core ops and automatic scale management. This is the current frontier.
Numeric Formats: The Full Landscape
Understanding what you are working with is prerequisite to writing correct kernels.
FP32 - The Baseline
32-bit floating point: 1 sign bit, 8 exponent bits, 23 mantissa bits. Range: approximately to . Precision: about 7 decimal digits. This is the baseline that everything else is measured against. Tensor Core FP32 on A100 (TF32 mode): 312 TFLOPS. True FP32 (non-Tensor Core): 19.5 TFLOPS.
FP16 - The Training Standard
16-bit: 1 sign, 5 exponent, 10 mantissa. Range: to . The narrow exponent range causes overflow (values above 65504 become inf) and underflow (very small gradients become zero). Mixed precision training with FP32 loss scaling solves this. Peak throughput on A100: 312 TFLOPS (Tensor Core).
BF16 - The Training Workhorse
Brain Float 16: 1 sign, 8 exponent, 7 mantissa. Same exponent range as FP32, so it never overflows relative to FP32. Only 7 mantissa bits means less precision than FP16, but the range is more important than precision for training stability. BF16 has almost completely replaced FP16 for training since Ampere. Same Tensor Core throughput as FP16 on A100.
TF32 - The Invisible One
Tensor Float 32: used internally by A100 Tensor Cores when doing FP32 math. 1 sign, 8 exponent, 10 mantissa. You cannot store TF32 tensors - it is an internal compute format. When you call a FP32 matmul on A100, the hardware silently rounds inputs to TF32 and computes faster. You get 10x speedup but lose some precision. Disable with torch.backends.cuda.matmul.allow_tf32 = False if you need exact FP32.
FP8 - The New Standard
Two variants designed for different roles:
- FP8 E4M3: 4 exponent bits, 3 mantissa bits. Higher precision, lower range. Used for forward pass (activations and weights). Range: up to 448.
- FP8 E5M2: 5 exponent bits, 2 mantissa bits. Lower precision, higher range. Used for gradient computation. Range: up to 57344.
FP8 requires explicit scaling tensors because the range is so narrow. You store a float32 scale factor alongside each FP8 tensor. The Transformer Engine on H100 manages these scales automatically with "delayed scaling" - tracking the max absolute value from the previous iteration and computing the scale factor accordingly.
FP8 Tensor Core throughput on H100: 1979 TFLOPS. Compared to FP16 at 989 TFLOPS - a 2x gain.
INT8 - The Inference Standard
8-bit signed integer: range -128 to 127. For matrix multiply, you accumulate into INT32 to avoid overflow. INT8 Tensor Core throughput on A100: 624 TOPS. Compared to FP16 at 312 TFLOPS - 2x throughput if you stay in INT8 all the way through the GEMM.
The quantization formula for symmetric INT8:
where is the scale factor:
Dequantization:
INT4 - The Aggressive Case
4-bit integer: range -8 to 7. Hardware support is limited - current GPUs do not have dedicated INT4 Tensor Core ops on most hardware. INT4 is usually implemented by packing two INT4 values into one INT8 register and using dequantize-then-INT8-GEMM. The win is purely memory bandwidth: weights are 4x smaller than FP16.
The Pipeline Must Be Fully Quantized
This is the single most important concept in this lesson. Visualizing the pipeline helps:
The wrong path (dequantize before GEMM):
- Load INT8 weights from HBM - reads N bytes
- Convert INT8 to FP16 in registers - free (register operation)
- Run FP16 GEMM - uses FP16 Tensor Core throughput
- Net result: you saved memory bandwidth on the weight load, but you ran FP16 compute, not INT8
The right path (INT8 GEMM with epilogue dequantize):
- Load INT8 weights from HBM - reads N bytes (same as above)
- Load INT8 activations from HBM - reads N bytes (half of FP16!)
- Run INT8 GEMM accumulating into INT32 - uses INT8 Tensor Core throughput (2x faster)
- Apply scale factors in epilogue (fast, no extra memory traffic)
- Net result: both memory bandwidth win AND compute win
Quantization Strategies
Weights-Only Quantization (W4A16, W8A16)
Weights stored in INT4 or INT8, activations kept in FP16. The matmul runs in FP16 after dequantizing weights to FP16.
When to use it: Models where the weight-to-activation ratio is high, and where you are memory-bandwidth bound. Most LLM inference with batch size 1 or small batches is in this category. The key insight is that at small batch sizes, you execute very few FLOPs per weight byte loaded - so the bottleneck is HBM bandwidth, not compute.
Win: Weights are 2x (INT8) or 4x (INT4) smaller. Loading them from HBM is 2-4x faster.
Limitation: No compute win. You are still running FP16 GEMM.
Implementations: GPTQ (INT4), AWQ (INT4), bitsandbytes NF4.
Weight + Activation Quantization (W8A8)
Both weights and activations in INT8. INT8 GEMM on Tensor Cores. Scale factors applied as a post-GEMM epilogue operation.
When to use it: Large batch inference where compute is the bottleneck. Server-side inference with high concurrency.
Win: Both memory bandwidth and compute. 2x GEMM throughput, 2x weight bandwidth, 2x activation bandwidth.
Challenge: Activations are harder to quantize than weights. Weights are static - you can calibrate once offline. Activations change with every input. Dynamic quantization (compute scale on the fly) adds some overhead. Static quantization (calibrate scale on a representative dataset) requires calibration but is faster at runtime.
Implementation: SmoothQuant, LLM.int8() for outlier handling, cuBLAS cublasGemmEx with CUDA_R_8I.
FP8 W8A8 (Transformer Engine)
Both weights and activations in FP8 E4M3. GEMM runs on FP8 Tensor Cores on H100.
Win: 2x compute vs FP16 (989 TFLOPS to 1979 TFLOPS). Memory bandwidth win same as INT8. Better precision than INT8 (floating point vs fixed point). No need for zero-point terms.
How it works: Transformer Engine maintains a scale factor per tensor, updated each iteration using delayed scaling. The scale from iteration N-1 is applied to iteration N. A history of max absolute values smooths out outliers.
Quantization Granularity
How you choose scale factors dramatically affects both kernel complexity and model accuracy.
Per-Tensor Quantization
One scale factor per weight tensor (e.g., one scale for the entire 4096x4096 weight matrix). Simplest to implement. Lowest accuracy for large models because outliers in any row force the scale to be very small, underutilizing the dynamic range for most values.
Kernel complexity: trivial. One multiply at the end.
Per-Channel (Per-Row/Column) Quantization
One scale factor per output channel (row of weight matrix). Dramatically better accuracy for weights. Standard for almost all production INT8 quantization.
Kernel complexity: moderate. After GEMM, multiply each output row by the corresponding scale factor. This is the "epilogue" - it can be fused into the GEMM output store.
where is the weight scale for output channel and is the activation scale for input token .
Per-Group Quantization
One scale factor per group of G consecutive elements within a row. Typical group size: 128 (GPTQ default). Dramatically better accuracy than per-channel for INT4. The cost: more scale factors to store and more complex dequantization.
For a weight matrix of shape with group size , you have scale factors - for a 4096x4096 matrix with G=128, that is 131,072 FP16 scale factors vs 4096 for per-channel. Not free, but small relative to the weight matrix itself.
The dequantization formula per group:
where is the quantized value, is the zero-point (for asymmetric quantization), and is the group scale. Zero-point is what moves the quantization range from symmetric (centered at zero) to asymmetric (centered anywhere).
Code: INT8 Linear Layer from Scratch
First, a NumPy implementation to understand the mechanics clearly:
import numpy as np
def quantize_symmetric(x: np.ndarray, bits: int = 8) -> tuple:
"""Symmetric per-tensor quantization."""
qmax = 2 ** (bits - 1) - 1 # 127 for INT8
scale = np.max(np.abs(x)) / qmax
q = np.clip(np.round(x / scale), -qmax, qmax).astype(np.int8)
return q, scale
def quantize_per_channel(W: np.ndarray, bits: int = 8) -> tuple:
"""Per-channel (per-row) symmetric quantization of weight matrix."""
qmax = 2 ** (bits - 1) - 1
# Compute scale per output channel (row)
scales = np.max(np.abs(W), axis=1, keepdims=True) / qmax
Q = np.clip(np.round(W / scales), -qmax, qmax).astype(np.int8)
return Q, scales.squeeze() # scales shape: [out_channels]
def int8_linear_wrong(x_fp32: np.ndarray, W_int8: np.ndarray, W_scales: np.ndarray) -> np.ndarray:
"""
WRONG: Dequantize weights BEFORE matmul.
Memory bandwidth win only (loaded INT8, converted to FP32).
Compute runs at FP32 throughput - no compute win.
"""
# Dequantize weights to FP32
W_fp32 = W_int8.astype(np.float32) * W_scales[:, np.newaxis]
# FP32 matmul - no INT8 compute benefit
return x_fp32 @ W_fp32.T
def int8_linear_correct(x_int8: np.ndarray, x_scale: float,
W_int8: np.ndarray, W_scales: np.ndarray) -> np.ndarray:
"""
CORRECT: INT8 matmul, dequantize in epilogue.
Both activations and weights quantized - INT8 GEMM + scale epilogue.
"""
# INT8 matmul accumulates to INT32
acc_int32 = x_int8.astype(np.int32) @ W_int8.T.astype(np.int32)
# Dequantize: multiply by activation scale and per-channel weight scales
# Result shape: [batch, out_channels]
output = acc_int32.astype(np.float32) * x_scale * W_scales[np.newaxis, :]
return output
# Demonstrate correctness
np.random.seed(42)
batch, in_features, out_features = 4, 256, 512
x_fp32 = np.random.randn(batch, in_features).astype(np.float32) * 0.5
W_fp32 = np.random.randn(out_features, in_features).astype(np.float32) * 0.02
# Reference FP32 output
ref = x_fp32 @ W_fp32.T
# Quantize
x_int8, x_scale = quantize_symmetric(x_fp32)
W_int8, W_scales = quantize_per_channel(W_fp32)
# Correct INT8 linear
out_int8 = int8_linear_correct(x_int8, x_scale, W_int8, W_scales)
# Check error
error = np.mean(np.abs(ref - out_int8)) / np.mean(np.abs(ref))
print(f"Relative error (INT8 W8A8): {error:.4f}") # typically 0.001 to 0.01
PyTorch INT8 Linear with cuBLAS
import torch
import torch.nn as nn
class INT8Linear(nn.Module):
"""
INT8 quantized linear layer using torch's low-level quantization.
Uses W8A8: both weights and activations in INT8.
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Store weights as INT8, scales as FP32
self.register_buffer('weight_int8', torch.zeros(out_features, in_features, dtype=torch.int8))
self.register_buffer('weight_scales', torch.ones(out_features, dtype=torch.float32))
self.register_buffer('weight_zero_points', torch.zeros(out_features, dtype=torch.int32))
@classmethod
def from_float(cls, module: nn.Linear) -> 'INT8Linear':
"""Quantize a pretrained FP32/FP16 Linear layer to INT8."""
layer = cls(module.in_features, module.out_features)
W = module.weight.float() # work in FP32 for quantization math
# Per-channel symmetric quantization
# Scale: max abs value per row / 127
scales = W.abs().max(dim=1).values / 127.0
scales = scales.clamp(min=1e-8) # avoid div by zero
# Quantize: round and clamp to [-128, 127]
W_int8 = (W / scales.unsqueeze(1)).round().clamp(-128, 127).to(torch.int8)
layer.weight_int8.copy_(W_int8)
layer.weight_scales.copy_(scales)
if module.bias is not None:
layer.register_buffer('bias', module.bias.float())
else:
layer.bias = None
return layer
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Quantize activations (per-tensor, dynamic)
x_scale = x.abs().max() / 127.0
x_scale = x_scale.clamp(min=1e-8)
x_int8 = (x / x_scale).round().clamp(-128, 127).to(torch.int8)
# INT8 GEMM: use torch._int_mm for true INT8 tensor core ops
# Input: [batch, in_features] INT8
# Weight: [out_features, in_features] INT8 -> need [in_features, out_features]
# Output: [batch, out_features] INT32
batch_shape = x_int8.shape[:-1]
x_2d = x_int8.reshape(-1, self.in_features)
# torch._int_mm: pure INT8 matmul, returns INT32
# Available on CUDA, requires CUDA >= 11.8
out_int32 = torch._int_mm(x_2d, self.weight_int8.T)
# Dequantize in epilogue: INT32 -> FP32 -> scale
# out = out_int32 * x_scale * weight_scales (broadcast per output channel)
out = out_int32.float() * x_scale * self.weight_scales.unsqueeze(0)
if self.bias is not None:
out = out + self.bias.unsqueeze(0)
return out.reshape(*batch_shape, self.out_features)
# Usage
linear_fp32 = nn.Linear(4096, 4096).cuda()
linear_int8 = INT8Linear.from_float(linear_fp32).cuda()
x = torch.randn(16, 4096).cuda()
with torch.no_grad():
out_fp32 = linear_fp32(x)
out_int8 = linear_int8(x.float())
rel_error = (out_fp32 - out_int8).abs().mean() / out_fp32.abs().mean()
print(f"Relative error: {rel_error:.4f}")
Triton Kernel: Fused Dequantize + GEMM
The key insight in production INT8 kernels is that the dequantization must be fused with the GEMM. Here is a Triton kernel for the GPTQ-style dequantize-and-multiply pattern:
import triton
import triton.language as tl
import torch
@triton.jit
def dequant_gemm_kernel(
# Input pointers
X_ptr, # [M, K] FP16 activations
W_ptr, # [K//8, N] packed INT4 weights (8 values per int32)
S_ptr, # [K//group_size, N] FP16 scales per group
Z_ptr, # [K//group_size, N] INT8 zero-points per group
# Output
Y_ptr, # [M, N] FP16 output
# Dimensions
M, N, K,
group_size: tl.constexpr,
# Tile sizes (autotuned)
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""
Fused INT4 dequantize + FP16 GEMM kernel (GPTQ-style weight-only quantization).
Weights are stored as packed INT4 (8 values per INT32).
Dequantization happens in registers during the matmul loop.
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# Tile offsets
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# Accumulator in FP32 for numerical stability
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# K-loop: iterate over K dimension in tiles
for k_start in range(0, K, BLOCK_K):
k_offs = k_start + offs_k
# Load FP16 activations: [BLOCK_M, BLOCK_K]
x_ptrs = X_ptr + offs_m[:, None] * K + k_offs[None, :]
x_mask = (offs_m[:, None] < M) & (k_offs[None, :] < K)
x = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float16)
# Load packed INT4 weights: [BLOCK_K//8, BLOCK_N]
# Each int32 holds 8 INT4 values (low-to-high nibbles)
w_packed_ptrs = W_ptr + (k_offs[None, :] // 8) * N + offs_n[:, None]
w_packed_mask = (k_offs[None, :] < K) & (offs_n[:, None] < N)
w_packed = tl.load(w_packed_ptrs, mask=w_packed_mask, other=0)
# Unpack INT4 values from INT32
# Each element of k_offs selects one nibble from the packed int32
nibble_idx = k_offs % 8 # which nibble (0-7) within each int32
w_int4 = (w_packed >> (nibble_idx[None, :] * 4)) & 0xF # extract 4 bits
# Load group scales and zero-points for this K tile
group_idx = k_offs // group_size
s_ptrs = S_ptr + group_idx[None, :] * N + offs_n[:, None]
z_ptrs = Z_ptr + group_idx[None, :] * N + offs_n[:, None]
s = tl.load(s_ptrs, mask=w_packed_mask, other=1.0).to(tl.float16)
z = tl.load(z_ptrs, mask=w_packed_mask, other=0).to(tl.float16)
# Dequantize: (int4_value - zero_point) * scale
w_fp16 = (w_int4.to(tl.float16) - z) * s
# GEMM tile: x [BLOCK_M, BLOCK_K] @ w_fp16.T [BLOCK_K, BLOCK_N]
acc += tl.dot(x, w_fp16)
# Store output
y_ptrs = Y_ptr + offs_m[:, None] * N + offs_n[None, :]
y_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(y_ptrs, acc.to(tl.float16), mask=y_mask)
def dequant_gemm(X: torch.Tensor, W_packed: torch.Tensor,
scales: torch.Tensor, zeros: torch.Tensor,
group_size: int = 128) -> torch.Tensor:
"""Launch the fused INT4 dequant + GEMM kernel."""
M, K = X.shape
N = W_packed.shape[1]
output = torch.empty(M, N, device=X.device, dtype=torch.float16)
# Grid: one program per output tile
grid = lambda meta: (
triton.cdiv(M, meta['BLOCK_M']),
triton.cdiv(N, meta['BLOCK_N']),
)
dequant_gemm_kernel[grid](
X, W_packed, scales, zeros, output,
M, N, K,
group_size=group_size,
BLOCK_M=16, BLOCK_N=64, BLOCK_K=32, # tune these
)
return output
Benchmark: FP16 vs INT8 vs Weights-Only INT4
import torch
import time
def benchmark_precision(M: int, N: int, K: int, dtype: str, n_warmup: int = 10, n_iter: int = 100):
"""
Benchmark different precision strategies for a single linear layer.
Shows the practical speedup (or slowdown) from each approach.
"""
device = 'cuda'
if dtype == 'fp16':
x = torch.randn(M, K, dtype=torch.float16, device=device)
W = torch.randn(N, K, dtype=torch.float16, device=device)
fn = lambda: torch.nn.functional.linear(x, W)
elif dtype == 'bf16':
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
W = torch.randn(N, K, dtype=torch.bfloat16, device=device)
fn = lambda: torch.nn.functional.linear(x, W)
elif dtype == 'int8_w8a8':
x_int8 = torch.randint(-128, 127, (M, K), dtype=torch.int8, device=device)
W_int8 = torch.randint(-128, 127, (N, K), dtype=torch.int8, device=device)
# torch._int_mm requires contiguous 2D int8 tensors
fn = lambda: torch._int_mm(x_int8, W_int8.T)
elif dtype == 'int8_wrong':
# The WRONG approach: dequantize before matmul
x = torch.randn(M, K, dtype=torch.float16, device=device)
W_int8 = torch.randint(-128, 127, (N, K), dtype=torch.int8, device=device)
scales = torch.ones(N, dtype=torch.float16, device=device)
# Simulate dequantize-then-multiply
W_fp16 = W_int8.to(torch.float16) * scales.unsqueeze(1)
fn = lambda: torch.nn.functional.linear(x, W_fp16)
# Warmup
for _ in range(n_warmup):
fn()
torch.cuda.synchronize()
# Benchmark
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(n_iter):
fn()
end.record()
torch.cuda.synchronize()
elapsed_ms = start.elapsed_time(end) / n_iter
return elapsed_ms
# Run benchmark for a 4096x4096 matmul (LLM-sized)
M, N, K = 64, 4096, 4096 # batch=64, typical inference
results = {}
for dtype in ['fp16', 'bf16', 'int8_w8a8', 'int8_wrong']:
results[dtype] = benchmark_precision(M, N, K, dtype)
print(f"{dtype:20s}: {results[dtype]:.3f} ms")
fp16_time = results['fp16']
for dtype, t in results.items():
print(f"{dtype:20s}: {fp16_time/t:.2f}x vs FP16")
# Expected output on A100:
# fp16 : 0.187 ms
# bf16 : 0.189 ms (1.01x vs FP16)
# int8_w8a8 : 0.091 ms (2.06x vs FP16) <- correct INT8
# int8_wrong : 0.211 ms (0.89x vs FP16) <- SLOWER than FP16
The bitsandbytes Library: LLM.int8() and QLoRA
The bitsandbytes library by Tim Dettmers is the most widely used quantization library for LLMs. Understanding what it actually does helps you debug production deployments.
LLM.int8() - Outlier Decomposition
Large transformer models (>6B parameters) have a phenomenon called activation outliers: a small fraction (about 0.1%) of activation channels have values 100x larger than the rest. These outliers break naive INT8 quantization - the scale is forced large to accommodate the outlier, and all other values get quantized to tiny numbers losing all precision.
LLM.int8()'s solution:
- On each forward pass, identify columns of X (activation channels) where any value exceeds a threshold (typically 6.0)
- Extract those outlier columns and compute them in FP16
- Quantize the remaining normal columns to INT8 and compute in INT8
- Sum the two partial results
# Conceptual LLM.int8() forward pass
def llm_int8_linear(X: torch.Tensor, W: torch.Tensor, threshold: float = 6.0):
"""
Decompose linear into: INT8 regular + FP16 outlier parts.
"""
# Find outlier columns (any row value exceeds threshold)
col_max = X.abs().max(dim=0).values
outlier_mask = col_max > threshold # [in_features] bool mask
# Split into outlier and non-outlier columns
X_outlier = X[:, outlier_mask].float() # FP16/FP32 path
X_normal = X[:, ~outlier_mask] # INT8 path
W_outlier = W[:, outlier_mask].float() # corresponding weight columns
W_normal = W[:, ~outlier_mask]
# FP16 computation for outliers
out_outlier = X_outlier @ W_outlier.T
# INT8 computation for normal columns
x_scale = X_normal.abs().max() / 127.0
x_int8 = (X_normal / x_scale).round().clamp(-128, 127).to(torch.int8)
w_scales = W_normal.abs().max(dim=1).values / 127.0
w_int8 = (W_normal / w_scales.unsqueeze(1)).round().clamp(-128, 127).to(torch.int8)
out_int32 = torch._int_mm(x_int8, w_int8.T)
out_int8 = out_int32.float() * x_scale * w_scales
return out_outlier + out_int8
QLoRA Double Quantization
QLoRA (Dettmers et al., 2023) introduces NF4 (NormalFloat4), a 4-bit format designed for normally distributed weights. The insight: if you know weights are approximately Gaussian, you can design quantization levels to minimize expected quantization error.
QLoRA also introduces "double quantization" - quantizing the quantization constants themselves:
- Quantize weights with FP32 scale constants - saves 4 bits per param but adds scale overhead
- Quantize the FP32 scale constants to FP8 with a second scale constant
- Net result: about 4.13 bits per parameter instead of 4 + 32/group_size
# Double quantization conceptual implementation
def double_quantize(W: torch.Tensor, group_size: int = 64):
"""
Step 1: Quantize W to NF4 with FP32 scales (first quantization)
Step 2: Quantize those FP32 scales to FP8 with a single FP32 meta-scale
"""
N, K = W.shape
num_groups = K // group_size
# First quantization: W -> NF4 with per-group FP32 scales
# (simplified - actual NF4 uses non-uniform quantization levels)
scales_fp32 = W.reshape(N, num_groups, group_size).abs().max(dim=-1).values
W_nf4 = quantize_to_nf4(W, scales_fp32) # returns 4-bit encoded tensor
# Second quantization: scales_fp32 -> FP8 with single meta-scale
meta_scale = scales_fp32.abs().max() / 127.0 # FP8-ish range
scales_fp8 = (scales_fp32 / meta_scale).round().clamp(-128, 127).to(torch.int8)
# Storage: 4 bits/param + 1 byte/64 params (scale) + 4 bytes/256 params (meta)
# Effective: 4 + 8/64 + 32/256 = 4 + 0.125 + 0.125 = 4.25 bits/param
return W_nf4, scales_fp8, meta_scale
Precision Selection Flowchart
Use this when deciding which dtype to use for a new deployment:
The FP8 Transformer Engine: How It Actually Works
NVIDIA's Transformer Engine (TE) on Hopper automates FP8 scaling. Understanding the mechanism is important for debugging.
Delayed Scaling
The core idea: you cannot know the max absolute value of the current iteration's activations before computing them. TE solves this with delayed scaling:
- Keep a history of the last N (default 16) max absolute values per tensor
- Compute the scale for iteration T+1 based on iteration T's max
- Track the "amax history" per tensor - if amax grows suddenly, the scale is wrong for one iteration
The scale computation:
where max_fp8_value for E4M3 is 448.0.
# Transformer Engine usage
import transformer_engine.pytorch as te
# Replace nn.Linear with te.Linear - automatic FP8 management
model = te.Linear(4096, 4096, bias=False)
# Enable FP8 during forward pass
with te.fp8_autocast(enabled=True):
output = model(input_tensor)
# TE automatically:
# 1. Quantizes input to FP8 E4M3 using current scale
# 2. Runs FP8 GEMM on H100 Tensor Cores (1979 TFLOPS)
# 3. Returns output in BF16/FP16
# 4. Updates amax history for next iteration
Manual FP8 Scaling
For custom kernels, you need to manage scaling yourself:
import torch
def fp8_e4m3_quantize(x: torch.Tensor) -> tuple:
"""Manually quantize to FP8 E4M3 with scale tracking."""
MAX_FP8_E4M3 = 448.0
# Compute scale from current tensor statistics
amax = x.abs().max().float()
scale = MAX_FP8_E4M3 / (amax + 1e-12)
scale_inv = 1.0 / scale
# Scale and convert to FP8
x_scaled = (x.float() * scale).clamp(-MAX_FP8_E4M3, MAX_FP8_E4M3)
# Note: torch.float8_e4m3fn available in PyTorch >= 2.1
x_fp8 = x_scaled.to(torch.float8_e4m3fn)
return x_fp8, scale_inv # scale_inv needed for dequantization
def fp8_gemm(X_fp8: torch.Tensor, W_fp8: torch.Tensor,
x_scale_inv: float, w_scale_inv: float) -> torch.Tensor:
"""
FP8 GEMM with scale correction.
On H100, this uses FP8 Tensor Cores.
Output is scaled BF16.
"""
# torch._scaled_mm handles FP8 GEMM with scales
# Available on H100 with CUDA >= 11.8 + PyTorch >= 2.1
output = torch._scaled_mm(
X_fp8, W_fp8.T,
scale_a=torch.tensor(x_scale_inv),
scale_b=torch.tensor(w_scale_inv),
out_dtype=torch.bfloat16
)
return output
Production Engineering Notes
Quantization Calibration Dataset
Static quantization requires calibration data to compute activation scales. The choice of calibration data matters more than most practitioners realize.
Use a representative sample of your actual workload, not the training set. For LLMs, a few hundred examples from your production distribution are usually sufficient. Using only one type of input (e.g., all short questions) will set scales too narrow for long documents.
def calibrate_int8_model(model, calibration_loader, num_batches=100):
"""
Run calibration to compute activation statistics for static INT8 quantization.
Uses PyTorch's quantization observer hooks.
"""
from torch.ao.quantization import prepare, convert
from torch.ao.quantization import get_default_qconfig
# Set up observers on all linear layers
model.qconfig = get_default_qconfig('x86') # or 'fbgemm' for server
prepare(model, inplace=True)
# Run calibration batches
model.eval()
with torch.no_grad():
for i, (inputs, _) in enumerate(calibration_loader):
if i >= num_batches:
break
model(inputs)
# Convert to INT8 using observed statistics
convert(model, inplace=True)
return model
Evaluating Quantization Quality
Always measure accuracy degradation, not just speed:
def evaluate_quantization(fp16_model, int8_model, test_loader):
"""Compare FP16 and INT8 model outputs on held-out test set."""
max_rel_error = 0.0
avg_rel_error = 0.0
n_samples = 0
fp16_model.eval()
int8_model.eval()
with torch.no_grad():
for x, _ in test_loader:
x = x.cuda()
out_fp16 = fp16_model(x.half())
out_int8 = int8_model(x.half())
# Relative error per output
rel_err = (out_fp16 - out_int8.half()).abs() / (out_fp16.abs() + 1e-6)
max_rel_error = max(max_rel_error, rel_err.max().item())
avg_rel_error += rel_err.mean().item()
n_samples += 1
avg_rel_error /= n_samples
print(f"Max relative error: {max_rel_error:.4f}")
print(f"Avg relative error: {avg_rel_error:.4f}")
# For production: avg relative error should be < 0.01 (1%)
# Max relative error < 0.05 (5%) is usually acceptable
Common Mistakes
:::danger Dequantize Before GEMM Never dequantize quantized weights before the matrix multiply. If you convert INT8 weights to FP16 before calling the GEMM, you lose the INT8 compute win entirely and may not even get the full bandwidth win. Always design your kernel to dequantize as part of the GEMM epilogue or use a library function (cublasGemmEx, torch._int_mm) that performs the operation in quantized format natively. :::
:::danger Mixed Dtypes in the Critical Path
The slowest dtype in your pipeline determines your speed. If activations are FP32 and weights are INT8, your GEMM must upcast to FP32 before computation. Always verify that every tensor in the matmul critical path is the dtype you intend. Use tensor.dtype checks in debug builds.
:::
:::warning Per-Tensor vs Per-Channel Scale Choice Using per-tensor quantization for weights when per-channel is available is leaving accuracy on the table. Weight matrices in transformers have large per-channel variance - different output neurons have drastically different weight magnitudes. Per-tensor scale forces a compromise that degrades quality for many channels. Default to per-channel for weights unless your target hardware cannot handle it. :::
:::warning Activation Outliers in Large Models LLMs with more than 6B parameters almost always have activation outliers. Naive INT8 quantization without outlier handling will produce noticeably degraded outputs. Use SmoothQuant, LLM.int8() with outlier decomposition, or calibrate with a large enough dataset to detect outlier channels and apply per-channel scaling. :::
:::warning Forgetting to Quantize the KV Cache In transformer inference, the KV cache can be as large as the model weights for long sequences. INT8 or FP8 KV cache quantization is a separate step from weight quantization. Many engineers quantize weights but forget the KV cache, leaving significant memory savings unrealized for long-context inference. :::
Interview Questions and Answers
Q1: A colleague proposes quantizing model weights to INT8 to get 2x speedup. What questions do you ask before agreeing?
You need to understand which half of the INT8 benefit they are expecting. Ask: (1) What is the batch size in production? At small batch sizes (1-16), inference is memory-bandwidth bound, so weight-only quantization (W8A16 or W4A16) gives the bandwidth win without the complexity of full INT8 GEMM. At large batch sizes, compute becomes the bottleneck and you need W8A8 to get the throughput win. (2) How are they planning to handle activation quantization? W8A8 requires calibration data and careful outlier handling. (3) What GPU hardware are they targeting? H100 with Transformer Engine makes FP8 a better option than INT8 for newer deployments. (4) What accuracy degradation is acceptable? INT8 static quantization typically causes 0.5-1% task metric degradation; INT4 often causes 1-3%. These need to be measured, not assumed.
Q2: Explain when quantization saves bandwidth vs compute, and why these are distinct wins.
Bandwidth win: quantized weights are smaller on disk and in HBM. Loading a 4096x4096 INT8 weight matrix takes 16 MB vs 32 MB for FP16 - half the HBM bandwidth. This win applies regardless of whether you run INT8 GEMM or dequantize to FP16. For memory-bandwidth-bound workloads (small batch inference), this is often the dominant benefit.
Compute win: INT8 Tensor Core GEMM has 2x the throughput of FP16 GEMM on A100 (624 TOPS vs 312 TFLOPS). But you only get this win if the GEMM itself runs in INT8 - both inputs must be INT8 and you must use a kernel that issues INT8 tensor core instructions (cublasGemmEx with CUDA_R_8I, or torch._int_mm). Dequantizing before the GEMM eliminates this win entirely. For compute-bound workloads (large batch inference), the compute win is what matters.
Q3: Why does dequantize-before-GEMM fail to capture the compute speedup?
The compute speedup from INT8 comes from Tensor Cores executing integer multiply-accumulate operations at 2x the rate of FP16 MACs. This requires the input data to be in INT8 format when it enters the Tensor Core. If you convert INT8 weights to FP16 in registers before the GEMM, the Tensor Core sees FP16 inputs and executes FP16 operations. The INT8 hardware path is never engaged. The only benefit retained is the bandwidth win on weight loading (you read INT8 from HBM, which is half the bytes), but you immediately expand to FP16 in registers, doubling the register file pressure. In practice, this overhead often reduces the bandwidth win, sometimes making the INT8 approach slower than pure FP16.
Q4: What is the Transformer Engine's delayed scaling approach and why is it necessary?
FP8 E4M3 has a range of only 0 to 448. For a tensor with max absolute value of 1000, you would need a scale of 448/1000 = 0.448 applied before quantizing to prevent overflow. But you cannot know the max absolute value of the current iteration's activations before computing them - it is a chicken-and-egg problem.
Delayed scaling solves this by using the statistics from the previous iteration. Transformer Engine maintains an "amax history" - a rolling window of the maximum absolute values seen in each tensor over the last 16 iterations. The scale for iteration T+1 is computed from the maximum over this history: scale = 448.0 / max(amax_history). This works well when tensors have stable statistics. The one failure mode is a sudden spike: if iteration T+1 has an amax 100x larger than anything in the history, you will overflow for one iteration and then the scale will adjust. TE handles this with NaN detection and scale rollback.
Q5: Explain per-group quantization for INT4. Why does it improve accuracy over per-channel, and what is the engineering tradeoff?
Per-channel quantization assigns one scale factor per output channel (row of weight matrix). For a 4096x4096 weight matrix, that is 4096 scale factors. The problem: within a single row, weights may still have high variance across different groups of columns - weights corresponding to different feature clusters may have very different magnitudes. This within-row variance means some groups are over-quantized (scale too large) and some are under-quantized.
Per-group quantization assigns one scale per group of G consecutive elements within each row (e.g., G=128). For a 4096x4096 matrix with G=128, you have 4096 * (4096/128) = 131,072 scale factors. Each group of 128 weights gets its own scale, dramatically reducing the maximum quantization error per group.
Engineering tradeoff: (1) Scale storage cost: 131,072 FP16 values = 256 KB vs 8 KB for per-channel. Small relative to the 2 MB weight matrix (for INT4). (2) Dequantization cost: in the kernel, each thread must load the correct group scale for its element, requiring an additional indexed memory access per BLOCK_K. (3) Accuracy benefit: GPTQ INT4 per-group (G=128) typically achieves < 1% perplexity degradation on LLMs vs 3-5% for per-tensor INT4. The accuracy gain almost always justifies the modest overhead.
Q6: What are activation outliers in LLMs and how does LLM.int8() handle them?
Activation outliers are a phenomenon discovered by Dettmers et al. where a small number of hidden state channels (roughly 0.1% of all channels) consistently have values 10-100x larger than the typical channel. These outliers appear in the same channels across all tokens and all inputs - they are a structural property of the trained weight matrices, not random noise.
The problem for INT8 quantization: the scale factor must accommodate the maximum value in the tensor. If most values are in [-1, 1] but 0.1% of channels reach 100, the scale is set to 100/127 = 0.787. Values of 0.01 then quantize to round(0.01/0.787) = 0, losing all information. The model output degrades severely.
LLM.int8() decomposes the linear operation: identify the outlier columns, extract them, compute them in FP16 at full precision. Compute all remaining normal columns in INT8. Sum the results. The key insight is that even though outlier channels are individually large, there are so few of them (0.1%) that computing them in FP16 adds very little overhead. The INT8 path handles 99.9% of compute and delivers near-full INT8 throughput, while the FP16 path handles outliers without quality loss.
Summary
Mixed precision and quantization kernels are the foundation of efficient LLM inference. The core insight is that quantization only helps when the entire pipeline is quantized - a single FP32 operation in the critical path can negate all gains.
The hierarchy of approaches, from simplest to most complex:
- BF16 training: standard, no calibration needed, nearly free accuracy
- FP16 inference: standard for most deployments
- W8A16 (weights-only INT8 or INT4): memory bandwidth win, works for small batch inference
- W8A8 INT8: both memory and compute win, requires calibration and outlier handling
- FP8 on H100: state of the art, managed by Transformer Engine
Master the distinction between bandwidth wins and compute wins. Design kernels to fuse dequantization into the GEMM epilogue. Always benchmark with realistic batch sizes. Measure accuracy degradation, not just throughput.
