Skip to main content

Iterative Solvers - Conjugate Gradient, Krylov Methods, and Large-Scale ML

Reading time: ~22 minutes | Level: Numerical Methods → Large-Scale ML

You have a neural network with 100 million parameters. You want to use a second-order optimizer like K-FAC or natural gradient descent, which requires solving a linear system involving the Fisher information matrix. That matrix is 108×10810^8 \times 10^8.

You cannot form it explicitly - it would require 101610^{16} bytes of storage. You cannot use LU factorization - it would take O(n3)=O(1024)O(n^3) = O(10^{24}) operations.

But you can compute matrix-vector products FvFv efficiently (just two forward passes through the network). And that is all iterative solvers need.

What You Will Learn

  • Why iterative methods succeed where direct methods fail for large systems
  • The conjugate gradient (CG) method: derivation and implementation
  • GMRES and other Krylov subspace methods
  • Preconditioning: the key to fast convergence
  • Applications to second-order optimization in ML
  • SciPy's sparse linear algebra interface

Part 1 - When Direct Methods Fail

The curse of scale

Direct methods (LU, QR, Cholesky) require:

  • Storage: O(n2)O(n^2) to store the matrix
  • Time: O(n3)O(n^3) to factorize
nnO(n2)O(n^2) storage (float32)O(n3)O(n^3) FLOPs
10310^34 MB10910^9 - feasible
10610^64 TB101810^{18} - years of computation
10810^840 PBimpossible

For ML-scale problems, direct methods are fundamentally off the table. We need algorithms that:

  1. Never form the full matrix explicitly
  2. Converge to acceptable accuracy in far fewer than nn iterations
  3. Only require matrix-vector products AvAv

The iterative idea

Start with initial guess x0x_0. Generate a sequence:

x0x1x2xx_0 \to x_1 \to x_2 \to \cdots \to x^*

where each xkx_k better approximates the solution. Stop when the residual rk=bAxkr_k = b - Ax_k is small enough.

The simplest iterative method is gradient descent on the quadratic:

minxf(x)=12xTAxbTxf=Axb\min_x f(x) = \frac{1}{2} x^T A x - b^T x \quad \Rightarrow \quad \nabla f = Ax - b

This connects linear system solving directly to optimization - a natural bridge for ML engineers.

Part 2 - Conjugate Gradient: The Gold Standard for SPD Systems

Setup and motivation

For solving Ax=bAx = b where AA is symmetric positive definite (SPD), conjugate gradient (CG) is the method of choice.

Gradient descent on f(x)=12xTAxbTxf(x) = \frac{1}{2} x^T A x - b^T x converges slowly because successive directions are orthogonal but not AA-orthogonal (conjugate). CG fixes this.

The conjugate gradient algorithm

Given: A (SPD), b, initial guess x_0
r_0 = b - A x_0 (initial residual)
p_0 = r_0 (initial search direction)
k = 0

LOOP until ‖r_k‖ < tolerance:
α_k = (r_k · r_k) / (p_k · A p_k) (optimal step size)
x_{k+1} = x_k + α_k p_k (update solution)
r_{k+1} = r_k - α_k A p_k (update residual)
β_k = (r_{k+1} · r_{k+1}) / (r_k · r_k) (Fletcher-Reeves)
p_{k+1} = r_{k+1} + β_k p_k (conjugate search direction)
k = k + 1

Key property: The search directions p0,p1,p_0, p_1, \ldots are mutually AA-conjugate: piTApj=0p_i^T A p_j = 0 for iji \neq j. This means CG spans an nn-dimensional search space in exactly nn steps - it is a direct method in exact arithmetic.

In practice (finite precision + tolerance criterion), CG converges in far fewer than nn iterations.

import numpy as np

def conjugate_gradient(
A_matvec, # Function: v -> A @ v (no need to form A explicitly)
b: np.ndarray,
x0: np.ndarray = None,
tol: float = 1e-6,
maxiter: int = None
) -> tuple:
"""
Conjugate Gradient for solving Ax = b where A is symmetric positive definite.

A_matvec: callable - computes A @ v without forming A explicitly.
This is the key: we only need matrix-vector products.
"""
n = len(b)
if x0 is None:
x0 = np.zeros(n)
if maxiter is None:
maxiter = 2 * n

