Iterative Solvers - Conjugate Gradient, Krylov Methods, and Large-Scale ML
Reading time: ~22 minutes | Level: Numerical Methods → Large-Scale ML
You have a neural network with 100 million parameters. You want to use a second-order optimizer like K-FAC or natural gradient descent, which requires solving a linear system involving the Fisher information matrix. That matrix is .
You cannot form it explicitly - it would require bytes of storage. You cannot use LU factorization - it would take operations.
But you can compute matrix-vector products efficiently (just two forward passes through the network). And that is all iterative solvers need.
What You Will Learn
- Why iterative methods succeed where direct methods fail for large systems
- The conjugate gradient (CG) method: derivation and implementation
- GMRES and other Krylov subspace methods
- Preconditioning: the key to fast convergence
- Applications to second-order optimization in ML
- SciPy's sparse linear algebra interface
Part 1 - When Direct Methods Fail
The curse of scale
Direct methods (LU, QR, Cholesky) require:
- Storage: to store the matrix
- Time: to factorize
| storage (float32) | FLOPs | |
|---|---|---|
| 4 MB | - feasible | |
| 4 TB | - years of computation | |
| 40 PB | impossible |
For ML-scale problems, direct methods are fundamentally off the table. We need algorithms that:
- Never form the full matrix explicitly
- Converge to acceptable accuracy in far fewer than iterations
- Only require matrix-vector products
The iterative idea
Start with initial guess . Generate a sequence:
where each better approximates the solution. Stop when the residual is small enough.
The simplest iterative method is gradient descent on the quadratic:
This connects linear system solving directly to optimization - a natural bridge for ML engineers.
Part 2 - Conjugate Gradient: The Gold Standard for SPD Systems
Setup and motivation
For solving where is symmetric positive definite (SPD), conjugate gradient (CG) is the method of choice.
Gradient descent on converges slowly because successive directions are orthogonal but not -orthogonal (conjugate). CG fixes this.
The conjugate gradient algorithm
Given: A (SPD), b, initial guess x_0
r_0 = b - A x_0 (initial residual)
p_0 = r_0 (initial search direction)
k = 0
LOOP until ‖r_k‖ < tolerance:
α_k = (r_k · r_k) / (p_k · A p_k) (optimal step size)
x_{k+1} = x_k + α_k p_k (update solution)
r_{k+1} = r_k - α_k A p_k (update residual)
β_k = (r_{k+1} · r_{k+1}) / (r_k · r_k) (Fletcher-Reeves)
p_{k+1} = r_{k+1} + β_k p_k (conjugate search direction)
k = k + 1
Key property: The search directions are mutually -conjugate: for . This means CG spans an -dimensional search space in exactly steps - it is a direct method in exact arithmetic.
In practice (finite precision + tolerance criterion), CG converges in far fewer than iterations.
import numpy as np
def conjugate_gradient(
A_matvec, # Function: v -> A @ v (no need to form A explicitly)
b: np.ndarray,
x0: np.ndarray = None,
tol: float = 1e-6,
maxiter: int = None
) -> tuple:
"""
Conjugate Gradient for solving Ax = b where A is symmetric positive definite.
A_matvec: callable - computes A @ v without forming A explicitly.
This is the key: we only need matrix-vector products.
"""
n = len(b)
if x0 is None:
x0 = np.zeros(n)
if maxiter is None:
maxiter = 2 * n
x = x0.copy()
r = b - A_matvec(x) # Initial residual
p = r.copy() # Initial search direction
r_dot = r @ r # ‖r‖²
residuals = [np.sqrt(r_dot)]
for k in range(maxiter):
if np.sqrt(r_dot) < tol * np.linalg.norm(b):
break
Ap = A_matvec(p)
alpha = r_dot / (p @ Ap) # Optimal step size
x = x + alpha * p # Update solution
r = r - alpha * Ap # Update residual
r_dot_new = r @ r
beta = r_dot_new / r_dot # Fletcher-Reeves coefficient
p = r + beta * p # New conjugate direction
r_dot = r_dot_new
residuals.append(np.sqrt(r_dot))
return x, residuals
# --- Example: solve a large SPD system ---
import numpy as np
n = 1000
# Generate SPD matrix via A = X^T X + I (but DON'T store it - use matvec)
np.random.seed(42)
X = np.random.randn(n, n)
# In practice, A_matvec might call a neural network or physics simulator
A = X.T @ X + n * np.eye(n) # Only for validation; normally you'd never form this
def A_matvec(v: np.ndarray) -> np.ndarray:
"""Matrix-vector product without forming A explicitly."""
return X.T @ (X @ v) + n * v # = (X^T X + nI) v
b = np.random.randn(n)
x_cg, residuals = conjugate_gradient(A_matvec, b)
# Verify
residual = np.linalg.norm(b - A @ x_cg)
print(f"Final residual: {residual:.2e}")
print(f"Iterations: {len(residuals)}")
# Compare with direct solve
x_direct = np.linalg.solve(A, b)
print(f"CG vs direct difference: {np.linalg.norm(x_cg - x_direct):.2e}")
Convergence rate
CG converges at rate:
where is the condition number. High → slow convergence. This is the motivation for preconditioning.
Part 3 - Preconditioning: Transforming the Problem
Why preconditioning matters
CG on the original problem with may require millions of iterations. Preconditioning transforms the problem to have a much smaller condition number.
Left preconditioning: Solve instead of .
If , then and - near-instant convergence.
The preconditioner must satisfy:
- (so )
- is cheap to compute (otherwise the speedup is lost)
import numpy as np
from scipy.sparse.linalg import cg, LinearOperator
def preconditioned_cg(
A_matvec,
b: np.ndarray,
M_inv_matvec, # Preconditioner: computes M^{-1} v
tol: float = 1e-8
) -> np.ndarray:
"""
Preconditioned Conjugate Gradient.
M_inv_matvec should be cheap and approximate A^{-1} v.
"""
n = len(b)
x = np.zeros(n)
r = b - A_matvec(x)
z = M_inv_matvec(r) # Apply preconditioner
p = z.copy()
rz = r @ z
for _ in range(n):
Ap = A_matvec(p)
alpha = rz / (p @ Ap)
x = x + alpha * p
r = r - alpha * Ap
if np.linalg.norm(r) < tol:
break
z = M_inv_matvec(r)
rz_new = r @ z
beta = rz_new / rz
p = z + beta * p
rz = rz_new
return x
# --- Common preconditioners ---
# 1. Diagonal (Jacobi) preconditioner - cheapest, moderate effectiveness
def jacobi_preconditioner(A: np.ndarray):
"""M = diag(A), M^{-1}v = v / diag(A)"""
diag = np.diag(A)
return lambda v: v / diag
# 2. Incomplete Cholesky - for SPD systems, more effective
def incomplete_cholesky_preconditioner(A_sparse):
"""For large sparse SPD systems - see scipy.sparse.linalg"""
from scipy.sparse.linalg import spilu
# For non-SPD: use ILU; for SPD: use ICC (incomplete Cholesky)
ilu = spilu(A_sparse)
return lambda v: ilu.solve(v)
Preconditioners in ML
| ML Context | Preconditioner | Rationale |
|---|---|---|
| Second-order optimization | Fisher information diagonal | Kronecker-factored approximation |
| K-FAC optimizer | Kronecker-factored curvature | Approximate Fisher per layer |
| Federated learning | Local Hessian approximation | Each client provides local curvature |
| Graph neural networks | Degree normalization | Spectral normalization of adjacency |
Part 4 - GMRES for Non-Symmetric Systems
Conjugate Gradient requires to be symmetric positive definite. For non-symmetric systems (common in physics simulations, some ML applications), use GMRES (Generalized Minimal Residual).
The GMRES idea
GMRES finds the best solution in the Krylov subspace:
by minimizing over this subspace. Each iteration adds one vector to the subspace via Arnoldi process.
from scipy.sparse.linalg import gmres, LinearOperator
import numpy as np
import scipy.sparse as sp
# Example: solve a large sparse non-symmetric system
n = 10000
# Create sparse non-symmetric matrix
offsets = [-1, 0, 1, 10]
data = [
-1.0 * np.ones(n-1),
4.0 * np.ones(n),
-1.0 * np.ones(n-1),
-0.5 * np.ones(n-10)
]
A_sparse = sp.diags(data, offsets, shape=(n, n), format='csr')
b = np.ones(n)
# Solve with GMRES (good for non-symmetric)
x_gmres, info = gmres(A_sparse, b, tol=1e-8, restart=50)
print(f"GMRES converged: {info == 0}")
print(f"Residual: {np.linalg.norm(b - A_sparse @ x_gmres):.2e}")
# With preconditioning - dramatically speeds up convergence
from scipy.sparse.linalg import spilu
ilu = spilu(A_sparse)
M_precond = LinearOperator((n, n), matvec=ilu.solve)
x_precond, info = gmres(A_sparse, b, M=M_precond, tol=1e-8)
print(f"Preconditioned GMRES residual: {np.linalg.norm(b - A_sparse @ x_precond):.2e}")
Krylov subspace methods: the family
Krylov Subspace Methods
├── Symmetric Positive Definite (A = A^T, A ≻ 0)
│ └── Conjugate Gradient (CG) - optimal
│
├── Symmetric Indefinite (A = A^T, not necessarily PD)
│ └── MINRES - minimizes residual for symmetric systems
│
├── Non-Symmetric
│ ├── GMRES - minimizes residual, requires O(k²) memory
│ ├── BiCGSTAB - short recurrences, less memory than GMRES
│ └── LGMRES - GMRES with augmented subspace
│
└── Normal Equations (always applicable, often slow)
└── CGLS / LSQR - for least-squares problems
Part 5 - Applications to ML
Second-order optimization: K-FAC
K-FAC (Kronecker-Factored Approximate Curvature) approximates the Fisher information matrix using Kronecker products, then uses CG to solve the resulting system at each step.
import numpy as np
class KFACLinearSolveExample:
"""
Conceptual demonstration of how K-FAC uses CG internally.
In practice, use: https://github.com/tensorflow/kfac
"""
def natural_gradient_update(
self,
gradients: np.ndarray,
fisher_matvec, # F @ v without forming F explicitly
damping: float = 1e-3,
cg_tol: float = 1e-4
) -> np.ndarray:
"""
Compute F^{-1} g using CG.
g: gradient vector (n_params,)
F: Fisher information matrix (n_params x n_params)
damping: regularization to make F+λI well-conditioned
The natural gradient update δθ = F^{-1} g
has better scaling properties than vanilla SGD.
"""
# (F + λI) v = g → solve for v = (F + λI)^{-1} g
def damped_fisher_matvec(v):
return fisher_matvec(v) + damping * v
natural_grad, _ = conjugate_gradient(
damped_fisher_matvec, gradients, tol=cg_tol
)
return natural_grad
Conjugate gradient in GPU-accelerated ML
import torch
def torch_cg(
A_matvec,
b: torch.Tensor,
x0: torch.Tensor = None,
tol: float = 1e-6,
maxiter: int = 100
) -> torch.Tensor:
"""
Conjugate Gradient implemented in PyTorch.
Works on GPU - A_matvec can call neural network layers.
"""
if x0 is None:
x = torch.zeros_like(b)
else:
x = x0.clone()
r = b - A_matvec(x)
p = r.clone()
r_dot = torch.dot(r, r)
for _ in range(maxiter):
if r_dot.sqrt() < tol:
break
Ap = A_matvec(p)
alpha = r_dot / torch.dot(p, Ap)
x = x + alpha * p
r = r - alpha * Ap
r_dot_new = torch.dot(r, r)
beta = r_dot_new / r_dot
p = r + beta * p
r_dot = r_dot_new
return x
def implicit_hessian_matvec(loss_fn, params, v):
"""
Compute Hessian-vector product H @ v using two backward passes.
This is the Pearlmutter trick - O(n) cost, never forms the Hessian.
"""
# Forward pass with gradient tracking
grad = torch.autograd.grad(loss_fn(), params, create_graph=True)
grad_vec = torch.cat([g.flatten() for g in grad])
# Second pass: derivative of grad · v w.r.t. params
grad_v = torch.dot(grad_vec, v)
hvp = torch.autograd.grad(grad_v, params)
return torch.cat([h.flatten() for h in hvp])
Solving normal equations for large-scale linear regression
from scipy.sparse.linalg import cg as scipy_cg, LinearOperator
import numpy as np
def large_scale_ridge_regression(
X: np.ndarray,
y: np.ndarray,
lambda_reg: float = 1e-3
) -> np.ndarray:
"""
Solve Ridge regression for large X using CG on the normal equations.
Avoids forming X^T X explicitly.
Normal equations: (X^T X + λI) β = X^T y
"""
m, n = X.shape
# Matrix-vector product without forming X^T X
def matvec(v: np.ndarray) -> np.ndarray:
return X.T @ (X @ v) + lambda_reg * v # O(m*n) per call
A_op = LinearOperator((n, n), matvec=matvec, dtype=X.dtype)
rhs = X.T @ y # X^T y - only computed once
# Solve with CG - converges in O(sqrt(κ)) iterations
beta, info = scipy_cg(A_op, rhs, tol=1e-6, maxiter=500)
if info == 0:
print("CG converged successfully")
else:
print(f"CG did not converge (info={info})")
return beta
Part 6 - When Iterative Beats Direct
Choose Iterative Solver When:
├── n > 10^4 (direct methods become slow/infeasible)
├── Matrix is never explicitly formed (second-order ML optimization)
├── Matrix is sparse (>95% zeros) - direct LU fills in zeros
├── You need only moderate accuracy (10^-4 to 10^-6 residual)
├── A good preconditioner is available
└── You can compute A @ v efficiently (e.g., via autodiff)
Choose Direct Solver When:
├── n < 10^4 - direct is fast and exact
├── Many right-hand sides with same A (LU once, solve many times)
├── Very high accuracy required (< 10^-12 residual)
└── Matrix fits in memory and is not sparse
Interview Questions
Q1: What makes conjugate gradient more efficient than gradient descent for quadratic optimization?
Both CG and gradient descent minimize (equivalent to solving for SPD ).
Gradient descent: Each step uses the current gradient (= residual ) as the search direction. Successive gradients are orthogonal: . This causes zigzagging - you repeatedly search in directions you've already explored, requiring many iterations to traverse a "canyon" in the objective.
Conjugate gradient: Search directions are chosen to be -conjugate: for . This means minimizing along never "undoes" progress along . Each new direction is orthogonal to the previous improvement.
Key theoretical result: CG on a problem with having only distinct eigenvalues converges in exactly steps (in exact arithmetic), regardless of . For a matrix with clustered eigenvalues, CG is extremely efficient.
Practical convergence: CG reduces the error by factor per iteration. For : CG needs ~50 iterations to reach accuracy. Gradient descent would need ~5000.
Q2: What is a Krylov subspace and why do iterative solvers build solutions within it?
The Krylov subspace of order for matrix and vector is:
Why this subspace? The solution can be written as a polynomial in applied to :
(This follows from the Cayley-Hamilton theorem: can be expressed as a polynomial in lower powers of .)
So the exact solution lies in . Krylov methods build increasingly accurate approximations within for , adding one matrix-vector product per step.
The efficiency: Each step only requires computing once, plus vector operations. No need to form or store explicitly.
CG vs GMRES distinction: CG minimizes the -norm of the error within . GMRES minimizes the Euclidean norm of the residual. For symmetric positive definite , CG's norm is more natural and allows shorter recurrences (constant storage), whereas GMRES needs to store all basis vectors ( storage).
Q3: What is preconditioning and why is it essential for fast CG convergence?
CG convergence rate depends on the condition number: iterations needed for each decade of accuracy reduction. For , this is prohibitively slow.
Preconditioning transforms into where and is cheap to compute. If , CG on the preconditioned system converges much faster.
Common preconditioners:
-
Jacobi (diagonal): . Trivial to apply, effective when diagonal dominates. cost per step.
-
SSOR (Symmetric Successive Over-Relaxation): Uses both forward and backward sweeps. Moderate effectiveness, cost.
-
Incomplete LU (ILU): Factorizes keeping only the sparsity pattern of . Effective for sparse systems. cost.
-
Multigrid: Solves a hierarchy of coarsened problems. Can achieve for some PDE-based problems - total solve time.
ML applications:
- K-FAC: Uses Kronecker-factored blocks of the Fisher matrix as preconditioner for natural gradient CG
- Adam as diagonal preconditioner: Adam's adaptive learning rate is effectively a diagonal preconditioner for gradient descent
Q4: How does the Hessian-vector product (HVP) trick enable large-scale second-order optimization?
Second-order optimization requires solving where is the Hessian ( for parameters). For 100M parameter models, requires bytes - impossible to store.
The Pearlmutter trick computes for any vector using only two backward passes through the network - without ever forming :
# Forward pass with gradient graph
with torch.enable_grad():
outputs = model(inputs)
loss = criterion(outputs, targets)
# First backward: compute gradients
grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
grad_vec = torch.cat([g.flatten() for g in grads])
# Second backward: d(grad · v)/dθ = H @ v
hessian_vec_prod = torch.autograd.grad(
torch.dot(grad_vec, v),
model.parameters()
)
Cost: - same as two gradient computations. Compare to for building the full Hessian.
Using HVP with CG: Newton's update can be computed by running CG with A_matvec = HVP. Each CG iteration requires one HVP. For well-conditioned Hessians (or with good preconditioning), convergence in 20–50 iterations is typical - matching the cost of 40–100 gradient computations.
This is the foundation of Hessian-free optimization and is used in K-FAC, KFAC-Reduce, and other practical second-order methods.
Q5: When would you choose GMRES over conjugate gradient?
Use CG when: is symmetric positive definite (SPD).
- Shorter recurrences → constant memory per iteration
- Theoretically optimal convergence for SPD systems
- Examples: Gaussian process covariance solves, ridge regression normal equations, SPD physics matrices
Use GMRES when: is non-symmetric or indefinite.
- More general: works for any non-singular
- Builds orthogonal Krylov basis via Arnoldi process
- Memory grows as - typically use GMRES(m) with restart every steps
- Examples: non-symmetric PDEs, certain ML applications with non-symmetric kernel matrices
BiCGSTAB: A middle ground - works for non-symmetric systems but uses only memory (like CG). Convergence is less predictable than GMRES but memory-efficient for large systems.
Practical rule for ML:
- Fisher information matrix is SPD → CG
- Hessian at a saddle point is indefinite → MINRES (not CG!)
- Asymmetric recurrent network Jacobians → GMRES
- Large sparse non-symmetric systems (physics-informed NNs, graph problems) → GMRES with ILU preconditioner
Quick Reference
| Method | Matrix Type | Memory | When to Use |
|---|---|---|---|
| Conjugate Gradient | SPD only | Natural gradient, Gaussian processes, ridge | |
| MINRES | Symmetric | Symmetric indefinite systems | |
| GMRES(m) | Any | Non-symmetric, restart every steps | |
| BiCGSTAB | Any | Non-symmetric, memory-constrained | |
| LSQR | Least squares | Overdetermined systems, regression |
| Preconditioner | Effectiveness | Cost per step | Use case |
|---|---|---|---|
| Jacobi (diagonal) | Low–moderate | Always try first | |
| SSOR | Moderate | Diagonally dominant systems | |
| ILU(0) | Good | Sparse systems | |
| Multigrid | Excellent | PDE-based problems | |
| K-FAC blocks | Good | per layer | Neural network natural gradient |
Next: Lesson 04: Numerical Differentiation →
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Conjugate Gradient demo on the EngineersOfAI Playground - no code required.
:::
