:::tip ๐ฎ Interactive Playground Visualize this concept: Try the Quantisation Explorer demo on the EngineersOfAI Playground - no code required. :::
Quantization Deep Dive: Making Large Models Fit in Small GPUs
The $16,000-Per-Month Question That Changed Everythingโ
The year is 2022. A startup has been running LLM inference on two A100 80GB GPUs - a $16,000 per month cloud expense - for a product that barely covers its own infrastructure costs. Their largest model is a 30B-parameter fine-tuned variant that consumes 62 GB of GPU memory in FP16, leaving only 18 GB for the KV cache and activations that serve actual requests. Their maximum concurrent users: 4. Their burn rate: catastrophic.
A new engineer joins and asks a simple question: "Does every weight actually need 16-bit precision?" The team spends a week running experiments. They quantize the model to 8-bit. Accuracy drops less than 0.5% on their task. Memory drops to 32 GB. They move from two A100s to one, cutting infrastructure costs in half. Then they try 4-bit. Accuracy drops 1.8%. Memory drops to 16 GB. They move to a single A10G at $1.50 per hour.
Total cost reduction: 89%. Total capability lost: less than 2% on the metrics that mattered to their customers. The startup survives its next funding round. This story is real - it reflects what happened at dozens of AI startups in the 2022-2023 period when quantization libraries became practical. Understanding how quantization works lets you replicate this kind of result for your own systems.
Why Weights Don't Need Full Precisionโ
Machine learning models store weights as floating-point numbers. The industry standard for training is FP32 (32-bit floats), though most modern LLMs train in BF16 or FP16 (16-bit). A single parameter takes 2 bytes in FP16. A 70B parameter model takes 140 GB.
FP16 can represent numbers from approximately -65,504 to +65,504, with roughly 3-4 decimal digits of precision - more than enough for neural network weights, which typically live in a much smaller range. The key question is: what is the minimum precision needed to preserve model behavior?
The insight: neural network weights are not uniformly distributed. They cluster near zero in an approximately Gaussian distribution. A representation that efficiently covers the dense central region and handles sparse outliers can do much better than using the full floating-point dynamic range.
FP16 can represent ~32,768 distinct positive values:
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
0 +65504
Most LLM weight values cluster here:
โโโโโโโโโโโโโ
โ โ โ โ โ โ โ โ โ โ โ โ โ โ โ โ โ โ โ โ
-0.5 0 +0.5
INT8 provides only 256 distinct values - but 256 levels covering [-0.5, 0.5]
gives resolution of ~0.004, which is more than adequate for most weights.
INT4 provides only 16 distinct values - surprisingly sufficient for most weights
when combined with per-group scaling and careful calibration (GPTQ, AWQ).
The Mathematical Foundation: Absmax Quantizationโ
The simplest quantization scheme is absmax (absolute maximum) quantization. It maps a floating-point tensor to integers by finding the maximum absolute value and scaling to fill the integer range:
Where is 127 for INT8 (signed, range -128 to 127) or 7 for INT4 (range -8 to 7).
To dequantize (reconstruct floating-point from integer):
The quantization error per weight is , bounded by half the step size . For INT8, this maximum error is .
import torch
import numpy as np
from typing import Tuple
def absmax_quantize_int8(
weight: torch.Tensor,
) -> Tuple[torch.Tensor, float]:
"""
Absmax quantization to INT8.
Maps the full weight range to [-127, 127].
Returns (quantized_int8_tensor, scale_factor).
Dequantize: weight_fp โ quantized / scale
This is the simplest possible quantization - one scale per entire tensor.
Works well when the weight distribution has no extreme outliers.
"""
Q_MAX = 127.0
abs_max = weight.abs().max().item()
if abs_max == 0:
return torch.zeros_like(weight, dtype=torch.int8), 1.0
scale = Q_MAX / abs_max
# Quantize: scale up, round to nearest integer, clamp to valid range
quantized = torch.round(weight * scale).clamp(-127, 127).to(torch.int8)
return quantized, scale
def absmax_dequantize(
quantized: torch.Tensor,
scale: float,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Reconstruct approximate floating-point values from INT8."""
return quantized.to(output_dtype) / scale
def measure_quantization_error(
original: torch.Tensor,
reconstructed: torch.Tensor,
) -> dict:
"""Compute error statistics between original and quantized-then-dequantized tensor."""
error = original.float() - reconstructed.float()
abs_error = error.abs()
rel_error = abs_error / (original.float().abs() + 1e-8)
return {
"max_absolute_error": abs_error.max().item(),
"mean_absolute_error": abs_error.mean().item(),
"rmse": (error ** 2).mean().sqrt().item(),
"mean_relative_error_pct": rel_error.mean().item() * 100,
"snr_db": 10 * torch.log10(
(original.float() ** 2).mean() /
((error ** 2).mean() + 1e-12)
).item(),
}
# Demonstration on a realistic LLM weight matrix
if __name__ == "__main__":
torch.manual_seed(42)
# Simulate a realistic LLM weight distribution
weight = torch.randn(4096, 4096) * 0.02 # Typical scale for LLM weights
q_weight, scale = absmax_quantize_int8(weight)
reconstructed = absmax_dequantize(q_weight, scale)
# Compare storage requirements
original_bytes = weight.element_size() * weight.numel() # FP32: 4 bytes/param
quantized_bytes = q_weight.element_size() * q_weight.numel() # INT8: 1 byte/param
print(f"Original (FP32): {original_bytes / 1024**2:.1f} MB")
print(f"Quantized (INT8): {quantized_bytes / 1024**2:.1f} MB")
print(f"Compression: {original_bytes / quantized_bytes:.0f}x")
print(f"\nQuantization error statistics:")
for metric, value in measure_quantization_error(weight, reconstructed).items():
print(f" {metric:30s}: {value:.6f}")
The Outlier Problem: Why Absmax Fails at Scaleโ
Absmax works well when weights are normally distributed without extreme outliers. But in large language models - particularly attention layers - a small fraction of weights have extremely large values. These outliers emerge systematically and become more pronounced with scale: models above approximately 6.7B parameters consistently develop outlier features in their activation dimensions.
Here is what happens when outliers dominate the scale computation:
Example: Weight tensor with 0.5% outliers
99.5% of weights: between -0.1 and +0.1
0.5% of weights: values up to ยฑ15 (outliers)
Absmax scale = 127 / 15 = 8.47
This means:
weight = 0.01 โ quantized = round(0.01 ร 8.47) = round(0.085) = 0
weight = 0.05 โ quantized = round(0.05 ร 8.47) = round(0.42) = 0
weight = 0.10 โ quantized = round(0.10 ร 8.47) = round(0.85) = 1
weight = 0.15 โ quantized = round(0.15 ร 8.47) = round(1.27) = 1
The majority of weights (0 to 0.09) all map to 0.
All precision is wasted on the outlier's range.
Signal-to-noise ratio collapses for the 99.5% normal weights.
This is the fundamental quantization accuracy problem. Three main solutions have emerged, each addressing it differently:
Solution 1: Percentile Clippingโ
Clip to a high percentile (99.9%, 99.99%) instead of the absolute maximum. Outliers are clipped to the threshold value. Normal weights get the full quantization range:
def percentile_quantize(
weight: torch.Tensor,
percentile: float = 99.9,
n_bits: int = 8,
) -> Tuple[torch.Tensor, float, float]:
"""
Clip weights to a percentile before quantizing.
Returns (quantized, scale, clip_threshold).
The clip_threshold is the absolute value at the specified percentile.
Values beyond this are clipped - small information loss for outliers,
large precision gain for the 99.9% of normal weights.
Args:
percentile: Higher = less clipping, but more outlier influence.
99.9% is a common default for INT8.
99.0% is more aggressive and better for INT4.
"""
Q_MAX = 2 ** (n_bits - 1) - 1 # 127 for INT8, 7 for INT4 signed
w_float = weight.float()
clip_val = torch.quantile(w_float.abs(), percentile / 100.0).item()
clip_val = max(clip_val, 1e-7) # Prevent division by zero
# Clip outliers to the percentile threshold
clipped = w_float.clamp(-clip_val, clip_val)
# Quantize the clipped range with full resolution
scale = Q_MAX / clip_val
quantized = torch.round(clipped * scale).clamp(-Q_MAX, Q_MAX)
if n_bits == 8:
quantized = quantized.to(torch.int8)
else:
quantized = quantized.to(torch.int8) # Store INT4 in INT8 (packing needed for production)
return quantized, scale, clip_val
def demonstrate_outlier_impact():
"""
Show numerically why outliers destroy quantization quality
and how percentile clipping repairs it.
"""
torch.manual_seed(0)
# 998 normal weights + 2 extreme outliers
normal_weights = torch.randn(998) * 0.1
outliers = torch.tensor([8.0, -7.5])
weights = torch.cat([normal_weights, outliers])
# Method 1: Absmax - scale dominated by outliers
q_absmax, scale_absmax = absmax_quantize_int8(weights)
rec_absmax = absmax_dequantize(q_absmax, scale_absmax)
# Method 2: Percentile clipping - outliers don't dominate
q_pct, scale_pct, clip = percentile_quantize(weights, percentile=99.5)
rec_pct = q_pct.float() / scale_pct
# Measure error on normal weights only (excluding the 2 outliers)
normal_mask = torch.cat([
torch.ones(998, dtype=torch.bool),
torch.zeros(2, dtype=torch.bool)
])
err_absmax = measure_quantization_error(weights[normal_mask], rec_absmax[normal_mask])
err_pct = measure_quantization_error(weights[normal_mask], rec_pct[normal_mask])
print("Impact of outliers on INT8 quantization quality (normal weights only):")
print(f" Absmax: MAE = {err_absmax['mean_absolute_error']:.6f}")
print(f" Percentile (99.5%): MAE = {err_pct['mean_absolute_error']:.6f}")
improvement = err_absmax['mean_absolute_error'] / err_pct['mean_absolute_error']
print(f" Improvement: {improvement:.1f}x better for the 99.8% of normal weights")
Solution 2: LLM.int8() - Mixed-Precision via Outlier Decompositionโ
Tim Dettmers' LLM.int8() (2022, implemented in bitsandbytes) takes a different approach: instead of clipping or ignoring outliers, handle them separately in FP16 while quantizing normal weights to INT8.
The key insight: outliers appear in the same input dimensions consistently across different inputs. By detecting which dimensions are outlier-prone from calibration data, you can permanently split the weight matrix into an outlier slice (kept in FP16) and a normal slice (quantized to INT8). The runtime adds a second matmul for the outlier slice, but since it covers only ~1% of columns, the overhead is small.
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
def load_model_int8(
model_name: str,
threshold: float = 6.0,
) -> tuple:
"""
Load a model with LLM.int8() quantization.
The threshold controls the outlier detection aggressiveness.
Columns where any |activation| > threshold are treated as outlier columns
and handled in FP16 instead of INT8.
Lower threshold = more columns in FP16 (more accurate, less memory savings)
Higher threshold = more columns in INT8 (less accurate, better compression)
The empirically validated default of 6.0 works well across most models up to ~65B params.
For models larger than 65B, consider reducing to 4.0.
Args:
model_name: HuggingFace model ID
threshold: Outlier detection threshold (default 6.0)
"""
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=threshold,
llm_int8_skip_modules=["lm_head"], # Keep output projection in FP16
# lm_head is sensitive - it must produce accurate logits for token ranking
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
)
if torch.cuda.is_available():
mem_gb = torch.cuda.max_memory_allocated() / 1e9
print(f"INT8 model loaded: {mem_gb:.2f} GB GPU memory")
return model, tokenizer
def inspect_int8_layer(
model,
layer_path: str = "model.layers.0.self_attn.q_proj",
) -> dict:
"""
Inspect the internals of an INT8-quantized layer.
Shows which columns were detected as outlier (FP16) vs normal (INT8).
Returns dict with quantization statistics.
"""
parts = layer_path.split(".")
module = model
for part in parts:
module = getattr(module, part)
info = {
"type": type(module).__name__,
"weight_dtype": str(module.weight.dtype) if hasattr(module, 'weight') else "N/A",
}
# bitsandbytes INT8 linear layers expose these internals
if hasattr(module, 'weight'):
info["weight_shape"] = tuple(module.weight.shape)
info["weight_dtype"] = str(module.weight.dtype)
if hasattr(module, 'SCB') and module.SCB is not None:
# SCB = column-wise scale constants for INT8 portion
info["SCB_shape"] = tuple(module.SCB.shape)
info["n_int8_columns"] = module.SCB.shape[0]
# Count outlier columns if available
if hasattr(module, 'CB') and module.CB is not None:
info["CB_shape"] = tuple(module.CB.shape)
for k, v in info.items():
print(f" {k}: {v}")
return info
Solution 3: Block-Wise Quantization - Isolating Outlier Damageโ
Block-wise quantization is the most important engineering improvement to basic absmax. Instead of using a single scale for the entire weight matrix, compute a separate scale for each small block of weights (typically 64 or 128 consecutive elements per row).
This means outliers in one block affect only that block's quantization scale. Other blocks are unaffected. The precision loss caused by outliers is isolated to a small fraction of the matrix.
def blockwise_quantize_int8(
weight: torch.Tensor,
block_size: int = 64,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Block-wise INT8 quantization.
Each block of block_size consecutive elements per row shares one scale.
This isolates outlier damage: an outlier only degrades quantization
quality for the 64 adjacent weights, not the entire row.
Memory overhead from scales:
For block_size=64, FP16 scales:
2 bytes scale per 64 ร 2 bytes = 2/128 โ 1.5% overhead
Args:
weight: Weight tensor of any shape
block_size: Elements per block (64 or 128 are common choices)
Returns:
quantized: INT8 tensor, same shape as weight
scales: FP16 tensor, shape (n_elements // block_size,)
"""
Q_MAX = 127.0
flat = weight.reshape(-1).float()
n_elements = flat.numel()
# Pad to multiple of block_size if needed
remainder = n_elements % block_size
if remainder != 0:
pad_size = block_size - remainder
flat = torch.cat([flat, torch.zeros(pad_size, dtype=flat.dtype)])
n_blocks = flat.numel() // block_size
blocks = flat.reshape(n_blocks, block_size)
# Per-block scale: absmax within each block
abs_max = blocks.abs().max(dim=1).values # (n_blocks,)
scales = abs_max / Q_MAX
scales = scales.clamp(min=1e-10) # Prevent division by zero
# Quantize each block independently
quantized_blocks = torch.round(blocks / scales.unsqueeze(1)).clamp(-Q_MAX, Q_MAX)
quantized_blocks = quantized_blocks.to(torch.int8)
# Return in original shape (trim padding)
quantized = quantized_blocks.reshape(-1)[:n_elements].reshape(weight.shape)
return quantized, scales.half() # Store scales as FP16
def blockwise_dequantize_int8(
quantized: torch.Tensor,
scales: torch.Tensor,
original_shape: torch.Size,
block_size: int = 64,
) -> torch.Tensor:
"""Reconstruct FP16 weights from block-wise INT8 quantization."""
flat = quantized.reshape(-1).float()
n_elements = flat.numel()
remainder = n_elements % block_size
if remainder != 0:
pad_size = block_size - remainder
flat = torch.cat([flat, torch.zeros(pad_size, dtype=flat.dtype)])
n_blocks = flat.numel() // block_size
blocks = flat.reshape(n_blocks, block_size)
reconstructed_blocks = blocks * scales.float().unsqueeze(1)
reconstructed = reconstructed_blocks.reshape(-1)[:n_elements]
return reconstructed.reshape(original_shape).half()
def demonstrate_blockwise_benefit():
"""
Show quantitatively how block-wise quantization isolates outlier damage.
Key insight: with global absmax, a single outlier anywhere in the tensor
degrades precision everywhere. With block-wise, the damage is contained.
"""
torch.manual_seed(42)
weights = torch.randn(256) * 0.1
# Inject a single outlier in the middle
weights[128] = 15.0
# Global absmax: single scale dominated by outlier
q_global, s_global = absmax_quantize_int8(weights)
rec_global = absmax_dequantize(q_global, s_global)
# Error in the FIRST half (far from the outlier at position 128)
err_global_first_half = (weights[:64] - rec_global[:64]).abs().mean()
# Block-wise (block_size=64): outlier only affects block 2
q_block, scales_block = blockwise_quantize_int8(weights, block_size=64)
rec_block = blockwise_dequantize_int8(q_block, scales_block, weights.shape, block_size=64)
# Error in the FIRST half (block 0: positions 0-63 - entirely unaffected by outlier)
err_block_first_half = (weights[:64] - rec_block[:64]).abs().mean()
print(f"Outlier at position 128. Error in first block (positions 0-63):")
print(f" Global absmax: MAE = {err_global_first_half:.6f}")
print(f" Block-wise (64): MAE = {err_block_first_half:.6f}")
print(f" Block-wise is {err_global_first_half/err_block_first_half:.1f}x better for unaffected blocks")
print(f" (The first block has no outlier - block-wise should give near-perfect precision)")
NF4: The Optimal 4-Bit Data Type for Neural Networksโ
QLoRA (Dettmers et al., 2023) introduced NF4 (NormalFloat 4-bit) - an information-theoretically optimal 4-bit data type for weights that follow a normal distribution. The design principle: place quantization points where weights actually are, not where they are not.
Standard INT4 places 16 quantization values at evenly-spaced intervals. But neural network weights follow a near-normal distribution - dense near zero, sparse at extremes. Uniform spacing wastes bins at the sparse tails and has insufficient bins at the dense center.
NF4 places quantization points at the quantiles of a standard normal distribution:
Where is the inverse standard normal CDF (the probit function). This ensures each of the 16 bins covers an equal fraction of the actual weight distribution - maximum information per bit.
import numpy as np
from scipy import stats
def compute_nf4_levels() -> np.ndarray:
"""
Compute the 16 NF4 quantization levels.
These are the quantiles of a standard normal distribution at positions
1/32, 3/32, 5/32, ..., 31/32 - the centers of 16 equal-probability bins.
This places more quantization points near zero (where most weights are)
and fewer at the extremes (where few weights are).
The result: each bin captures an equal fraction of the actual weight distribution.
"""
n_levels = 16
quantile_positions = np.array([(2 * i + 1) / (2 * n_levels) for i in range(n_levels)])
nf4_levels = stats.norm.ppf(quantile_positions) # Inverse normal CDF
# Normalize to [-1, 1] for consistent use with block scales
nf4_levels = nf4_levels / nf4_levels.max()
return nf4_levels
def nf4_quantize(
weight: np.ndarray,
) -> Tuple[np.ndarray, float]:
"""
Quantize a weight array to NF4.
Process:
1. Compute the absmax scale for the weight
2. Normalize weights to [-1, 1]
3. For each weight, find the nearest NF4 level (index 0-15)
4. Store the 4-bit index
Dequantize: weight โ nf4_levels[index] * scale
Args:
weight: 1D array of float32 weights
Returns:
indices: Array of 4-bit indices (stored as uint8)
scale: Absmax scale factor for this block
"""
nf4_levels = compute_nf4_levels()
scale = np.abs(weight).max()
if scale == 0:
return np.zeros(len(weight), dtype=np.uint8), 1.0
# Normalize to [-1, 1]
w_normalized = weight / scale
# Find nearest NF4 level for each weight
# distances shape: (n_weights, 16)
distances = np.abs(w_normalized[:, None] - nf4_levels[None, :])
indices = distances.argmin(axis=1).astype(np.uint8) # 4-bit index: 0-15
return indices, scale
def compare_nf4_vs_uniform_int4(
n_weights: int = 100_000,
seed: int = 42,
) -> dict:
"""
Empirically compare NF4 vs uniform INT4 quantization error
for normally distributed weights.
Expected result: NF4 should have 20-40% lower MSE than uniform INT4
for the near-normal distributions typical of LLM weights.
"""
np.random.seed(seed)
weights = np.random.randn(n_weights) * 0.02 # Typical LLM weight scale
nf4_levels = compute_nf4_levels()
# --- NF4 quantization ---
indices_nf4, scale_nf4 = nf4_quantize(weights)
nf4_reconstructed = nf4_levels[indices_nf4] * scale_nf4
# --- Uniform INT4 quantization (symmetric, 7 levels positive) ---
scale_int4 = 7.0 / np.abs(weights).max()
int4_q = np.round(weights * scale_int4).clip(-7, 7).astype(np.int8)
int4_reconstructed = int4_q.astype(np.float32) / scale_int4
nf4_mse = np.mean((weights - nf4_reconstructed) ** 2)
int4_mse = np.mean((weights - int4_reconstructed) ** 2)
return {
"nf4_mse": nf4_mse,
"int4_uniform_mse": int4_mse,
"mse_improvement": int4_mse / nf4_mse,
"nf4_rmse": np.sqrt(nf4_mse),
"int4_rmse": np.sqrt(int4_mse),
}
# Verify NF4 advantage
if __name__ == "__main__":
nf4_levels = compute_nf4_levels()
print("NF4 quantization levels (16 values, normalized to [-1, 1]):")
print(np.round(nf4_levels, 4))
print()
comparison = compare_nf4_vs_uniform_int4()
print(f"NF4 vs Uniform INT4 for normally distributed weights:")
print(f" NF4 MSE: {comparison['nf4_mse']:.2e}")
print(f" INT4 MSE: {comparison['int4_uniform_mse']:.2e}")
print(f" NF4 advantage: {comparison['mse_improvement']:.2f}x lower MSE")
# Expect ~1.2-1.4x improvement for standard normal distribution
Double Quantization: Squeezing Out the Last Bitsโ
QLoRA introduced double quantization: apply a second quantization pass to the scale constants of the first quantization. Here is why this matters at scale.
In block-wise quantization with block_size=64, each block has one FP32 scale constant (4 bytes). For a 70B parameter model:
That is significant overhead. Double quantization applies a second quantization pass to these scale constants:
- Group 256 scale constants together
- Quantize them to 8-bit with a FP32 meta-scale
- Each block's scale now costs 1 byte instead of 4 bytes (plus trivial meta-scale overhead)
Net savings: approximately 0.37 bits per original model parameter. For a 70B model, this saves roughly 3 GB - meaningful when you are trying to fit the model on a single GPU.
def double_quantize_scales(
first_level_scales: torch.Tensor, # FP32 scales from block-wise quantization
group_size: int = 256, # How many first-level scales to group together
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply double quantization to reduce scale constant storage.
First-level: one FP32 scale per block of 64 original weights (4 bytes)
Second-level: one INT8 per first-level scale (1 byte) + one FP32 meta-scale per 256 first-level scales
Storage reduction for 70B model:
Before: 1.09B scales ร 4 bytes = 4.4 GB
After: 1.09B scales ร 1 byte = 1.1 GB + tiny meta-scale overhead
Savings: ~3.3 GB (0.37 bits per original parameter)
Returns:
scales_q8: INT8-quantized first-level scales
meta_scales: FP32 meta-scales (one per 256-scale group)
"""
flat = first_level_scales.reshape(-1).float()
n_scales = flat.numel()
# Pad to multiple of group_size
remainder = n_scales % group_size
if remainder != 0:
flat = torch.cat([flat, flat[-1].expand(group_size - remainder)])
groups = flat.reshape(-1, group_size)
# Second-level absmax quantization of scales
meta_max = groups.abs().max(dim=1).values
meta_scales = meta_max / 127.0
meta_scales = meta_scales.clamp(min=1e-10)
# Quantize first-level scales to INT8
scales_q8 = torch.round(groups / meta_scales.unsqueeze(1)).clamp(-127, 127).to(torch.int8)
return scales_q8, meta_scales.float()
def estimate_memory_with_double_quant(
model_params_b: float,
block_size: int = 64,
meta_group_size: int = 256,
n_bits: int = 4,
) -> dict:
"""
Estimate memory for a model with and without double quantization.
Shows the concrete savings in GB.
"""
n_params = model_params_b * 1e9
# Model weight memory (INT4 = 0.5 bytes per param)
weight_memory_gb = n_params * n_bits / 8 / 1e9
# First-level scales: one per block
n_first_level_scales = n_params / block_size
# Without double quantization: FP32 (4 bytes each)
scales_without_dq_gb = n_first_level_scales * 4 / 1e9
# With double quantization: INT8 (1 byte each) + FP32 meta-scale
scales_int8_gb = n_first_level_scales * 1 / 1e9
n_meta_scales = n_first_level_scales / meta_group_size
meta_gb = n_meta_scales * 4 / 1e9
scales_with_dq_gb = scales_int8_gb + meta_gb
savings_gb = scales_without_dq_gb - scales_with_dq_gb
savings_bits_per_param = savings_gb * 8e9 / n_params
return {
"model_params_b": model_params_b,
"n_bits": n_bits,
"weight_memory_gb": round(weight_memory_gb, 2),
"scales_without_dq_gb": round(scales_without_dq_gb, 3),
"scales_with_dq_gb": round(scales_with_dq_gb, 3),
"savings_gb": round(savings_gb, 3),
"savings_bits_per_param": round(savings_bits_per_param, 4),
"overhead_reduction_pct": round((1 - scales_with_dq_gb / scales_without_dq_gb) * 100, 1),
}
# Show savings for common model sizes
for b in [7.0, 13.0, 70.0]:
stats = estimate_memory_with_double_quant(b)
print(f"{b}B model: weight={stats['weight_memory_gb']}GB, "
f"scale savings={stats['savings_gb']}GB, "
f"{stats['overhead_reduction_pct']}% scale reduction")
Production Configuration: bitsandbytes and Transformersโ
In production, you use libraries rather than implementing quantization from scratch. Here is the complete, production-ready pattern for each configuration:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import time
def create_quantization_configs() -> dict:
"""
The standard quantization configurations used in production.
Each config is a BitsAndBytesConfig ready to pass to from_pretrained.
"""
return {
# FP16 baseline - no quantization config, just set dtype
"fp16": None,
# INT8 via LLM.int8() - outlier-aware mixed precision
# Best choice when accuracy is paramount and memory allows
"int8": BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0, # Outlier detection threshold
llm_int8_skip_modules=["lm_head"], # Keep output layer in FP16
),
# INT4 NF4 - optimal for normally distributed weights, best for QLoRA
# Best choice for general INT4 inference and QLoRA fine-tuning
"nf4": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4 data type
bnb_4bit_use_double_quant=True, # Quantize the scales too (~0.37 bits/param saving)
bnb_4bit_compute_dtype=torch.bfloat16, # Dequantize to BF16 for matmul
),
# INT4 FP4 - alternative 4-bit type, occasionally better for specific models
"fp4": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="fp4", # Standard 4-bit float
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
),
}
def benchmark_quantization_configs(
model_name: str,
test_prompt: str = "Explain the concept of transformer attention in detail:",
max_new_tokens: int = 100,
n_warmup: int = 3,
n_runs: int = 10,
) -> dict:
"""
Benchmark a model under different quantization configurations.
Reports memory, latency, and throughput for each configuration.
Use this to empirically determine the best config for your GPU and model.
Run before committing to a configuration for production.
"""
configs = create_quantization_configs()
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(test_prompt, return_tensors="pt")
all_results = {}
for config_name, bnb_config in configs.items():
print(f"\n{'='*55}")
print(f"Testing: {config_name.upper()}")
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
load_kwargs = {"device_map": "auto"}
if bnb_config is not None:
load_kwargs["quantization_config"] = bnb_config
else:
load_kwargs["torch_dtype"] = torch.float16
t_load = time.time()
model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
load_time = time.time() - t_load
peak_mem = torch.cuda.max_memory_allocated() / 1e9
inputs_gpu = {k: v.to(model.device) for k, v in inputs.items()}
# Warmup - first few runs have JIT compilation overhead
for _ in range(n_warmup):
with torch.no_grad():
model.generate(**inputs_gpu, max_new_tokens=10, do_sample=False)
# Timed benchmark runs
times = []
for _ in range(n_runs):
torch.cuda.synchronize()
t = time.perf_counter()
with torch.no_grad():
output = model.generate(
**inputs_gpu,
max_new_tokens=max_new_tokens,
do_sample=False,
)
torch.cuda.synchronize()
times.append(time.perf_counter() - t)
mean_time = sum(times) / len(times)
p50_time = sorted(times)[len(times) // 2]
p90_time = sorted(times)[int(len(times) * 0.9)]
tokens_per_sec = max_new_tokens / mean_time
gen_text = tokenizer.decode(
output[0][inputs_gpu["input_ids"].shape[1]:],
skip_special_tokens=True,
)[:80]
result = {
"load_time_s": round(load_time, 1),
"peak_memory_gb": round(peak_mem, 2),
"mean_latency_ms": round(mean_time * 1000, 1),
"p50_latency_ms": round(p50_time * 1000, 1),
"p90_latency_ms": round(p90_time * 1000, 1),
"tokens_per_second": round(tokens_per_sec, 1),
"sample_output_80chars": gen_text,
}
all_results[config_name] = result
print(f" Load time: {result['load_time_s']}s")
print(f" Peak memory: {result['peak_memory_gb']} GB")
print(f" Mean latency: {result['mean_latency_ms']} ms ({max_new_tokens} tokens)")
print(f" P90 latency: {result['p90_latency_ms']} ms")
print(f" Throughput: {result['tokens_per_second']} tok/s")
del model
torch.cuda.empty_cache()
return all_results
FP8: The Next Generation Quantization Formatโ
H100 GPUs natively support FP8 - a floating-point format that maintains exponent and mantissa bits instead of mapping to integers. FP8 comes in two variants:
- E4M3: 4 exponent bits, 3 mantissa bits. Range: approximately ยฑ448. Better accuracy for weights.
- E5M2: 5 exponent bits, 2 mantissa bits. Range: approximately ยฑ57,344. Better range for activations and gradients.
FP8 is increasingly important because it enables training and inference at 2x the density of FP16, with hardware-native support on H100 and newer GPUs. Unlike INT8, which requires explicit dequantization before computation, FP8 arithmetic can happen natively in the floating-point domain - a cleaner abstraction that avoids the awkward quantize/compute/dequantize cycle.
def compare_fp8_vs_fp16_vs_int8(
n_weights: int = 100_000,
weight_std: float = 0.02,
seed: int = 42,
) -> dict:
"""
Compare approximate quantization error for FP8 vs FP16 vs INT8.
Note: True FP8 computation requires transformer-engine or H100 hardware.
This shows the approximate quantization error for comparison.
"""
torch.manual_seed(seed)
weights = torch.randn(n_weights) * weight_std
# FP16: virtually lossless
fp16_reconstructed = weights.half().float()
fp16_mse = ((weights - fp16_reconstructed) ** 2).mean().item()
# INT8 (absmax): standard quantization
q_int8, s_int8 = absmax_quantize_int8(weights)
int8_reconstructed = absmax_dequantize(q_int8, s_int8)
int8_mse = ((weights - int8_reconstructed) ** 2).mean().item()
# FP8 E4M3 approximation: 3 mantissa bits = 1/8 resolution at each exponent
# Max value: 448.0, Min positive normal: ~1.95e-3
def approx_fp8_e4m3(t: torch.Tensor) -> torch.Tensor:
"""Approximate FP8 E4M3 by rounding to 3 mantissa bits of precision."""
clamped = t.clamp(-448.0, 448.0)
abs_vals = clamped.abs().clamp(min=1e-37)
exponents = torch.floor(torch.log2(abs_vals))
mantissa_scale = 2.0 ** exponents / 8.0 # 2^3 = 8 mantissa levels
mantissa_scale = mantissa_scale.clamp(min=1e-37)
return torch.round(clamped / mantissa_scale) * mantissa_scale
fp8_reconstructed = approx_fp8_e4m3(weights)
fp8_mse = ((weights - fp8_reconstructed) ** 2).mean().item()
return {
"fp16_mse": fp16_mse,
"fp8_e4m3_approx_mse": fp8_mse,
"int8_mse": int8_mse,
"fp8_vs_fp16_ratio": fp8_mse / (fp16_mse + 1e-12),
"int8_vs_fp8_ratio": int8_mse / (fp8_mse + 1e-12),
"summary": "FP8 is ~10-20x worse than FP16 but ~2-3x better than INT8",
}
The Full Picture: Formats, Accuracy, and Use Casesโ
| Configuration | Memory (7B) | Quality Loss | Inference Speed | Best For |
|---|---|---|---|---|
| FP32 | 28 GB | 0% | 1.0x | Training only |
| FP16/BF16 | 14 GB | ~0% | 1.5x vs FP32 | All inference (default) |
| FP8 (H100 native) | 7 GB | less than 0.5% | 2x vs FP16 | H100 inference + training |
| INT8 (LLM.int8) | 7 GB | 0.5-1% | 1.3-1.5x vs FP16 | Accuracy-critical inference |
| NF4 (QLoRA) | 3.5 GB | 1-3% | 2-2.5x vs FP16 | Consumer GPUs, QLoRA fine-tuning |
| GPTQ INT4 | 3.5 GB | 1-3% | 2-3x vs FP16 | Production INT4 inference |
| AWQ INT4 | 3.5 GB | 1-2.5% | 2.5-3.5x vs FP16 | Best accuracy + speed at INT4 |
Production Pitfalls and How to Avoid Themโ
:::danger Quantizing Attention Layers Without Checking Sensitivity
Attention Q, K, V projections are often the most sensitive to quantization. The outlier problem is most severe here, and even one bit of extra precision can make a significant difference. When accuracy degrades with INT4, check attention layers first. Use llm_int8_skip_modules or GPTQ/AWQ equivalent to skip the most sensitive layers and keep them in INT8 or FP16. Phi-3 and Gemma architectures have more outlier-prone attention layers than Llama-style models - if you see unusual degradation, this is the first place to look.
:::
:::warning Not Testing at Production Sequence Lengths Quantization errors compound over long context windows in ways that short-context benchmarks miss. A model that looks fine at 512-token prompts may show degradation at 4K or 8K tokens because quantization noise accumulates layer by layer differently at longer sequences. Always benchmark at the sequence lengths you will actually use in production before declaring a quantized model production-ready. :::
:::tip Use BF16 Instead of FP16 for Compute Dtype
When using bitsandbytes 4-bit quantization, set bnb_4bit_compute_dtype=torch.bfloat16 instead of float16. BF16 has wider dynamic range (same exponent bits as FP32) and is less prone to overflow/underflow during the dequantized computations. On modern NVIDIA GPUs (Ampere and later), BF16 arithmetic is just as fast as FP16. Using FP16 compute dtype with NF4 weights can cause subtle numerical instability during generation of long sequences.
:::
:::info When INT8 Is Better Than INT4 Despite INT4 being more compressed, INT8 is sometimes the better production choice: (1) When accuracy is non-negotiable - medical, legal, financial applications. (2) When your GPU has native INT8 Tensor Core support but not optimized INT4 kernels (older Ampere GPUs). (3) When you are doing batch inference at large batch sizes - INT4 dequantization overhead can make it slower than INT8 at high batch sizes. (4) When you are combining quantization with other techniques and need a better accuracy starting point. Measure both before committing. :::
Interview Questionsโ
Q1: Explain the outlier problem in LLM quantization. How does LLM.int8() address it?
Outliers in LLM activations are dimensions where the input values are systematically much larger than the rest - often 10-100x larger. These outliers emerge predictably at approximately 6.7B parameters and increase in prevalence and magnitude with scale. The problem: absmax quantization computes a single scale factor from the maximum absolute value, which is dominated by outliers. For a tensor where 99.5% of values are in [-0.1, 0.1] but the max is 15, the scale maps 15โ127, which means 0.1โroughly 0.85โquantized to 1. All values from 0 to 0.11 map to 0, destroying their information. LLM.int8() addresses this by detecting which input dimensions are outlier-prone using calibration data, then running those dimensions in FP16 while quantizing all other dimensions in INT8. The implementation: a second, separate matrix multiply in FP16 for the outlier portion (~1% of columns), added to the INT8 result for the normal portion. The accuracy cost is near-zero because outlier dimensions receive full FP16 precision.
Q2: What makes NF4 better than standard INT4 for neural network weights?
Standard INT4 places its 16 quantization levels at evenly-spaced intervals across the weight range. Neural network weights follow a near-normal distribution - most values cluster near zero, with exponentially fewer values toward the extremes. Uniform spacing wastes quantization levels at the sparse tails and has insufficient levels near the dense center. NF4 (NormalFloat 4-bit) addresses this by placing its 16 levels at the quantiles of a standard normal distribution. At each quantile, the bins have equal probability mass - meaning each bin covers the same fraction of the actual weight distribution. For typical LLM weight distributions, this reduces quantization MSE by 20-40% compared to uniform INT4 at the same bit-width. The implementation: normalize weights to [-1, 1] by dividing by the absolute maximum, find the nearest of the 16 NF4 levels, store the 4-bit index. Dequantize by looking up the level and multiplying by the stored scale.
Q3: What is double quantization and when does it provide meaningful savings?
Double quantization (from QLoRA) applies a second quantization pass to the scale constants of the first quantization. In block-wise INT4 quantization with block_size=64, each block of 64 weights has one FP32 scale constant. For a 70B model, this produces 1.09 billion scale constants ร 4 bytes = 4.4 GB just for scales. Double quantization groups 256 of these scales together, quantizes them to INT8 with a meta-scale, and reduces each scale from 4 bytes (FP32) to approximately 1 byte (INT8) plus trivial meta-scale overhead. For a 70B model, this saves roughly 3 GB. The accuracy impact is negligible because scale constants follow a smooth, well-behaved distribution that INT8 handles excellently. Enable with bnb_4bit_use_double_quant=True in bitsandbytes - there is essentially no reason not to use it.
Q4: When would you choose INT8 over INT4 for production inference?
Choose INT8 when: (1) Your task is accuracy-sensitive and tolerates at most 1% degradation - INT8 with LLM.int8() typically loses 0.3-0.8% on standard benchmarks versus FP16; INT4 loses 1-3%. For medical triage, legal document analysis, or financial calculations, that difference matters. (2) Your GPU has native INT8 Tensor Core support but your INT4 kernel support is limited. (3) Your serving batch size is large (greater than 16) and you are compute-bound - INT4 dequantization overhead can exceed its bandwidth savings at large batches, making INT8 faster in throughput. (4) You need to support fine-tuning after deployment - INT8 can be fine-tuned more easily than INT4. Choose INT4 when: the model simply does not fit at INT8, you are serving single-user requests (bandwidth-bound), or you have measured that the accuracy difference is acceptable for your task.
Q5: How would you diagnose and fix accuracy degradation after quantizing a model?
Systematic diagnosis: (1) Identify which capabilities are degraded. Benchmark on your specific task - arithmetic, code generation, reasoning, factual recall. Do not stop at MMLU. (2) Run layer-by-layer sensitivity analysis. Quantize one layer at a time and measure the accuracy impact. This identifies which layers are most sensitive - typically attention projections in the early and late layers. (3) Use mixed precision for sensitive layers. Keep the two or three most sensitive layers in FP16 or INT8 using llm_int8_skip_modules in bitsandbytes or the equivalent in GPTQ/AWQ configurations. (4) Check your calibration data. For GPTQ and AWQ, the calibration dataset directly affects which weights are considered important. Use domain-matched calibration data. (5) Try a different algorithm. If bitsandbytes NF4 degrades accuracy by 4%, GPTQ or AWQ might achieve 2% degradation through better error compensation. (6) Consider INT8 instead of INT4 for the affected layers, or for the entire model if the accuracy requirement is strict.
