import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem';
Variational Autoencoders - Learning Latent Distributions with Evidence Lower Bound
Reading time: 50–60 minutes Interview relevance: Very High - VAEs appear in generative modeling, representation learning, and research rounds Target roles: Machine Learning Engineer, Research Engineer, AI Engineer
The Real Interview Moment
It is December 2013. Diederik Kingma is a PhD student at the University of Amsterdam. He and his advisor Max Welling are wrestling with a problem that has frustrated the generative modeling community for years.
Everyone wants to learn a generative model of images - something that can understand the underlying structure of the data and generate new examples from that structure. The mathematical framework is clear: model the joint distribution , where is a low-dimensional latent representation and is the observed image. To learn the model, you need to optimize the log-likelihood:
The integral is intractable. Sampling-based methods like MCMC converge too slowly for training on large datasets. The field is stuck.
Kingma's insight is elegant: what if you use a neural network to approximate the posterior , and then train everything end-to-end with backpropagation? The encoder neural network takes and outputs parameters of a Gaussian distribution over . The decoder takes and reconstructs . The whole system is differentiable - almost. There is a sampling step in between: you need to sample from the encoder's distribution to feed the decoder. Sampling is not differentiable.
Then Kingma invents the reparameterization trick. Instead of sampling directly, compute where . The randomness is isolated in , which has no parameters. The gradient flows through and normally.
The VAE paper ships in December 2013. It accumulates over 30,000 citations. It becomes one of the most influential papers in deep learning history - not because it generated the most realistic images, but because it showed how to combine probabilistic generative models with deep learning in a principled, trainable way.
The Problem: Why Autoencoders Are Not Generative
A standard autoencoder learns a deterministic mapping:
- Encoder: - compress to a latent code
- Decoder: - reconstruct from
Trained to minimize reconstruction loss , the encoder can map each training example to any point in latent space. There is no constraint on the structure of that space.
The consequence: the latent space has holes. If you sample a random point from the latent space and decode it, you get nonsense - because that region of -space was never assigned to any training example. The autoencoder memorizes a mapping; it does not learn a generative model.
Standard Autoencoder Latent Space VAE Latent Space
───────────────────────────────── ─────────────────────────
Discrete clusters Continuous smooth manifold
Holes between clusters Gaussian-distributed, no holes
Sample random z → garbage Sample z ~ N(0,I) → valid image
Cannot interpolate smoothly Smooth interpolation between points
Not generative Generative by construction
The VAE fixes this by forcing the latent space to conform to a known prior distribution . The encoder does not output a point - it outputs a distribution, and sampling happens from that distribution. The KL term in the ELBO enforces that this distribution stays close to the standard normal prior.
The Generative Model
The VAE posits the following generative story:
- Sample a latent code:
- Decode to observation parameters: - parameterized by a neural network
- Sample an observation:
For images with pixel values in : For continuous data:
The joint distribution:
The marginal likelihood (what we want to maximize):
This integral is intractable. The VAE's solution: introduce an approximate posterior.
The Variational Posterior
Introduce an inference network (encoder) that approximates the true posterior .
The encoder takes as input and outputs the mean and log-variance of a diagonal Gaussian.
Deriving the ELBO
Start from the marginal log-likelihood and introduce :
By Jensen's inequality ( is concave, swap with expectation):
Decompose:
This is the Evidence Lower BOund (ELBO):
The gap between the ELBO and is exactly . Maximizing the ELBO simultaneously:
- Makes the decoder explain the data well (reconstruction).
- Makes the encoder's distribution close to the prior (regularization).
KL Divergence - Closed Form for Gaussians
For the diagonal Gaussian encoder vs. standard normal prior :
This is computed analytically - no sampling required. The gradient flows cleanly through and .
Intuition of each term:
- : penalizes large variance (pushes toward unit variance)
- : penalizes large mean (pushes toward zero mean)
- : constant normalization
- : prevents variance from collapsing to zero
The Reparameterization Trick
The reconstruction term requires sampling . Sampling is not differentiable - we cannot backpropagate through a random node.
The reparameterization trick: instead of sampling directly, compute:
Now is a deterministic function of the network parameters and a fixed-distribution noise . The gradient flows through and :
Estimate with a single sample per data point per step - low-variance gradient.
Mermaid: VAE Architecture
Code: Complete VAE Implementation
Full VAE in PyTorch (MNIST)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import os
class Encoder(nn.Module):
"""
Inference network: q_phi(z|x) = N(mu, diag(sigma^2))
Input: flattened image [batch, 784]
Output: mu [batch, latent_dim], log_var [batch, latent_dim]
"""
def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
def forward(self, x):
h = self.shared(x)
mu = self.fc_mu(h)
log_var = self.fc_log_var(h)
return mu, log_var
class Decoder(nn.Module):
"""
Generative network: p_theta(x|z) = Bernoulli(g_theta(z))
Input: latent code [batch, latent_dim]
Output: reconstruction probabilities [batch, input_dim]
"""
def __init__(self, latent_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.Sigmoid(), # output in [0,1] for Bernoulli
)
def forward(self, z):
return self.net(z)
class VAE(nn.Module):
"""
Variational Autoencoder (Kingma & Welling 2013).
"""
def __init__(
self,
input_dim: int = 784,
hidden_dim: int = 512,
latent_dim: int = 20,
):
super().__init__()
self.latent_dim = latent_dim
self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
self.decoder = Decoder(latent_dim, hidden_dim, input_dim)
def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
"""
z = mu + sigma * eps, eps ~ N(0, I)
sigma = exp(0.5 * log_var)
"""
if self.training:
std = (0.5 * log_var).exp()
eps = torch.randn_like(std)
return mu + std * eps
else:
# At eval time, use the mean (MAP estimate)
return mu
def forward(self, x: torch.Tensor):
mu, log_var = self.encoder(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decoder(z)
return x_recon, mu, log_var
def elbo_loss(self, x: torch.Tensor, x_recon: torch.Tensor,
mu: torch.Tensor, log_var: torch.Tensor,
beta: float = 1.0) -> dict:
"""
ELBO = E_q[log p(x|z)] - beta * KL(q(z|x) || p(z))
Reconstruction: binary cross-entropy (Bernoulli likelihood)
KL: closed-form for diagonal Gaussians vs N(0,I)
"""
batch_size = x.size(0)
# Reconstruction term: -E_q[log p(x|z)]
recon_loss = F.binary_cross_entropy(x_recon, x, reduction="sum") / batch_size
# KL term: KL(N(mu, sigma^2) || N(0,1))
# = 0.5 * sum(sigma^2 + mu^2 - 1 - log(sigma^2))
kl_loss = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum() / batch_size
# Total ELBO loss (we minimize this)
total_loss = recon_loss + beta * kl_loss
return {
"loss": total_loss,
"recon_loss": recon_loss.item(),
"kl_loss": kl_loss.item(),
}
def generate(self, n_samples: int = 16) -> torch.Tensor:
"""Sample from the prior p(z) = N(0,I) and decode."""
device = next(self.parameters()).device
z = torch.randn(n_samples, self.latent_dim, device=device)
return self.decoder(z)
def reconstruct(self, x: torch.Tensor) -> torch.Tensor:
"""Encode then decode (deterministic at eval)."""
self.eval()
with torch.no_grad():
mu, log_var = self.encoder(x)
z = mu # use mean at eval time
return self.decoder(z)
def interpolate(
self, x1: torch.Tensor, x2: torch.Tensor, steps: int = 10
) -> torch.Tensor:
"""
Linear interpolation in latent space between two images.
"""
self.eval()
with torch.no_grad():
mu1, _ = self.encoder(x1.unsqueeze(0))
mu2, _ = self.encoder(x2.unsqueeze(0))
alphas = torch.linspace(0, 1, steps)
images = []
for alpha in alphas:
z = (1 - alpha) * mu1 + alpha * mu2
images.append(self.decoder(z))
return torch.cat(images, dim=0)
# ── Training loop ─────────────────────────────────────────────────────────────
def train_vae(
latent_dim: int = 20,
hidden_dim: int = 512,
beta: float = 1.0,
epochs: int = 30,
batch_size: int = 128,
lr: float = 1e-3,
save_dir: str = "vae_outputs",
):
os.makedirs(save_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_ds = datasets.MNIST("./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=4)
model = VAE(input_dim=784, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
# ── Train ──
model.train()
train_loss = 0.0
train_recon = 0.0
train_kl = 0.0
for x, _ in train_loader:
x = x.view(-1, 784).to(device)
optimizer.zero_grad()
x_recon, mu, log_var = model(x)
losses = model.elbo_loss(x, x_recon, mu, log_var, beta=beta)
losses["loss"].backward()
optimizer.step()
train_loss += losses["loss"].item()
train_recon += losses["recon_loss"]
train_kl += losses["kl_loss"]
n_batches = len(train_loader)
print(
f"Epoch {epoch + 1:3d} | "
f"Loss: {train_loss / n_batches:.2f} | "
f"Recon: {train_recon / n_batches:.2f} | "
f"KL: {train_kl / n_batches:.2f}"
)
# ── Generate samples ──
if (epoch + 1) % 5 == 0:
model.eval()
samples = model.generate(64)
save_image(
samples.view(64, 1, 28, 28),
f"{save_dir}/samples_epoch_{epoch + 1}.png",
nrow=8,
)
return model
Posterior Collapse Detection and Fixes
import torch
import torch.nn as nn
import torch.nn.functional as F
def detect_posterior_collapse(
model: VAE,
data_loader,
device: torch.device,
threshold: float = 0.01,
) -> dict:
"""
Detect posterior collapse: dimensions where KL ≈ 0 are collapsed.
A collapsed dimension means q(z_j|x) ≈ p(z_j) for all x.
The latent dimension carries no information.
"""
model.eval()
all_kl = []
with torch.no_grad():
for x, _ in data_loader:
x = x.view(-1, 784).to(device)
mu, log_var = model.encoder(x)
# Per-dimension KL: 0.5 * (sigma^2 + mu^2 - 1 - log_sigma^2)
kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp())
all_kl.append(kl_per_dim.mean(0)) # average over batch
avg_kl = torch.stack(all_kl).mean(0) # [latent_dim]
collapsed = (avg_kl < threshold).sum().item()
active = (avg_kl >= threshold).sum().item()
return {
"avg_kl_per_dim": avg_kl.cpu().numpy(),
"n_collapsed": collapsed,
"n_active": active,
"collapsed_dims": (avg_kl < threshold).nonzero(as_tuple=True)[0].tolist(),
}
class KLAnnealing:
"""
KL annealing: linearly ramp up the KL weight from 0 to 1 over warmup steps.
Prevents posterior collapse in early training.
"""
def __init__(self, warmup_steps: int = 10_000):
self.warmup_steps = warmup_steps
self.step = 0
def get_beta(self) -> float:
beta = min(1.0, self.step / self.warmup_steps)
self.step += 1
return beta
class FreebitVAE(VAE):
"""
VAE with free bits (Kingma et al. 2016).
Clamp per-dimension KL below a minimum lambda - prevents collapse
by not penalizing dimensions that already achieve minimum KL.
"""
def elbo_loss(self, x, x_recon, mu, log_var, beta=1.0, free_bits=0.5):
batch_size = x.size(0)
recon_loss = F.binary_cross_entropy(x_recon, x, reduction="sum") / batch_size
# Per-dimension KL
kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).mean(0)
# Free bits: clamp each dimension to at least free_bits nats
kl_clamped = kl_per_dim.clamp(min=free_bits)
kl_loss = kl_clamped.sum()
return {"loss": recon_loss + beta * kl_loss, "recon_loss": recon_loss.item(), "kl_loss": kl_loss.item()}
Beta-VAE: Disentangled Representations
class BetaVAE(VAE):
"""
beta-VAE (Higgins et al. 2017): higher beta forces disentangled representations.
ELBO = recon_loss - beta * KL
beta > 1: stronger pressure toward unit Gaussian prior
→ each latent dimension forced to encode one independent factor
→ better disentanglement, worse reconstruction quality
beta < 1: weaker KL penalty
→ better reconstruction, risk of posterior collapse
"""
def __init__(self, *args, beta: float = 4.0, **kwargs):
super().__init__(*args, **kwargs)
self.beta = beta
def elbo_loss(self, x, x_recon, mu, log_var, beta=None):
if beta is None:
beta = self.beta
return super().elbo_loss(x, x_recon, mu, log_var, beta=beta)
def evaluate_disentanglement(model: VAE, data_loader, latent_dim: int, device: torch.device):
"""
Simple traversal: vary each latent dimension from -3 to +3 while holding others fixed.
A disentangled VAE should show clear, interpretable variation per dimension.
"""
model.eval()
# Get a fixed input
x, _ = next(iter(data_loader))
x = x[0:1].view(1, -1).to(device)
traversal_images = []
with torch.no_grad():
mu, log_var = model.encoder(x)
z_base = mu.clone()
n_steps = 11
values = torch.linspace(-3.0, 3.0, n_steps)
for dim in range(min(latent_dim, 10)): # show first 10 dims
for val in values:
z = z_base.clone()
z[0, dim] = val
x_gen = model.decoder(z)
traversal_images.append(x_gen)
return torch.cat(traversal_images, dim=0)
Generation and Interpolation Demo
def generation_and_interpolation_demo():
"""
Demonstrate generation from prior and latent space interpolation.
"""
device = torch.device("cpu")
model = VAE(784, 512, 20).to(device)
# (In practice: load a trained model)
# model.load_state_dict(torch.load("vae.pt"))
model.eval()
print("=== Generation from Prior ===")
samples = model.generate(n_samples=16)
print(f"Generated shape: {samples.shape}") # [16, 784]
print("\n=== Reconstruction ===")
# Simulate two test images
x1 = torch.rand(784)
x2 = torch.rand(784)
recon = model.reconstruct(x1.unsqueeze(0))
print(f"Reconstruction shape: {recon.shape}") # [1, 784]
print("\n=== Latent Space Interpolation ===")
interp = model.interpolate(x1, x2, steps=10)
print(f"Interpolation shape: {interp.shape}") # [10, 784]
print("Smooth interpolation between two images in latent space")
print("\n=== Posterior Collapse Check ===")
# Compute per-dimension KL for a batch of random data
x_batch = torch.rand(64, 784)
with torch.no_grad():
mu, log_var = model.encoder(x_batch)
kl_per_dim = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).mean(0)
print(f"Avg KL per dim (first 5): {kl_per_dim[:5].numpy().round(3)}")
print("Dims near 0 → collapsed (carrying no information)")
generation_and_interpolation_demo()
Posterior Collapse - The Core Failure Mode
Posterior collapse is the most important practical challenge in VAE training. It occurs when the KL term overwhelms the reconstruction term early in training, and the encoder learns to output the prior regardless of input:
When this happens, the decoder is forced to model without any latent code - it becomes a fixed output independent of . The latent space is useless.
Why it happens: The KL term is a constant at initialization (often 0.5 per dimension), while reconstruction loss is large. The optimizer reduces KL first because it is the easier gradient path.
Fixes:
| Fix | Mechanism | Trade-off |
|---|---|---|
| KL annealing | Start with , ramp to 1 over 10k steps | May delay but not prevent collapse |
| Free bits | Clamp per-dim KL to minimum nats | Prevents collapse but adds hyperparameter |
| -VAE with | Weaker KL penalty | Risk of poorly structured latent space |
| Aggressive decoder | Weak decoder (less capacity) | Forces encoder to use latent code |
| NVAE / hierarchical VAE | Multiple latent groups at different scales | More complex, more robust |
:::warning Posterior Collapse Is Subtle You can train a VAE for 50 epochs, see the reconstruction loss go to near-zero, and still have 90% of your latent dimensions collapsed. Always monitor per-dimension KL during training. Any dimension with average KL below 0.1 nats is likely collapsed. :::
VAE vs GAN
| Aspect | VAE | GAN |
|---|---|---|
| Training objective | ELBO (likelihood-based) | Minimax adversarial |
| Training stability | Stable | Notoriously unstable (mode collapse, vanishing gradient) |
| Sample quality | Blurry (MSE-like reconstruction penalty) | Sharp (adversarial sharpening) |
| Latent space | Structured, continuous, interpolatable | Unstructured unless additional constraints added |
| Likelihood estimate | Available (ELBO is a lower bound) | Not available |
| Applications | Representation learning, anomaly detection, semi-supervised | High-quality image generation, super-resolution |
The blurriness of VAE samples comes from the pixel-wise reconstruction loss: maximizing under a Gaussian likelihood corresponds to minimizing MSE, which averages over multiple plausible outputs. This produces soft, averaged images.
Fix: combine VAE with perceptual loss, adversarial loss (VAE-GAN), or use a diffusion model as the decoder.
Mermaid: VAE Variants
:::tip VAE in Stable Diffusion Latent Diffusion Models (Rombach et al. 2022) - the architecture behind Stable Diffusion - use a VQ-VAE to compress 512x512 images to a 64x64 latent space, then run the diffusion process in that latent space. This is why Stable Diffusion is so much faster than pixel-space diffusion: the VAE reduces the dimensionality by 48x. :::
YouTube Resources
| Resource | What You Will Learn |
|---|---|
| Arxiv Insights - VAE Explained | Best visual explanation of ELBO and reparameterization trick |
| Yannic Kilcher - VAE Paper Walkthrough | Line-by-line paper reading with commentary |
| Andrej Karpathy - VAE Code Walkthrough | PyTorch implementation from scratch |
| MIT 6.S191 - Deep Generative Models | VAEs in context of the full generative model landscape |
Interview Q&A
Q1: Derive the ELBO from the marginal log-likelihood.
Answer: We want to maximize but it is intractable. Introduce auxiliary distribution :
By Jensen's inequality ( is concave):
This is the ELBO. The gap is . Maximizing ELBO = making the approximate posterior close to the true posterior AND making the decoder explain the data.
Q2: Why does the reparameterization trick work, and why can't you just use REINFORCE instead?
Answer: REINFORCE (score function estimator) computes gradients as:
This is unbiased but has very high variance - the gradient signal is noisy and training is slow or unstable.
The reparameterization trick rewrites , . Now:
The gradient passes directly through the deterministic transformation. In practice, one sample per training step is sufficient for stable training. The variance of the reparameterization estimator is much lower because the gradient is a direct derivative rather than a product of a function value and a log-density gradient. The requirement: the transformation must be differentiable w.r.t. . For discrete latent variables (VQ-VAE), this fails - straight-through estimator or Gumbel-softmax tricks are needed.
Q3: What is posterior collapse and how do you fix it?
Answer: Posterior collapse occurs when the encoder ignores the input and outputs the prior regardless of . The KL term becomes 0 (prior matches posterior perfectly), and the decoder must reconstruct without information from - it collapses to a fixed mean prediction.
This happens because: (1) the KL term is easy to minimize early in training by pushing the posterior to the prior, (2) powerful decoders can model without using at all.
Fixes:
- KL annealing: start with (pure reconstruction), ramp to 1 over thousands of steps. This lets the encoder learn to use before the KL penalty enforces structure.
- Free bits (Kingma et al. 2016): clamp per-dimension KL to at least nats. Dimensions that have collapsed below receive no gradient from the KL term, allowing them to re-activate.
- Weaker decoder: reduce decoder capacity so it cannot model without . Counterintuitive but effective.
- Hierarchical VAEs: model latent structure at multiple scales (NVAE, VDVAE) - upper layers are harder to collapse.
Diagnosis: monitor per-dimension KL during training. Any dimension consistently below 0.1 nats is likely collapsed.
Q4: How is VAE different from GAN, and when would you choose each?
Answer: VAEs and GANs are both generative models but with fundamentally different training objectives and failure modes.
VAE: maximizes the ELBO (likelihood-based). Pros: stable training, structured latent space enabling interpolation and disentanglement, provides a likelihood estimate. Cons: samples tend to be blurry (pixel-wise MSE penalty averages over modes). Use when: you need a structured latent space, need likelihood estimates, want stable training, or need the encoder for downstream tasks (semi-supervised learning, anomaly detection).
GAN: trains a generator and discriminator adversarially. Pros: sharp, realistic samples (adversarial loss sharpens details). Cons: unstable training (mode collapse, vanishing gradients), no latent structure, no likelihood estimate. Use when: output quality is the primary metric (image synthesis, super-resolution, style transfer).
In modern practice, the dichotomy has blurred: VAE-GAN hybrids combine the structured latent space of VAEs with the sharp samples of GANs. Latent diffusion models (Stable Diffusion) use a VQ-VAE to learn the latent space, then train a diffusion model in that space - combining the best of both.
Q5: What is beta-VAE and why does higher beta lead to disentanglement?
Answer: beta-VAE (Higgins et al. 2017) modifies the ELBO to:
With , the KL penalty is amplified. Since the KL to a standard normal is minimized when the posterior factorizes and each dimension has unit variance, a stronger penalty forces each latent dimension to:
- Encode at most one independent factor of variation (statistical independence).
- Have variance close to 1 (information bottleneck).
The information bottleneck forces the encoder to prioritize the most important factors - the latent code can only carry so much information, so it uses each dimension efficiently for a single semantic factor (e.g., one dimension controls brightness, another controls shape).
Trade-off: higher means worse reconstruction quality (more information is discarded). The disentanglement-reconstruction trade-off requires tuning per dataset. Typical values: for image datasets. Monitoring: plot reconstruction quality vs. MIG (Mutual Information Gap) score to find the best .
Key Takeaways
- VAEs combine variational inference with neural networks to learn a probabilistic generative model end-to-end.
- The ELBO = reconstruction term + KL term. Reconstruction pushes the decoder to fit the data; KL regularizes the latent space to match the prior.
- The reparameterization trick () makes sampling differentiable by isolating randomness in .
- KL closed form for diagonal Gaussians: - no sampling needed.
- Posterior collapse is the main failure mode: fix with KL annealing, free bits, or weakened decoder.
- beta-VAE: amplifies the KL penalty, forcing disentangled representations at the cost of reconstruction quality.
- VAE produces blurry samples (MSE averaging); GAN produces sharp samples (adversarial sharpening) but with unstable training.
- Modern systems like Stable Diffusion use VAEs as a compression backbone - the latent space is where diffusion happens.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Variational Inference demo on the EngineersOfAI Playground - no code required.
:::
