Information Geometry
The Scenario That Motivates This Lesson
You are reading the K-FAC paper (Kronecker-Factored Approximate Curvature), one of the most influential papers on neural network optimization. The abstract says:
"We propose K-FAC, an approximation to the natural gradient..."
And later, in the Adam optimizer's analysis:
"Adam is an approximation to the natural gradient when the Fisher information matrix is diagonal..."
You have heard these claims before. But what is the natural gradient? Why is it "natural"? What is the Fisher information matrix? And why does using it lead to better optimization?
This lesson answers all of these questions, building from the geometry of probability distributions to practical second-order optimization algorithms.
Statistical Manifolds: The Geometry of Distributions
A statistical manifold is a smooth manifold where each point represents a probability distribution. The coordinates are the parameters of the distribution family.
For example, the family of Gaussian distributions forms a 2D manifold with coordinates :
Statistical Manifold of Gaussians N(μ, σ²):
σ (scale)
│
│ each point = one Gaussian
4 │ * * * *
│ [curve] [curve] [curve] [curve]
2 │ * * * *
│
1 │ * * * *
│
└─────────────────────────────────── μ (mean)
-4 0 4 8
Moving along this manifold changes the distribution.
What is the "distance" between two nearby points?
The central question of information geometry: what is the right notion of distance between nearby probability distributions?
The Fisher Information Matrix
The Fisher information matrix defines a natural Riemannian metric on the statistical manifold. For a model :
The vector is called the score function. The FIM is the covariance matrix of the score function.
Alternative Definition via KL Divergence
The Fisher information matrix appears as the second-order term in the Taylor expansion of KL divergence:
This is the key geometric insight: the FIM defines the local curvature of the KL divergence landscape. Moving a distance in parameter space changes the distribution by approximately in KL terms.
Fisher Information for Common Distributions
Bernoulli() (single parameter):
The Fisher information is large when is near 0 or 1 (the distribution is nearly deterministic - small changes have large effects) and small near (the distribution is maximally uncertain - small changes have small effects).
Gaussian (two parameters):
The parameters are orthogonal in Fisher space - the geometry of Gaussians separates mean and variance.
Softmax with classes:
This is a positive semi-definite matrix that, in neural networks, becomes block-structured corresponding to the weight matrices.
Computing the Fisher Information Matrix
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
def fisher_information_bernoulli(p: float) -> float:
"""Fisher information for Bernoulli(p): F(p) = 1/(p*(1-p))."""
assert 0 < p < 1, "p must be in (0, 1)"
return 1.0 / (p * (1 - p))
def fisher_information_gaussian_analytical() -> np.ndarray:
"""
Fisher information matrix for N(mu, sigma^2).
Parameters: theta = (mu, sigma).
Returns 2x2 FIM.
"""
# F = diag(1/sigma^2, 2/sigma^2)
# Note: this depends on sigma; here we return the structure.
# For sigma=1: F = diag(1, 2)
return np.array([[1.0, 0.0],
[0.0, 2.0]]) # at sigma=1
def empirical_fisher(
model: nn.Module,
data_loader,
n_samples: int = 100,
device: str = "cpu",
) -> list:
"""
Compute empirical Fisher: F = (1/N) * sum_i grad_i * grad_i^T
where grad_i = gradient of log p(y_i | x_i) w.r.t. parameters.
This is expensive (O(p^2) memory) - for illustration only.
Returns per-layer diagonal approximation.
"""
model.eval()
diagonal_fisher = [torch.zeros_like(p) for p in model.parameters()]
count = 0
for x, y in data_loader:
if count >= n_samples:
break
x, y = x.to(device), y.to(device)
model.zero_grad()
output = model(x)
log_probs = F.log_softmax(output, dim=-1)
# Sample from model distribution for empirical Fisher
sampled_y = torch.multinomial(torch.exp(log_probs), 1).squeeze()
loss = F.nll_loss(log_probs, sampled_y)
loss.backward()
for i, param in enumerate(model.parameters()):
if param.grad is not None:
diagonal_fisher[i] += param.grad.data ** 2
count += x.size(0)
# Average over samples
for d in diagonal_fisher:
d.div_(count)
return diagonal_fisher
# Demo: Fisher information for Bernoulli
print("=== Fisher Information for Bernoulli(p) ===")
print(f"{'p':>6} | {'F(p)':>10} | {'Interpretation'}")
print("-" * 50)
for p in [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99]:
fi = fisher_information_bernoulli(p)
interp = "near-deterministic" if fi > 50 else ("uncertain" if fi < 5 else "")
print(f"{p:>6.2f} | {fi:>10.2f} | {interp}")
# Fisher is highest near 0 and 1: the distribution is most "sensitive" there
The Natural Gradient
Standard gradient descent moves in parameter space - it treats all parameter directions equally:
But parameter space has a misleading geometry. Moving the same Euclidean distance in different directions can change the model's behavior by wildly different amounts.
Natural gradient descent moves in distribution space - it respects the geometry induced by the FIM:
The FIM preconditions the gradient: it stretches directions where the loss landscape is flat (needs larger steps) and compresses directions where it is steep (needs smaller steps) - but measured in distribution space, not parameter space.
Why "Natural"?
The natural gradient is the direction of steepest descent when measured by KL divergence rather than Euclidean distance in parameter space:
Interpretation: "Move in the direction that most reduces the loss, subject to not changing the model's distribution by more than in KL terms."
This is "natural" because it is invariant to the parameterization of the model - the same update regardless of how you parameterize the probability distribution.
Natural Gradient vs. Standard Gradient: Geometric Intuition
Standard gradient (Euclidean):
Parameter space ──────────────────────────────────────
θ₁ * current θ
*
* ← gradient
* points "down the hill"
* in Euclidean parameter space
─────────────────────────────────
θ₂
Problem: equal Euclidean steps in θ-space can cause
very different changes in p(x;θ) distribution.
Near a saturating sigmoid, large δθ → tiny δp.
Near a linear region, small δθ → large δp.
Natural gradient (Fisher-Riemannian):
Distribution space ───────────────────────────────────
The step δθ is rescaled so it always corresponds to
the same KL distance in distribution space.
This makes the update scale-invariant.
Practical Computation: The Problem with FIM
For a neural network with parameters, the FIM is - storing it requires memory. For a 100M parameter model, that's entries.
Several approximations make natural gradient tractable:
1. Diagonal Fisher (Adaptive Learning Rates)
Approximate . The natural gradient update becomes:
This is exactly the adaptive learning rate in Adagrad, RMSprop, and (approximately) Adam:
- Adagrad: (sum of squared gradients)
- RMSprop: (exponential moving average)
- Adam: + bias correction
:::info ML Connection - Adam as Approximate Natural Gradient Adam's diagonal preconditioning divides each gradient by the running average of squared gradients. This approximates under the assumption that the Fisher information is diagonal.
The connection explains why adaptive learning rates work: they are an approximation to the geometrically correct (natural gradient) update in distribution space, scaled per-parameter by the local Fisher information. :::
2. Block-Diagonal Fisher (K-FAC)
Kronecker-Factored Approximate Curvature (K-FAC) approximates the FIM for each layer independently. For a linear layer :
where:
- is the input covariance
- is the pre-activation gradient covariance
This Kronecker product structure allows efficient computation:
import torch
import torch.nn.functional as F
class KFACLinear(torch.nn.Module):
"""
Linear layer with K-FAC Fisher information approximation.
Tracks A = E[x x^T] and B = E[delta_y delta_y^T].
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
self.bias = torch.nn.Parameter(torch.zeros(out_features))
# K-FAC factors
self.register_buffer('A', torch.eye(in_features)) # input covariance
self.register_buffer('B', torch.eye(out_features)) # grad covariance
self.n_updates = 0
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Store input for K-FAC update
self._input = x
return F.linear(x, self.weight, self.bias)
def update_kfac_factors(self, output_grad: torch.Tensor, momentum: float = 0.9):
"""Update the Kronecker factors from current batch."""
x = self._input.detach() # (batch, in)
g = output_grad.detach() # (batch, out)
# Update input covariance: A = mom*A + (1-mom)*E[x x^T]
A_new = torch.einsum('bi,bj->ij', x, x) / x.size(0)
self.A = momentum * self.A + (1 - momentum) * A_new
# Update gradient covariance: B = mom*B + (1-mom)*E[g g^T]
B_new = torch.einsum('bi,bj->ij', g, g) / g.size(0)
self.B = momentum * self.B + (1 - momentum) * B_new
def natural_gradient(
self,
gradient: torch.Tensor,
damping: float = 1e-3,
) -> torch.Tensor:
"""
Compute natural gradient: F^{-1} g ≈ (A^{-1} ⊗ B^{-1}) vec(G)
= B^{-1} G A^{-1} (for matrix G = gradient of W)
"""
# Add damping for numerical stability: F + λI
A_damp = self.A + damping * torch.eye(self.A.size(0), device=self.A.device)
B_damp = self.B + damping * torch.eye(self.B.size(0), device=self.B.device)
# Solve: nat_grad = B^{-1} * grad * A^{-1}
A_inv = torch.linalg.inv(A_damp)
B_inv = torch.linalg.inv(B_damp)
return B_inv @ gradient @ A_inv
# Demo: Natural gradient vs. standard gradient step
torch.manual_seed(42)
layer = KFACLinear(4, 3)
# Simulate forward/backward pass
x = torch.randn(16, 4)
target = torch.randint(0, 3, (16,))
output = layer(x)
loss = F.cross_entropy(output, target)
loss.backward()
# Simulate K-FAC factor update
with torch.no_grad():
layer.update_kfac_factors(layer.weight.grad.T) # simplified
# Compare gradient magnitudes
grad = layer.weight.grad
nat_grad = layer.natural_gradient(grad)
print("=== Natural Gradient vs. Standard Gradient ===")
print(f"Standard gradient norm: {grad.norm().item():.6f}")
print(f"Natural gradient norm: {nat_grad.norm().item():.6f}")
print(f"Ratio: {nat_grad.norm().item() / grad.norm().item():.4f}")
print("\n(Natural gradient rescales updates to be uniform in distribution space)")
Fisher Information and the Cramér-Rao Bound
The Fisher information has a fundamental statistical interpretation: it quantifies how much information a sample carries about the parameter.
Cramér-Rao Bound: For any unbiased estimator of :
The inverse Fisher information is a lower bound on the variance of any unbiased estimator. More Fisher information = smaller estimation uncertainty = better inference.
For neural networks, this has an optimization interpretation: parameters with high Fisher information (high curvature in distribution space) are more sensitive - their estimation is more constrained. Parameters with low Fisher information are relatively free to move.
Amari's Natural Gradient Convergence Properties
Why is the natural gradient faster than standard gradient descent?
Key result (Amari, 1998): Under appropriate conditions, natural gradient descent converges to a local minimum in a number of steps independent of the number of parameters, while standard gradient descent's convergence depends on the condition number of the FIM.
For a quadratic loss near an optimum:
- Standard GD convergence rate: where is the condition number
- Natural GD convergence rate: approaches (independent of condition number)
Condition number and convergence:
Loss landscape: Condition number:
Long, thin valley κ = λ_max/λ_min >> 1
(sharp valley walls)
→ Standard GD zigzags
→ Natural GD goes straight
Round bowl κ ≈ 1
→ Standard GD works fine
→ Natural GD = Standard GD
Connection to Second-Order Optimization
The natural gradient is closely related to Newton's method. Newton's update is:
where is the Hessian of the loss. The natural gradient uses the FIM instead of the Hessian:
When are they equal? For generalized linear models and log-linear models, the FIM equals the expected Hessian of the log-likelihood. So for cross-entropy loss (log-likelihood), natural gradient ≈ Newton's method.
For general neural networks, they differ - the FIM captures the curvature of the output distribution, while the Hessian captures the curvature of the loss landscape including label information.
Optimizer | Curvature Matrix | Cost | Accuracy
-------------------|---------------------|-------------|------------------
SGD | Identity | O(p) | Depends on LR
Momentum | Identity (damped) | O(p) | Better than SGD
Adagrad/Adam | Diagonal of FIM | O(p) | Good approximation
K-FAC | Block Kronecker FIM | O(p^1.5) | Near-exact per layer
Newton's method | Full Hessian | O(p^3) | Exact (local)
Natural gradient | Full FIM | O(p^3) | Exact (distributional)
Practical Impact: When to Use Natural Gradient Methods
def compare_convergence():
"""
Illustrate why Fisher preconditioning helps on ill-conditioned problems.
This is a simplified 2D example showing the geometric advantage.
"""
import numpy as np
# Ill-conditioned quadratic: L(theta) = 0.5 * theta^T * H * theta
# H with high condition number
H = np.array([[100.0, 0.0],
[0.0, 1.0]]) # condition number = 100
# Fisher information (for this simplified problem = H)
F = H
theta_gd = np.array([1.0, 1.0])
theta_ng = np.array([1.0, 1.0])
lr = 0.01
n_steps = 50
loss_gd, loss_ng = [], []
for _ in range(n_steps):
grad = H @ theta_gd
theta_gd -= lr * grad
loss_gd.append(0.5 * theta_gd @ H @ theta_gd)
# Natural gradient with FIM preconditioning
for _ in range(n_steps):
grad = H @ theta_ng
nat_grad = np.linalg.inv(F) @ grad # = H^{-1} * H * theta = theta
theta_ng -= lr * nat_grad
loss_ng.append(0.5 * theta_ng @ H @ theta_ng)
print("\n=== Convergence Comparison (ill-conditioned problem) ===")
print(f"{'Step':>6} | {'GD Loss':>12} | {'Nat GD Loss':>12}")
print("-" * 35)
for step in [0, 5, 10, 20, 49]:
print(f"{step:>6} | {loss_gd[step]:>12.6f} | {loss_ng[step]:>12.6f}")
compare_convergence()
Interview Questions and Answers
Q1: What is the Fisher information matrix and why does it matter for optimization?
The Fisher information matrix measures the average amount of information a sample carries about the parameters . Geometrically, it defines a local Riemannian metric on the manifold of probability distributions: the "distance" between and in distribution space is approximately .
For optimization, it matters because the standard Euclidean metric on parameter space is misleading. The same Euclidean step in different directions causes vastly different changes to the model's outputs. The FIM rescales directions according to their actual impact on the distribution - directions with high FIM are "expensive" (small changes have big effects), directions with low FIM are "cheap."
Q2: What is the natural gradient and how does it differ from the standard gradient?
The natural gradient is the direction of steepest loss descent when distance is measured in distribution space (KL divergence) rather than Euclidean parameter space.
Standard gradient: "Move in the direction that most reduces loss per unit of Euclidean distance in parameter space."
Natural gradient: "Move in the direction that most reduces loss per unit of KL divergence from the current distribution."
The natural gradient is invariant to reparameterization - if you rescale or rotate parameters, the natural gradient update produces the same behavioral change to the model. The standard gradient does not have this property.
Q3: How does Adam relate to the natural gradient?
Adam maintains a running estimate of (the second moment of gradients) and divides each gradient component by . This is a diagonal approximation to the natural gradient update:
Under the assumption that the Fisher information matrix is diagonal (off-diagonal correlations between parameters are ignored), , and Adam's update is approximately .
This explains why Adam converges faster than SGD on ill-conditioned problems: it is an approximate second-order method that accounts for different curvatures in different parameter directions.
Q4: What is K-FAC and why is it more accurate than Adam?
K-FAC (Kronecker-Factored Approximate Curvature) approximates the Fisher information matrix for each weight matrix as a Kronecker product: , where is the input covariance and is the pre-activation gradient covariance.
This is more accurate than Adam's diagonal approximation because it captures correlations between rows and columns of the weight matrix (within a layer), not just per-parameter curvature. It is more expensive ( per layer), but the Kronecker structure makes it much cheaper than the full FIM (). K-FAC is used in large-scale training (language models, RL) where the better curvature estimation significantly improves convergence.
Q5: Explain the Cramér-Rao bound and what it says about estimation limits.
The Cramér-Rao bound states that for any unbiased estimator of a parameter : . No unbiased estimator can achieve lower variance than the inverse Fisher information.
The bound is tight: the maximum likelihood estimator (MLE) achieves it asymptotically - as .
For neural networks, this has a regularization interpretation: parameters with high Fisher information (sensitive parameters) have a natural precision limit. When training on samples, the effective estimation variance for parameter is approximately . This motivates Fisher-weighted regularization - penalizing changes to high-Fisher parameters more strongly, which is the foundation of Elastic Weight Consolidation (EWC) for continual learning.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the KL Divergence demo on the EngineersOfAI Playground - no code required.
:::