x = x0.copy()
r = b - A_matvec(x) # Initial residual
p = r.copy() # Initial search direction
r_dot = r @ r # ‖r‖²

residuals = [np.sqrt(r_dot)]

for k in range(maxiter):
if np.sqrt(r_dot) < tol * np.linalg.norm(b):
break

Ap = A_matvec(p)
alpha = r_dot / (p @ Ap) # Optimal step size

x = x + alpha * p # Update solution
r = r - alpha * Ap # Update residual

r_dot_new = r @ r
beta = r_dot_new / r_dot # Fletcher-Reeves coefficient
p = r + beta * p # New conjugate direction

r_dot = r_dot_new
residuals.append(np.sqrt(r_dot))

return x, residuals

# --- Example: solve a large SPD system ---
import numpy as np

n = 1000
# Generate SPD matrix via A = X^T X + I (but DON'T store it - use matvec)
np.random.seed(42)
X = np.random.randn(n, n)
# In practice, A_matvec might call a neural network or physics simulator
A = X.T @ X + n * np.eye(n) # Only for validation; normally you'd never form this

def A_matvec(v: np.ndarray) -> np.ndarray:
"""Matrix-vector product without forming A explicitly."""
return X.T @ (X @ v) + n * v # = (X^T X + nI) v

b = np.random.randn(n)
x_cg, residuals = conjugate_gradient(A_matvec, b)

# Verify
residual = np.linalg.norm(b - A @ x_cg)
print(f"Final residual: {residual:.2e}")
print(f"Iterations: {len(residuals)}")

# Compare with direct solve
x_direct = np.linalg.solve(A, b)
print(f"CG vs direct difference: {np.linalg.norm(x_cg - x_direct):.2e}")

Convergence rate

CG converges at rate:

xkxA2(κ1κ+1)kx0xA\|x_k - x^*\|_A \leq 2 \left(\frac{\sqrt{\kappa} - 1}{\sqrt{\kappa} + 1}\right)^k \|x_0 - x^*\|_A

where κ=κ(A)\kappa = \kappa(A) is the condition number. High κ\kappa → slow convergence. This is the motivation for preconditioning.

Part 3 - Preconditioning: Transforming the Problem

Why preconditioning matters

CG on the original problem Ax=bAx = b with κ(A)=106\kappa(A) = 10^6 may require millions of iterations. Preconditioning transforms the problem to have a much smaller condition number.

Left preconditioning: Solve M1Ax=M1bM^{-1}Ax = M^{-1}b instead of Ax=bAx = b.

If MAM \approx A, then M1AIM^{-1}A \approx I and κ(M1A)1\kappa(M^{-1}A) \approx 1 - near-instant convergence.

The preconditioner MM must satisfy:

  1. MAM \approx A (so κ(M1A)κ(A)\kappa(M^{-1}A) \ll \kappa(A))
  2. M1vM^{-1}v is cheap to compute (otherwise the speedup is lost)
import numpy as np
from scipy.sparse.linalg import cg, LinearOperator

def preconditioned_cg(
A_matvec,
b: np.ndarray,
M_inv_matvec, # Preconditioner: computes M^{-1} v
tol: float = 1e-8
) -> np.ndarray:
"""
Preconditioned Conjugate Gradient.
M_inv_matvec should be cheap and approximate A^{-1} v.
"""
n = len(b)
x = np.zeros(n)
r = b - A_matvec(x)
z = M_inv_matvec(r) # Apply preconditioner
p = z.copy()
rz = r @ z

for _ in range(n):
Ap = A_matvec(p)
alpha = rz / (p @ Ap)

x = x + alpha * p
r = r - alpha * Ap

if np.linalg.norm(r) < tol:
break

z = M_inv_matvec(r)
rz_new = r @ z
beta = rz_new / rz
p = z + beta * p
rz = rz_new

