Floating-Point Arithmetic - Precision, Overflow, and Mixed Precision Training
Reading time: ~26 minutes | Level: Numerical Foundations → Production ML
It is 3 AM. Your 7B parameter model has been training for 18 hours on 8 A100s. Then: loss: nan. Training is dead. You restart. Same thing at epoch 3. Your team is burning $40/hour in cloud compute.
The culprit: a single log(0) in your loss function, triggered by a probability that numerically underflowed to 0.0 in float16. Works perfectly in float32 on your laptop.
This is the real cost of not understanding floating-point arithmetic. It is not a theoretical concern. Every production ML engineer encounters it.
What You Will Learn
- The IEEE 754 standard: what a float actually is in memory
- Machine epsilon: the fundamental precision limit of floating-point
- float16 vs float32 vs bfloat16: the trade-offs that define mixed precision training
- Catastrophic cancellation: when subtraction destroys all your precision
- Log-sum-exp trick and other numerically stable reformulations
- Gradient scaling in mixed precision training
- How to detect and debug numerical instability in PyTorch
Part 1 - What a Floating-Point Number Actually Is
The IEEE 754 representation
Every float32 number is stored in 32 bits split into three fields:
float32 (32 bits):
┌───┬──────────┬───────────────────────────┐
│ s │ exp │ mantissa │
│ 1 │ 8 │ 23 │
└───┴──────────┴───────────────────────────┘
Value = (-1)^s × 2^(exp - 127) × (1 + mantissa/2^23)
- Sign (1 bit): 0 = positive, 1 = negative
- Exponent (8 bits): biased by 127, so stored exponent 127 → actual exponent 0
- Mantissa (23 bits): the fractional part, with an implicit leading 1
import struct
import numpy as np
def float32_bits(x: float) -> str:
"""Show the IEEE 754 binary representation of a float32."""
bits = struct.pack('>f', x)
n = int.from_bytes(bits, 'big')
binary = format(n, '032b')
sign = binary[0]
exponent = binary[1:9]
mantissa = binary[9:]
exp_val = int(exponent, 2) - 127
return f"sign={sign} | exp={exponent} ({exp_val}) | mantissa={mantissa}"
print(float32_bits(1.0))
# sign=0 | exp=01111111 (0) | mantissa=00000000000000000000000
# Value: (-1)^0 × 2^0 × 1.0 = 1.0
print(float32_bits(0.1))
# sign=0 | exp=01111011 (-4) | mantissa=10011001100110011001101
# 0.1 cannot be represented exactly in binary!
# Demonstrate that 0.1 + 0.2 ≠ 0.3 in floating point
print(0.1 + 0.2) # 0.30000000000000004
print(0.1 + 0.2 == 0.3) # False
The key insight: not all real numbers are representable
Between any two consecutive float32 values there is a gap. The gap size scales with the magnitude of the number:
This means:
- Near 1.0: gap ≈ (7 decimal digits of precision)
- Near 1000.0: gap ≈ (only 4 decimal digits near 1000)
- Near : gap ≈ (more precision near zero)
# Gap between consecutive float32 values
print(np.spacing(np.float32(1.0))) # ~1.19e-7
print(np.spacing(np.float32(1000.0))) # ~1.22e-4
print(np.spacing(np.float32(1e7))) # ~1.0 - integers above 2^24 cannot be
# represented exactly!
Part 2 - Machine Epsilon: The Fundamental Precision Limit
Machine epsilon () is the smallest positive number such that:
in floating-point arithmetic.
import numpy as np
# Machine epsilon for different dtypes
for dtype in [np.float16, np.float32, np.float64]:
info = np.finfo(dtype)
print(f"{dtype.__name__}:")
print(f" epsilon = {info.eps:.3e} # smallest ε: 1+ε ≠ 1")
print(f" tiny = {info.tiny:.3e} # smallest positive normal")
print(f" max = {info.max:.3e} # largest finite value")
print()
# float16: epsilon=9.77e-04, tiny=6.10e-05, max=6.55e+04
# float32: epsilon=1.19e-07, tiny=1.18e-38, max=3.40e+38
# float64: epsilon=2.22e-16, tiny=2.22e-308, max=1.80e+308
ML implication: learning rate floors
Machine epsilon defines the practical precision floor. A learning rate smaller than:
produces no actual weight update - the addition rounds away to zero. In float32 with weights of magnitude ~1.0, learning rates below have no effect.
:::warning float16 learning rate precision In float16, machine epsilon is ~. A learning rate of may produce weight updates so small they round to zero. This is one reason mixed precision training keeps a master copy of weights in float32. :::
Part 3 - The Three Floating-Point Formats in Deep Learning
Comparison table
| Property | float16 | bfloat16 | float32 | float64 |
|---|---|---|---|---|
| Total bits | 16 | 16 | 32 | 64 |
| Exponent bits | 5 | 8 | 8 | 11 |
| Mantissa bits | 10 | 7 | 23 | 52 |
| Max value | 65,504 | ~3.4×10³⁸ | ~3.4×10³⁸ | ~1.8×10³⁰⁸ |
| Machine epsilon | 9.77×10⁻⁴ | 7.81×10⁻³ | 1.19×10⁻⁷ | 2.22×10⁻¹⁶ |
| GPU throughput | 2× float32 | 2× float32 | Baseline | 0.5× float32 |
float16 - high throughput, dangerously narrow range
float16 has only 5 exponent bits, giving a maximum value of 65,504. This is easily exceeded during training:
import numpy as np
# float16 overflow is silent - produces inf, not an error
x = np.float16(65504.0) # Maximum representable float16
print(x + 1.0) # 65504.0 - saturates, no exception!
print(x * 2.0) # inf - overflow to infinity
# Logits that are fine in float32 but overflow in float16
logits = np.array([12.0, -2.0, 5.0], dtype=np.float32)
logits_f16 = logits.astype(np.float16)
# e^12 ≈ 162754, which overflows float16's max of 65504
print(np.exp(logits_f16)) # [inf, 0.1353, 148.4] - inf contaminates softmax
print(np.exp(logits)) # [162754.79, 0.135, 148.41] - correct in float32
bfloat16 - the ML-optimized format
bfloat16 was designed by Google Brain specifically for deep learning. It has the same 8-bit exponent as float32 (same dynamic range: ), but only 7 mantissa bits versus float32's 23.
Why this is the right trade-off for ML:
- Neural network weights and gradients rarely need more than 2–3 decimal digits of precision - the stochastic noise of SGD dwarfs float32's precision advantage
- But activations can span a wide dynamic range - bfloat16's float32-equivalent range prevents overflow
- Result: far safer than float16 for training, with identical GPU throughput
import torch
# bfloat16 in PyTorch - natively supported on A100, H100, TPUs
x = torch.tensor([12.0, -2.0, 5.0], dtype=torch.bfloat16)
print(torch.exp(x)) # tensor([162688., 0.1353, 148.125], dtype=torch.bfloat16)
# No overflow - bfloat16 has same range as float32
# Mixed precision training pattern with bfloat16 (no GradScaler needed)
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
optimizer.zero_grad()
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = model(batch['input_ids'])
loss = criterion(output, batch['labels'])
loss.backward()
optimizer.step()
When to use each format
| Hardware | Recommended dtype | Notes |
|---|---|---|
| A100, H100, TPU | bfloat16 | Wide range, no GradScaler needed |
| V100, T4, RTX 30xx | float16 + GradScaler | Tensor cores efficient, needs scaling |
| CPU inference | float32 | bfloat16 may not be optimized |
| Optimizer state (Adam) | float32 | Always - need precision for small updates |
| Loss computation | float32 | Reductions need full precision |
Part 4 - Catastrophic Cancellation
Catastrophic cancellation occurs when you subtract two nearly equal floating-point numbers. The result has far fewer significant digits than the inputs.
Example: computing variance naively
Algebraically correct, but numerically catastrophic for data with large mean:
import numpy as np
# Data with large mean (~1e6) and true variance of 1.0
data = np.array([1000000.0, 1000001.0, 1000002.0, 999999.0, 1000000.0],
dtype=np.float32)
# Numerically UNSTABLE - catastrophic cancellation
mean_sq = np.mean(data**2) # ~1.0e12
sq_mean = np.mean(data)**2 # ~1.0e12 - nearly equal to mean_sq
variance_unstable = mean_sq - sq_mean
print(f"Unstable variance: {variance_unstable}") # Can be wrong or 0.0
# Numerically STABLE - compute deviations from mean first
mean = np.mean(data)
variance_stable = np.mean((data - mean)**2)
print(f"Stable variance: {variance_stable}") # ≈ 1.0 ✓
# NumPy's built-in var() uses the stable algorithm
print(f"NumPy var: {np.var(data)}") # ≈ 1.0 ✓
What happened: Both and are approximately . Their difference should be 1.0, but the subtraction cancels all significant digits, leaving only rounding noise.
Catastrophic cancellation in deep learning: the log-softmax problem
The most common cancellation problem in ML is computing log(softmax(x)):
Naive implementation (DO NOT use):
def naive_log_softmax(x):
"""Numerically UNSTABLE."""
softmax = np.exp(x) / np.sum(np.exp(x)) # overflow risk
return np.log(softmax) # underflow: log(0) = -inf
x = np.array([1000.0, 1001.0, 1002.0])
print(naive_log_softmax(x)) # [nan, nan, nan] - overflow at exp(1000)
The log-sum-exp trick
Subtract the maximum before exponentiating - algebraically equivalent but numerically stable:
def stable_log_softmax(x: np.ndarray) -> np.ndarray:
"""Numerically stable log-softmax."""
c = np.max(x) # subtract max
shifted = x - c # all values ≤ 0, no overflow
log_sum_exp = np.log(np.sum(np.exp(shifted))) + c # add c back
return x - log_sum_exp
x = np.array([1000.0, 1001.0, 1002.0])
result = stable_log_softmax(x)
print(result) # [-2.4076, -1.4076, -0.4076] - correct!
print(np.sum(np.exp(result))) # ≈ 1.0 ✓ probabilities sum to 1
# PyTorch uses this trick automatically
import torch
x_torch = torch.tensor([1000.0, 1001.0, 1002.0])
print(torch.log_softmax(x_torch, dim=0)) # identical result
General log-sum-exp
from scipy.special import logsumexp
import numpy as np
x = np.array([1000.0, 1001.0, 1002.0])
print(logsumexp(x)) # 1002.4076 - correct, no overflow
# Use case: stable log-likelihood in Bayesian models
def stable_log_likelihood(logits: np.ndarray, y: int) -> float:
"""log p(y|x) = log_softmax(logits)[y]"""
return logits[y] - logsumexp(logits)
Part 5 - Overflow and Underflow in Practice
Float16 overflow in transformer attention
The scaled dot-product attention formula:
The scaling factor exists precisely to prevent float16 overflow:
import numpy as np
d_k = 64 # typical attention head dimension
Q = np.random.randn(1, d_k).astype(np.float16)
K = np.random.randn(10, d_k).astype(np.float16)
# Without scaling: dot products can easily exceed float16 max
raw_dots = Q @ K.T
print(f"Raw max: {raw_dots.max():.1f}") # Could be 20-30 → exp(30) >> 65504
# With sqrt(d_k) scaling: reduced by factor of 8
scaled_dots = Q @ K.T / np.sqrt(d_k)
print(f"Scaled max: {scaled_dots.max():.1f}") # Much safer for exp()
Underflow: probabilities that vanish to zero
When a probability becomes smaller than float16's minimum (), it underflows to 0.0. Then log(0.0) = -inf, which propagates as NaN through backpropagation.
import numpy as np
p = np.float16(1e-5) # Below float16's minimum (~6.1e-5)
print(p) # 0.0 - silently underflowed!
print(np.log(p)) # -inf - loss function catastrophically fails
# Safe pattern: always use log_softmax, never log(softmax)
# PyTorch's nn.CrossEntropyLoss does this automatically:
# F.cross_entropy = F.nll_loss(F.log_softmax(logits, dim=1), targets)
def safe_log(x: np.ndarray, eps: float = 1e-8) -> np.ndarray:
"""Numerically safe log - clips near zero to avoid -inf."""
return np.log(np.maximum(x, eps))
Part 6 - Mixed Precision Training: The Full Picture
The four-step pattern with GradScaler (float16)
import torch
from torch.cuda.amp import GradScaler, autocast
model = TransformerModel(d_model=512, n_heads=8).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# GradScaler multiplies loss by large constant before backward()
# This prevents float16 gradient underflow
# Divides gradients back before the optimizer step
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Step 1: forward pass in float16
with autocast(device_type='cuda', dtype=torch.float16):
output = model(batch['input_ids'])
loss = criterion(output, batch['labels'])
# Step 2: scale loss to prevent gradient underflow
scaler.scale(loss).backward()
# Step 3: unscale gradients, check for inf/nan
scaler.unscale_(optimizer)
# Gradient clipping - important for stability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step 4: update weights in float32 (skips step if grads are inf/nan)
scaler.step(optimizer)
scaler.update() # adjusts scale factor for next iteration
What autocast decides per operation
torch.autocast maintains a policy - some operations run in reduced precision, others always in float32:
| Run in float16/bfloat16 | Always float32 |
|---|---|
| matmul, linear, conv | loss functions (cross-entropy, MSE) |
| batch matrix multiply | log, exp, pow |
| attention (after scaling) | softmax, layer norm |
| embedding lookup | batch norm statistics |
# Force float32 for a specific operation inside autocast
class StableAttention(torch.nn.Module):
def forward(self, Q, K, V):
d_k = Q.shape[-1]
# matmul runs in float16 - fast
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
# softmax: force float32 for stability
scores = scores.float()
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = attn_weights.to(Q.dtype) # cast back for V multiply
return torch.matmul(attn_weights, V)
Part 7 - Detecting and Debugging Numerical Issues
Gradient monitoring hooks
import torch
def add_gradient_hooks(model: torch.nn.Module) -> None:
"""Monitor gradients for NaN/Inf during training."""
def make_hook(name):
def hook(grad):
if grad is not None:
if torch.isnan(grad).any():
print(f"[NaN gradient] layer: {name}")
if torch.isinf(grad).any():
print(f"[Inf gradient] layer: {name}")
grad_norm = grad.norm().item()
if grad_norm > 1000:
print(f"[Large gradient norm {grad_norm:.1f}] layer: {name}")
return grad
return hook
for name, param in model.named_parameters():
param.register_hook(make_hook(name))
Pre-training numerical health check
def check_numerical_health(model: torch.nn.Module, sample_batch: dict) -> dict:
"""Run one forward pass and report any numerical issues."""
issues = {}
with torch.no_grad():
# Check inputs
for key, val in sample_batch.items():
if isinstance(val, torch.Tensor) and val.is_floating_point():
if torch.isnan(val).any():
issues[f'input_{key}_has_nan'] = True
abs_max = val.abs().max().item()
if abs_max > 1e4:
issues[f'input_{key}_large_values'] = abs_max
# Check weight initialization
for name, param in model.named_parameters():
norm = param.norm().item()
if norm > 100:
issues[f'param_{name}_large_norm'] = norm
if torch.isnan(param).any():
issues[f'param_{name}_has_nan'] = True
return issues
Numerically stable loss patterns
import numpy as np
import torch
import torch.nn.functional as F
def safe_softmax(x: np.ndarray, axis: int = -1) -> np.ndarray:
"""Numerically stable softmax."""
x_shifted = x - np.max(x, axis=axis, keepdims=True)
exp_x = np.exp(x_shifted)
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def safe_batch_norm(x: np.ndarray, eps: float = 1e-5) -> np.ndarray:
"""BatchNorm with eps to prevent zero-denominator catastrophe."""
mean = x.mean(axis=0)
var = x.var(axis=0)
return (x - mean) / np.sqrt(var + eps)
# PyTorch's BatchNorm uses eps=1e-5 by default for exactly this reason
# Cross-entropy: ALWAYS use F.cross_entropy, never log(softmax) manually
def correct_cross_entropy(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# PyTorch internally computes: F.nll_loss(F.log_softmax(logits, dim=1), targets)
# log_softmax uses the log-sum-exp trick - safe even for extreme logits
return F.cross_entropy(logits, targets)
Interview Questions
Q1: Why does float16 training sometimes diverge even with gradient scaling?
Float16 has three key failure modes:
-
Overflow: Max value 65,504. Activations, intermediate results, or gradients exceeding this overflow to
inf→ propagates as NaN through all downstream operations. -
Low precision: Machine epsilon ~9.77×10⁻⁴. Weight updates smaller than ~0.001× weight magnitude are rounded to zero - learning effectively stops.
-
Gradient underflow: Even with GradScaler scaling the loss, individual layer gradients can span many orders of magnitude. If a gradient at a specific layer is intrinsically tiny, scaling the global loss may not be enough - it still underflows to zero in float16.
Solutions in priority order:
- Switch to bfloat16 (same range as float32 - eliminates overflow and underflow entirely)
- Use gradient clipping (
max_norm=1.0) to reduce gradient variance - Keep batch norm and layer norm computations in float32
- Reduce sequence length / batch size if activations overflow
- Monitor
scaler.get_scale()- if it repeatedly drops to 1.0, you have chronic overflow
Q2: What is catastrophic cancellation and how does the log-sum-exp trick prevent it?
Catastrophic cancellation occurs when two nearly equal floating-point numbers are subtracted. The leading significant digits cancel, leaving only rounding noise in the result.
Example: both and are approximately for data with mean . Their difference should be 1.0, but float32 can only represent with 7 significant digits - so the last few digits are noise, and their difference has zero valid significant digits.
In ML: Computing log(softmax(x)) suffers from:
exp(large x)overflows toinfin float16exp(x_i) / sum(exp(x_j))can beinf / inf= NaNlog(tiny probability)underflows to-inf
The log-sum-exp trick reformulates:
By subtracting the max first, all exponentials are in - no overflow. The sum is at least 1 - no underflow. The final result involves no catastrophic cancellation.
Q3: Why is bfloat16 generally safer than float16 for training transformers?
The key difference is the exponent width:
- float16: 5 exponent bits → max value 65,504 → easy overflow in attention, activations, and gradients
- bfloat16: 8 exponent bits (same as float32) → max value ~3.4×10³⁸ → same dynamic range as float32
Practical benefits for transformers:
-
No GradScaler needed: The entire purpose of GradScaler is preventing float16's narrow range from causing gradient underflow. bfloat16's range eliminates this need.
-
Simpler training code:
torch.autocast(dtype=torch.bfloat16)with no other changes typically works. -
No attention overflow: Unscaled attention dot products can reach values of 20–50 for long sequences.
exp(50) ≈ 5×10²¹- no problem for bfloat16, catastrophic for float16.
Downside: bfloat16 has 7 mantissa bits vs float16's 10 → less precision per unit of range. But neural network training is inherently stochastic - the SGD noise floor is much larger than the precision difference. In practice, bfloat16 matches float32 training quality on nearly all tasks.
Q4: A transformer's softmax output contains NaN values. Walk through your debugging process.
Step-by-step debugging:
Step 1: Check input logits magnitude
# Before softmax, print max attention logit
max_logit = (Q @ K.T / math.sqrt(d_k)).max().item()
print(f"Max attention logit: {max_logit}")
# If > 11 in float16, exp overflows → NaN
Step 2: Check dtype
Are you in a float16 autocast region? Even exp(12) ≈ 162754 overflows float16's max of 65504.
Step 3: Check for attention mask issues
If attention mask adds large negative values (like -1e9 or -inf) to padding positions, exp(-inf) = 0 and softmax([-inf, -inf, ...]) = 0/0 = NaN.
# Fix: use -1e4 instead of -inf for float16, or cast to float32 before softmax
attention_scores = attention_scores.float() # force float32 for softmax
attention_probs = torch.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.to(original_dtype)
Step 4: Look for all-masked rows
If an entire sequence row is masked out (e.g., padding-only input), the softmax denominator is sum(exp(-inf, -inf, ...)) = 0 → 0/0 = NaN. Add a guard: if all positions are masked, the output should be zero, not NaN.
Q5: What is machine epsilon and how does it bound the minimum effective learning rate?
Machine epsilon () is the smallest positive such that fl(1 + ε) > 1 in the floating-point system. Equivalently, it is the upper bound on relative rounding error for any single floating-point operation.
For float32: .
Learning rate implication: Weight update:
In floating-point: fl(w + (-η·g)) = w·(1 + δ) where .
If , the update is smaller than the rounding error in representing - it simply rounds away. The weight does not change.
Minimum effective learning rate (for float32, weights of magnitude ~1, gradients of magnitude ~1):
For float16: , so learning rates below ~ (relative to weight magnitude) may have no effect. This is why:
- Mixed precision keeps master weights in float32
- Very small learning rates (cosine annealing to 1e-6) require float32 weight updates
Quick Reference
| Format | Max Value | Epsilon | Use Case |
|---|---|---|---|
| float16 | 65,504 | 9.77×10⁻⁴ | Inference, V100/T4 training with GradScaler |
| bfloat16 | 3.4×10³⁸ | 7.81×10⁻³ | A100/H100/TPU training - preferred |
| float32 | 3.4×10³⁸ | 1.19×10⁻⁷ | Optimizer state, master weights, loss |
| float64 | 1.8×10³⁰⁸ | 2.22×10⁻¹⁶ | Scientific computing, rarely in ML |
| Problem | Symptom | Fix |
|---|---|---|
| exp overflow | loss: inf or NaN | log-sum-exp trick, temperature scaling |
| log underflow | NaN at first epoch | Use log_softmax not log(softmax) |
| Gradient explosion | Spiky loss → NaN | Gradient clipping, LR warmup |
| Float16 overflow | NaN only in float16 mode | Switch to bfloat16 or use GradScaler |
| Catastrophic cancellation | Variance = 0 for large-mean data | Two-pass stable algorithm |
| All-masked softmax | NaN in padding positions | Guard against all-zero softmax rows |
Next: Lesson 02: Numerical Linear Algebra →
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Floating Point Arithmetic demo on the EngineersOfAI Playground - no code required.
:::
