Variational Inference
Why MCMC Doesn't Scale to Deep Learning
Imagine you want full Bayesian uncertainty quantification for a neural network with 10 million parameters. MCMC requires:
- Drawing thousands of samples from the posterior
- Each sample involves evaluating the entire network forward pass
- Chains mix slowly in 10-million-dimensional space
- Diagnosis requires multiple chains and thousands of steps
Even modern hardware makes this intractable at neural network scale.
Variational inference (VI) solves this by turning Bayesian inference into an optimization problem. Instead of sampling from the posterior, we approximate it with a simpler distribution (the variational distribution), and find the parameters that make as close as possible to the true posterior .
The result: instead of MCMC sampling (slow, exact), we do gradient descent (fast, approximate). This is why VAEs, Bayesian neural networks, and scalable probabilistic ML exist.
The Evidence Lower Bound (ELBO): Derivation from Scratch
We want to minimize the KL divergence from our approximation to the true posterior :
Expanding using Bayes theorem :
Since doesn't depend on :
Rearranging (and noting KL ):
The ELBO (Evidence Lower BOund) is a lower bound on the log evidence. Maximizing the ELBO is equivalent to minimizing the KL divergence to the posterior.
Interpretation:
- First term: How well does the variational distribution explain the data? (Reconstruction quality)
- Second term: How far is from the prior? (Complexity penalty)
This is exactly the VAE loss function!
Mean-Field Variational Inference
The most common VI approach: assume the variational distribution factorizes across parameters:
This "mean-field" assumption makes computation tractable but ignores correlations between parameters.
Coordinate ascent VI (CAVI): Maximize the ELBO by updating each factor one at a time, holding others fixed.
The optimal update for factor is:
where means expectation over all with .
For exponential family models, this often gives closed-form updates. The algorithm converges to a local optimum of the ELBO (not necessarily global).
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
class MeanFieldVI_BayesianLinearRegression:
"""
Mean-field VI for Bayesian linear regression.
Model: y = X w + eps, w ~ N(0, alpha^{-1} I), eps ~ N(0, beta^{-1})
Variational distribution: q(w) = N(m, S) where S is diagonal
"""
def __init__(self, alpha=1.0, beta=1.0):
self.alpha = alpha # prior precision on weights
self.beta = beta # likelihood precision (1/noise variance)
self.m = None # variational mean
self.s2 = None # variational variance (diagonal)
def fit(self, X, y, n_iter=100):
"""
For Bayesian linear regression with Gaussian prior, the posterior
IS Gaussian (conjugate). Here we compute it both exactly and
via VI to illustrate that VI recovers the exact solution.
"""
n, d = X.shape
# Exact posterior (for comparison)
S_N_inv = self.alpha * np.eye(d) + self.beta * X.T @ X
self.S_exact = np.linalg.inv(S_N_inv)
self.m_exact = self.beta * self.S_exact @ X.T @ y
# Mean-field VI: approximate posterior as diagonal Gaussian
# Initialize
self.m = np.zeros(d)
self.s2 = np.ones(d) / self.alpha
elbos = []
for iteration in range(n_iter):
# CAVI update for diagonal variances
self.s2 = 1.0 / (self.alpha + self.beta * np.sum(X**2, axis=0))
# CAVI update for mean
# For Bayesian linear regression, mean update is closed-form
residuals = y - X @ self.m
for j in range(d):
x_j = X[:, j]
r_j = y - X @ self.m + self.m[j] * x_j # leave-one-out residual
self.m[j] = self.beta * self.s2[j] * (x_j @ r_j)
# Compute ELBO
elbo = self._compute_elbo(X, y)
elbos.append(elbo)
return elbos
def _compute_elbo(self, X, y):
n, d = X.shape
# E_q[log p(y|X,w)] = -beta/2 * E_q[||y - Xw||^2] + const
sq_error = np.sum((y - X @ self.m)**2)
trace_term = np.sum(np.sum(X**2, axis=0) * self.s2)
expected_loglik = (n/2)*np.log(self.beta/(2*np.pi)) - \
self.beta/2 * (sq_error + trace_term)
# -KL(q(w) || p(w)) for diagonal Gaussian vs isotropic Gaussian
kl = 0.5 * np.sum(self.alpha * (self.m**2 + self.s2) -
np.log(self.alpha * self.s2) - 1)
return expected_loglik - kl
def predict(self, X_new):
"""Posterior predictive mean and variance."""
y_mean = X_new @ self.m
# Predictive variance: parameter uncertainty + noise
y_var = (1/self.beta) + np.sum(X_new**2 * self.s2, axis=1)
return y_mean, np.sqrt(y_var)
# Demonstrate VI vs exact posterior
np.random.seed(42)
n, d = 50, 3
X = np.random.randn(n, d)
true_w = np.array([1.5, -0.8, 0.3])
y = X @ true_w + np.random.randn(n) * 0.5
vi_model = MeanFieldVI_BayesianLinearRegression(alpha=1.0, beta=4.0)
elbos = vi_model.fit(X, y, n_iter=100)
print("Mean-Field VI vs Exact Posterior:")
print(f"VI mean: {vi_model.m}")
print(f"Exact mean: {vi_model.m_exact}")
print(f"True w: {true_w}")
print(f"\nVI variance (diag): {vi_model.s2}")
print(f"Exact variance (diag): {np.diag(vi_model.S_exact)}")
print(f"\nELBO at convergence: {elbos[-1]:.4f}")
print("Note: For Gaussian posterior, VI recovers exact posterior (no approximation error)")
The Reparameterization Trick: The Bridge to Deep Learning
The challenge with optimizing the ELBO with gradient descent: the ELBO involves an expectation over the variational distribution, which depends on the parameters we're optimizing. Naive Monte Carlo gradient estimation has high variance.
The reparameterization trick: Instead of sampling , express as a deterministic function of and a noise variable :
For a Gaussian variational distribution :
Now gradients can flow through the sampling operation:
This gives low-variance gradient estimates that work beautifully with backpropagation.
import torch
import torch.nn as nn
import torch.nn.functional as F
class VariationalAutoencoder(nn.Module):
"""
VAE: canonical application of variational inference with reparameterization.
The encoder computes q(z|x) = N(mu, sigma^2), a Gaussian approximate posterior.
The decoder computes p(x|z), the generative model.
The ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z))
"""
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
super().__init__()
self.latent_dim = latent_dim
# Encoder: x -> (mu, log_var) of q(z|x)
self.encoder_fc1 = nn.Linear(input_dim, hidden_dim)
self.encoder_mu = nn.Linear(hidden_dim, latent_dim)
self.encoder_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder: z -> p(x|z)
self.decoder_fc1 = nn.Linear(latent_dim, hidden_dim)
self.decoder_out = nn.Linear(hidden_dim, input_dim)
def encode(self, x):
"""Compute variational parameters mu and log_var of q(z|x)."""
h = F.relu(self.encoder_fc1(x))
mu = self.encoder_mu(h)
log_var = self.encoder_logvar(h) # log variance (unconstrained)
return mu, log_var
def reparameterize(self, mu, log_var):
"""
Reparameterization trick: z = mu + sigma * epsilon
epsilon ~ N(0, I) -- the randomness is in epsilon, not in z
This allows gradients to flow through z to the encoder parameters.
"""
if self.training:
std = torch.exp(0.5 * log_var) # sigma = exp(log_var / 2)
epsilon = torch.randn_like(std) # epsilon ~ N(0, I)
return mu + std * epsilon # z = mu + sigma * epsilon
else:
return mu # At test time, use mean for deterministic prediction
def decode(self, z):
"""Generate x from latent z."""
h = F.relu(self.decoder_fc1(z))
return torch.sigmoid(self.decoder_out(h))
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_reconstructed = self.decode(z)
return x_reconstructed, mu, log_var
def elbo_loss(self, x, x_reconstructed, mu, log_var):
"""
ELBO loss (negated for minimization).
ELBO = E_q[log p(x|z)] - KL(q(z|x) || p(z))
Reconstruction term: binary cross-entropy = -E_q[log p(x|z)]
KL term: closed-form for Gaussian q vs Gaussian prior p(z) = N(0, I)
"""
# Reconstruction: -E_q[log p(x|z)] (binary cross-entropy)
batch_size = x.shape[0]
reconstruction_loss = F.binary_cross_entropy(
x_reconstructed, x, reduction='sum'
) / batch_size
# KL divergence: KL(N(mu, sigma^2) || N(0, I))
# Closed form: 0.5 * sum(1 + log_var - mu^2 - exp(log_var))
kl_loss = -0.5 * torch.sum(
1 + log_var - mu.pow(2) - log_var.exp()
) / batch_size
return reconstruction_loss + kl_loss
# Training loop
def train_vae(model, data_loader, optimizer, epochs=10):
model.train()
for epoch in range(epochs):
total_loss = 0
for x_batch, _ in data_loader:
x_batch = x_batch.view(-1, 784) # Flatten
optimizer.zero_grad()
x_recon, mu, log_var = model(x_batch)
loss = model.elbo_loss(x_batch, x_recon, mu, log_var)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}: ELBO loss = {total_loss/len(data_loader):.4f}")
# Initialize model
vae = VariationalAutoencoder(input_dim=784, hidden_dim=400, latent_dim=20)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
print("VAE model created successfully")
print(f"Encoder output: mu and log_var, each of dim {vae.latent_dim}")
print(f"KL term regularizes latent space toward N(0, I) prior")
print(f"Reconstruction term pushes encoder to capture data information")
The KL Direction: Why VI Underestimates Variance
A subtle but critical point: VI minimizes (forward KL), not (reverse KL). These have different behaviors:
- "zero-forcing": must be zero wherever is zero. This means tends to concentrate on one mode of , ignoring other modes. The approximation is overconfident - it fits one mode well and ignores the rest.
- "zero-avoiding": must cover all regions where is non-zero. This leads to a distribution that spreads over all modes, potentially being overdispersed.
True posterior p(theta) -- bimodal VI approximation q(theta)
┌──┐ ┌──┐ ┌───┐
│ │ │ │ vi ──┘ └──
│ │ │ │ minimizes
────────┘ └────┘ └──── KL(q||p) --> picks one mode!
mode 1 mode 2
For unimodal posteriors (the common case), VI is excellent. For multimodal posteriors, VI may miss modes. This is the primary approximation error of mean-field VI.
VI vs MCMC: The Engineering Decision
| Criterion | MCMC | Variational Inference |
|---|---|---|
| Asymptotic correctness | Yes (ergodic chains → true posterior) | No (ELBO is a lower bound; KL is not zero) |
| Scales to neural networks | No (sampling in millions of dims is intractable) | Yes (optimization in parameter space) |
| Handles multimodal posteriors | Better (can explore multiple modes) | Poorly (zero-forcing KL may miss modes) |
| Computation | Slow (thousands of chain steps) | Fast (gradient steps, minibatch SGD) |
| Uncertainty quantification | Full posterior distribution | Approximate posterior, often underestimates variance |
| Tuning required | Proposal scaling, warmup length | Learning rate, architecture of |
| Production usage | Research, small models, scientific computing | Deep learning (VAEs, BNNs), recommender systems |
| Correctness checking | R-hat, trace plots, ESS | ELBO convergence, posterior predictive checks |
The practical rule: Use MCMC when correctness matters and the model has < 10,000 parameters. Use VI when you need to scale to neural networks or need fast inference.
Amortized Variational Inference
In standard VI, we optimize for a specific dataset . In amortized VI, we train a neural network that maps any input directly to variational parameters - learning to do inference across all possible inputs simultaneously.
This is the core innovation in VAEs: the encoder network IS the amortized inference network. Given a new image , it instantly outputs - the variational parameters for the latent distribution . No per-image optimization required.
Amortized inference trades optimization quality (for any specific ) for speed (inference is a single forward pass).
Interview Questions
Q1: Derive the ELBO and explain what each term represents.
Starting from , and noting KL , we get - hence "lower bound on the evidence." Expanding: . First term: expected log-likelihood under - how well the approximate posterior explains the data (reconstruction quality in VAE terms). Second term: negative KL from to prior - penalizes how much the approximate posterior deviates from the prior (regularization in VAE terms). Maximizing the ELBO simultaneously fits the data and stays close to the prior, exactly like regularized MLE, but with full posterior uncertainty.
Q2: What is the reparameterization trick and why is it necessary?
Without the reparameterization trick, the ELBO gradient with respect to is . We can't push the gradient inside the expectation because depends on . The naive approach - REINFORCE gradient estimator - works but has very high variance, making optimization slow. The reparameterization trick: write with (distribution independent of ). Then . Now the gradient can flow through via backpropagation. For Gaussian : , . The gradient with respect to and (or log ) can be computed by automatic differentiation. This is what makes training VAEs feasible.
Q3: Why does mean-field VI tend to underestimate posterior variance?
Mean-field VI minimizes (forward KL). This is "zero-forcing": whenever the true posterior , the approximation must also be zero (otherwise the KL is infinite). For a multimodal posterior, cannot simultaneously cover all modes under the mean-field factorization - it will collapse onto the dominant mode. More subtly, even for unimodal posteriors, the zero-forcing property means tends to be narrower than : it "fits inside" rather than covering it. The result is overconfidence - predictive intervals are too narrow. In practice, this matters for safety-critical applications where uncertainty must be reliable. One mitigation: use richer variational families (normalizing flows, full-covariance Gaussians) that can better approximate the true posterior.
Q4: Explain the VAE ELBO loss in terms of reconstruction and regularization.
The VAE ELBO is: . Training maximizes this (minimizes negative ELBO). The reconstruction term measures how well the decoder reconstructs from latent samples . For binary data and Bernoulli decoder, this is binary cross-entropy. The KL term regularizes the latent space: it forces the approximate posterior to stay close to the standard Gaussian prior . With a Gaussian approximate posterior, this KL has a closed form: . The trade-off: high reconstruction quality wants a sharp posterior (small ), but the KL pushes toward a diffuse prior. The balance produces a smooth, structured latent space where interpolation generates meaningful images.
Q5: What is amortized inference, and why is it important for scalable Bayesian deep learning?
Standard VI optimizes a set of variational parameters for a specific dataset, requiring re-optimization for every new input. Amortized inference trains a neural network that maps any input directly to variational parameters - the network learns the inference mapping, not just the parameters for one dataset. The key benefits: (1) inference for new inputs is a single forward pass - O(1) rather than O(many optimization steps); (2) the inference network shares information across inputs - learning to do inference for one image helps for similar images; (3) scales to large datasets with minibatch stochastic optimization. The VAE encoder is an amortized inference network. Limitations: the amortization gap - the amortized approximation is generally worse than optimizing per-input, because the encoder must use a fixed architecture for all inputs. Iterative amortized inference (refining the encoder output with a few optimization steps at test time) reduces this gap.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Variational Inference demo on the EngineersOfAI Playground - no code required.
:::