return x

# --- Common preconditioners ---

# 1. Diagonal (Jacobi) preconditioner - cheapest, moderate effectiveness
def jacobi_preconditioner(A: np.ndarray):
"""M = diag(A), M^{-1}v = v / diag(A)"""
diag = np.diag(A)
return lambda v: v / diag

# 2. Incomplete Cholesky - for SPD systems, more effective
def incomplete_cholesky_preconditioner(A_sparse):
"""For large sparse SPD systems - see scipy.sparse.linalg"""
from scipy.sparse.linalg import spilu
# For non-SPD: use ILU; for SPD: use ICC (incomplete Cholesky)
ilu = spilu(A_sparse)
return lambda v: ilu.solve(v)

Preconditioners in ML

ML ContextPreconditionerRationale
Second-order optimizationFisher information diagonalKronecker-factored approximation
K-FAC optimizerKronecker-factored curvatureApproximate Fisher per layer
Federated learningLocal Hessian approximationEach client provides local curvature
Graph neural networksDegree normalization D1/2AD1/2D^{-1/2} A D^{-1/2}Spectral normalization of adjacency

Part 4 - GMRES for Non-Symmetric Systems

Conjugate Gradient requires AA to be symmetric positive definite. For non-symmetric systems (common in physics simulations, some ML applications), use GMRES (Generalized Minimal Residual).

The GMRES idea

GMRES finds the best solution xkx_k in the Krylov subspace:

Kk(A,b)=span{b,Ab,A2b,,Ak1b}\mathcal{K}_k(A, b) = \text{span}\{b, Ab, A^2b, \ldots, A^{k-1}b\}

by minimizing bAxk\|b - Ax_k\| over this subspace. Each iteration adds one vector to the subspace via Arnoldi process.

from scipy.sparse.linalg import gmres, LinearOperator
import numpy as np
import scipy.sparse as sp

# Example: solve a large sparse non-symmetric system
n = 10000

# Create sparse non-symmetric matrix
offsets = [-1, 0, 1, 10]
data = [
-1.0 * np.ones(n-1),
4.0 * np.ones(n),
-1.0 * np.ones(n-1),
-0.5 * np.ones(n-10)
]
A_sparse = sp.diags(data, offsets, shape=(n, n), format='csr')

b = np.ones(n)

# Solve with GMRES (good for non-symmetric)
x_gmres, info = gmres(A_sparse, b, tol=1e-8, restart=50)
print(f"GMRES converged: {info == 0}")
print(f"Residual: {np.linalg.norm(b - A_sparse @ x_gmres):.2e}")

# With preconditioning - dramatically speeds up convergence
from scipy.sparse.linalg import spilu
ilu = spilu(A_sparse)
M_precond = LinearOperator((n, n), matvec=ilu.solve)
x_precond, info = gmres(A_sparse, b, M=M_precond, tol=1e-8)
print(f"Preconditioned GMRES residual: {np.linalg.norm(b - A_sparse @ x_precond):.2e}")

Krylov subspace methods: the family

Krylov Subspace Methods
├── Symmetric Positive Definite (A = A^T, A ≻ 0)
│ └── Conjugate Gradient (CG) - optimal

├── Symmetric Indefinite (A = A^T, not necessarily PD)
│ └── MINRES - minimizes residual for symmetric systems

├── Non-Symmetric
│ ├── GMRES - minimizes residual, requires O(k²) memory
│ ├── BiCGSTAB - short recurrences, less memory than GMRES
│ └── LGMRES - GMRES with augmented subspace

└── Normal Equations (always applicable, often slow)
└── CGLS / LSQR - for least-squares problems

Part 5 - Applications to ML

Second-order optimization: K-FAC

K-FAC (Kronecker-Factored Approximate Curvature) approximates the Fisher information matrix using Kronecker products, then uses CG to solve the resulting system at each step.

import numpy as np

class KFACLinearSolveExample:
"""
Conceptual demonstration of how K-FAC uses CG internally.
In practice, use: https://github.com/tensorflow/kfac
"""

def natural_gradient_update(
self,
gradients: np.ndarray,
fisher_matvec, # F @ v without forming F explicitly
damping: float = 1e-3,
cg_tol: float = 1e-4
) -> np.ndarray:
"""
Compute F^{-1} g using CG.

g: gradient vector (n_params,)
F: Fisher information matrix (n_params x n_params)
damping: regularization to make F+λI well-conditioned

The natural gradient update δθ = F^{-1} g
has better scaling properties than vanilla SGD.
"""
# (F + λI) v = g → solve for v = (F + λI)^{-1} g
def damped_fisher_matvec(v):
return fisher_matvec(v) + damping * v

natural_grad, _ = conjugate_gradient(
damped_fisher_matvec, gradients, tol=cg_tol
)
return natural_grad

Conjugate gradient in GPU-accelerated ML

import torch

def torch_cg(
A_matvec,
b: torch.Tensor,
x0: torch.Tensor = None,
tol: float = 1e-6,
maxiter: int = 100
) -> torch.Tensor:
"""
Conjugate Gradient implemented in PyTorch.
Works on GPU - A_matvec can call neural network layers.
"""
if x0 is None:
x = torch.zeros_like(b)
else:
x = x0.clone()

r = b - A_matvec(x)
p = r.clone()
r_dot = torch.dot(r, r)

for _ in range(maxiter):
if r_dot.sqrt() < tol:
break

Ap = A_matvec(p)
alpha = r_dot / torch.dot(p, Ap)

x = x + alpha * p
r = r - alpha * Ap

r_dot_new = torch.dot(r, r)
beta = r_dot_new / r_dot
p = r + beta * p
r_dot = r_dot_new

return x


def implicit_hessian_matvec(loss_fn, params, v):
"""
Compute Hessian-vector product H @ v using two backward passes.
This is the Pearlmutter trick - O(n) cost, never forms the Hessian.
"""
# Forward pass with gradient tracking
grad = torch.autograd.grad(loss_fn(), params, create_graph=True)
grad_vec = torch.cat([g.flatten() for g in grad])

# Second pass: derivative of grad · v w.r.t. params
grad_v = torch.dot(grad_vec, v)
hvp = torch.autograd.grad(grad_v, params)
return torch.cat([h.flatten() for h in hvp])

Solving normal equations for large-scale linear regression

from scipy.sparse.linalg import cg as scipy_cg, LinearOperator
import numpy as np

def large_scale_ridge_regression(
X: np.ndarray,
y: np.ndarray,
lambda_reg: float = 1e-3
) -> np.ndarray:
"""
Solve Ridge regression for large X using CG on the normal equations.
Avoids forming X^T X explicitly.

Normal equations: (X^T X + λI) β = X^T y
"""
m, n = X.shape

# Matrix-vector product without forming X^T X
def matvec(v: np.ndarray) -> np.ndarray:
return X.T @ (X @ v) + lambda_reg * v # O(m*n) per call

A_op = LinearOperator((n, n), matvec=matvec, dtype=X.dtype)
rhs = X.T @ y # X^T y - only computed once

# Solve with CG - converges in O(sqrt(κ)) iterations
beta, info = scipy_cg(A_op, rhs, tol=1e-6, maxiter=500)

if info == 0:
print("CG converged successfully")
else:
print(f"CG did not converge (info={info})")

return beta

Part 6 - When Iterative Beats Direct

Choose Iterative Solver When:
├── n > 10^4 (direct methods become slow/infeasible)
├── Matrix is never explicitly formed (second-order ML optimization)
├── Matrix is sparse (>95% zeros) - direct LU fills in zeros
├── You need only moderate accuracy (10^-4 to 10^-6 residual)
├── A good preconditioner is available
└── You can compute A @ v efficiently (e.g., via autodiff)

Choose Direct Solver When:
├── n < 10^4 - direct is fast and exact
├── Many right-hand sides with same A (LU once, solve many times)
├── Very high accuracy required (< 10^-12 residual)
└── Matrix fits in memory and is not sparse

Interview Questions

Q1: What makes conjugate gradient more efficient than gradient descent for quadratic optimization?

Both CG and gradient descent minimize f(x)=12xTAxbTxf(x) = \frac{1}{2}x^T Ax - b^T x (equivalent to solving Ax=bAx = b for SPD AA).

Gradient descent: Each step uses the current gradient (= residual rk=bAxkr_k = b - Ax_k) as the search direction. Successive gradients are orthogonal: rkrk+1r_k \perp r_{k+1}. This causes zigzagging - you repeatedly search in directions you've already explored, requiring many iterations to traverse a "canyon" in the objective.

Conjugate gradient: Search directions are chosen to be AA-conjugate: piTApj=0p_i^T A p_j = 0 for iji \neq j. This means minimizing along pip_i never "undoes" progress along pjp_j. Each new direction is orthogonal to the previous improvement.

Key theoretical result: CG on a problem with AA having only kk distinct eigenvalues converges in exactly kk steps (in exact arithmetic), regardless of nn. For a matrix with clustered eigenvalues, CG is extremely efficient.

Practical convergence: CG reduces the error by factor (κ1)/(κ+1)(\sqrt{\kappa}-1)/(\sqrt{\kappa}+1) per iteration. For κ=100\kappa = 100: CG needs ~50 iterations to reach 10610^{-6} accuracy. Gradient descent would need ~5000.

Q2: What is a Krylov subspace and why do iterative solvers build solutions within it?

The Krylov subspace of order kk for matrix AA and vector bb is:

Kk(A,b)=span{b,Ab,A2b,,Ak1b}\mathcal{K}_k(A, b) = \text{span}\{b, Ab, A^2b, \ldots, A^{k-1}b\}

Why this subspace? The solution x=A1bx^* = A^{-1}b can be written as a polynomial in AA applied to bb:

x=p(A)b,p(A)=c0I+c1A++cn1An1x^* = p(A) b, \quad p(A) = c_0 I + c_1 A + \cdots + c_{n-1} A^{n-1}

(This follows from the Cayley-Hamilton theorem: AnA^n can be expressed as a polynomial in lower powers of AA.)

So the exact solution lies in Kn(A,b)\mathcal{K}_n(A, b). Krylov methods build increasingly accurate approximations within Kk\mathcal{K}_k for k=1,2,k = 1, 2, \ldots, adding one matrix-vector product per step.

The efficiency: Each step only requires computing AvA \cdot v once, plus O(k)O(k) vector operations. No need to form or store AA explicitly.

CG vs GMRES distinction: CG minimizes the AA-norm of the error within Kk\mathcal{K}_k. GMRES minimizes the Euclidean norm of the residual. For symmetric positive definite AA, CG's norm is more natural and allows shorter recurrences (constant storage), whereas GMRES needs to store all kk basis vectors (O(kn)O(kn) storage).

Q3: What is preconditioning and why is it essential for fast CG convergence?

CG convergence rate depends on the condition number: κ\kappa iterations needed for each decade of accuracy reduction. For κ=106\kappa = 10^6, this is prohibitively slow.

Preconditioning transforms Ax=bAx = b into M1Ax=M1bM^{-1}Ax = M^{-1}b where MAM \approx A and M1vM^{-1}v is cheap to compute. If κ(M1A)κ(A)\kappa(M^{-1}A) \ll \kappa(A), CG on the preconditioned system converges much faster.

Common preconditioners:

  1. Jacobi (diagonal): M=diag(A)M = \text{diag}(A). Trivial to apply, effective when diagonal dominates. O(n)O(n) cost per step.

  2. SSOR (Symmetric Successive Over-Relaxation): Uses both forward and backward sweeps. Moderate effectiveness, O(n)O(n) cost.

  3. Incomplete LU (ILU): Factorizes ALUA \approx L'U' keeping only the sparsity pattern of AA. Effective for sparse systems. O(nnz)O(\text{nnz}) cost.

  4. Multigrid: Solves a hierarchy of coarsened problems. Can achieve κ(M1A)=O(1)\kappa(M^{-1}A) = O(1) for some PDE-based problems - O(n)O(n) total solve time.

ML applications:

  • K-FAC: Uses Kronecker-factored blocks of the Fisher matrix as preconditioner for natural gradient CG
  • Adam as diagonal preconditioner: Adam's 1/vt1/\sqrt{v_t} adaptive learning rate is effectively a diagonal preconditioner for gradient descent
Q4: How does the Hessian-vector product (HVP) trick enable large-scale second-order optimization?

Second-order optimization requires solving Hδ=gH\delta = g where HH is the Hessian (n×nn \times n for nn parameters). For 100M parameter models, HH requires 101610^{16} bytes - impossible to store.

The Pearlmutter trick computes HvHv for any vector vv using only two backward passes through the network - without ever forming HH:

# Forward pass with gradient graph
with torch.enable_grad():
outputs = model(inputs)
loss = criterion(outputs, targets)

# First backward: compute gradients
grads = torch.autograd.grad(loss, model.parameters(), create_graph=True)
grad_vec = torch.cat([g.flatten() for g in grads])

# Second backward: d(grad · v)/dθ = H @ v
hessian_vec_prod = torch.autograd.grad(
torch.dot(grad_vec, v),
model.parameters()
)

Cost: O(n)O(n) - same as two gradient computations. Compare to O(n2)O(n^2) for building the full Hessian.

Using HVP with CG: Newton's update δ=H1g\delta^* = H^{-1}g can be computed by running CG with A_matvec = HVP. Each CG iteration requires one HVP. For well-conditioned Hessians (or with good preconditioning), convergence in 20–50 iterations is typical - matching the cost of 40–100 gradient computations.

This is the foundation of Hessian-free optimization and is used in K-FAC, KFAC-Reduce, and other practical second-order methods.

Q5: When would you choose GMRES over conjugate gradient?

Use CG when: AA is symmetric positive definite (SPD).

  • Shorter recurrences → constant O(n)O(n) memory per iteration
  • Theoretically optimal convergence for SPD systems
  • Examples: Gaussian process covariance solves, ridge regression normal equations, SPD physics matrices

Use GMRES when: AA is non-symmetric or indefinite.

  • More general: works for any non-singular AA
  • Builds orthogonal Krylov basis via Arnoldi process
  • Memory grows as O(kn)O(kn) - typically use GMRES(m) with restart every mm steps
  • Examples: non-symmetric PDEs, certain ML applications with non-symmetric kernel matrices

BiCGSTAB: A middle ground - works for non-symmetric systems but uses only O(n)O(n) memory (like CG). Convergence is less predictable than GMRES but memory-efficient for large systems.

Practical rule for ML:

  • Fisher information matrix is SPD → CG
  • Hessian at a saddle point is indefinite → MINRES (not CG!)
  • Asymmetric recurrent network Jacobians → GMRES
  • Large sparse non-symmetric systems (physics-informed NNs, graph problems) → GMRES with ILU preconditioner

Quick Reference

MethodMatrix TypeMemoryWhen to Use
Conjugate GradientSPD onlyO(n)O(n)Natural gradient, Gaussian processes, ridge
MINRESSymmetricO(n)O(n)Symmetric indefinite systems
GMRES(m)AnyO(mn)O(mn)Non-symmetric, restart every mm steps
BiCGSTABAnyO(n)O(n)Non-symmetric, memory-constrained
LSQRLeast squaresO(n)O(n)Overdetermined systems, regression
PreconditionerEffectivenessCost per stepUse case
Jacobi (diagonal)Low–moderateO(n)O(n)Always try first
SSORModerateO(n)O(n)Diagonally dominant systems
ILU(0)GoodO(nnz)O(\text{nnz})Sparse systems
MultigridExcellentO(n)O(n)PDE-based problems
K-FAC blocksGoodO(n)O(n) per layerNeural network natural gradient

Next: Lesson 04: Numerical Differentiation →

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Conjugate Gradient demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.