Variational Autoencoders
The Real Interview Moment
You are building a drug discovery system at a pharmaceutical company. The goal: generate novel molecular structures that are chemically similar to known effective compounds but differ enough to constitute new intellectual property. Your data is 50,000 known drug molecules, each represented as a high-dimensional fingerprint vector. You need to explore the "chemical space" around each known drug.
A senior engineer proposes a standard autoencoder. It compresses each molecule to a 128-dimensional latent vector, then reconstructs it. You test it: pick two known drug molecules, encode them to latent vectors and , interpolate halfway - decode the midpoint. The output is incoherent garbage. Why? The latent space has holes. The encoder learned to map each molecule to a compact isolated region of latent space, optimized purely for reconstruction. The space between those regions is unexplored, unconstrained, and decodes to nonsense.
A VAE solves this by imposing a prior on the latent space. Every molecule maps not to a point but to a Gaussian distribution. The KL divergence term in the loss forces these distributions to overlap with a standard Gaussian prior, creating a continuous, structured latent space where every point decodes to something meaningful. You can now sample freely, interpolate smoothly, and systematically explore the neighborhood of any known drug. This is why VAEs became the backbone of molecular generation systems like CVAE (Gómez-Bombarelli et al., 2018) that generated millions of candidate drug molecules.
This lesson covers VAEs completely - the probabilistic framework, the ELBO derivation, the reparameterization trick, β-VAE disentanglement, VQ-VAE discrete latent spaces, conditional generation, and production anomaly detection patterns. All with complete PyTorch implementations.
Why Standard Autoencoders Fail at Generation
Before diving into VAEs, it is worth understanding precisely why standard autoencoders (AEs) fail as generative models.
A standard autoencoder consists of an encoder and a decoder , trained to minimize reconstruction loss: .
The encoder learns to compress each input to a single point in latent space. There is no constraint on where these points land, how they are distributed, or what the space between them looks like. The encoder is free to use a tiny, fragmented patch of latent space for each input - and that is exactly what it does, because concentrated representations minimize reconstruction error.
The consequence: if you sample a random latent vector and decode it, you almost certainly land in an unconstrained region between the training data's fragmented patches. The decoder has never seen this region during training and produces garbage output.
The AE also cannot tell you how similar two inputs are in a semantically meaningful way: the latent space geometry is arbitrary, not organized by semantic content. Interpolation between two latent points passes through unmapped territory.
What we need: a generative model where the entire latent space is structured and meaningful. VAEs achieve this by learning a latent space that matches a known prior distribution (the standard Gaussian), making every point in the space decodable and every direction in the space semantically smooth.
Historical Context
2013: Diederik Kingma and Max Welling publish "Auto-Encoding Variational Bayes" (the original VAE paper). The reparameterization trick - their key technical contribution - enables backpropagation through stochastic sampling, making variational inference scalable to deep networks.
2014: Danilo Rezende, Shakir Mohamed, and Daan Wierstra independently develop a similar framework with the "stochastic backpropagation" technique.
2017: Irina Higgins et al. (DeepMind) publish β-VAE: "BETA-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework," introducing disentangled representation learning.
2017: Aaron van den Oord et al. publish VQ-VAE: "Neural Discrete Representation Learning," introducing discrete latent spaces that produce sharper images by avoiding the blurriness inherent in continuous VAEs.
2018: Conditional VAEs and molecular design VAEs achieve landmark results in drug discovery (Gómez-Bombarelli) and protein structure prediction.
The key insight of the original VAE paper: variational inference (approximating an intractable posterior with a simpler distribution) had existed in statistics for decades, but training the approximate posterior simultaneously with the model parameters required a differentiable sampling operation. The reparameterization trick provided that differentiability, unlocking scalable amortized variational inference with neural networks.
The Probabilistic Framework
A VAE models the data distribution through a continuous latent variable :
This integral is intractable in general (cannot be computed in closed form for neural network decoders), so we need an approximation strategy. The three distributions in a VAE:
Prior : the distribution we impose on the latent space. Standard choice: isotropic Gaussian , where is the latent dimension. This is the target shape for the latent space - we want all latent codes to collectively follow this distribution.
Likelihood : the decoder's generative model. Given a latent code , the decoder outputs the parameters of a distribution over . For images with pixel values in , this is typically a Bernoulli distribution (BCE loss). For continuous data, a Gaussian (MSE loss).
Posterior : the true posterior - given a data point , what latent code explains it? By Bayes' theorem: . This is intractable because computing requires integrating over all .
Variational posterior (encoder) : a tractable approximation to the true posterior, parameterized by the encoder network with parameters . Standard choice: a diagonal Gaussian , where and are outputs of the encoder network.
The encoder effectively answers: "given input , what distribution of latent codes would explain it?" Rather than outputting a single code, it outputs the mean and variance of a Gaussian - a distribution of plausible codes.
The ELBO: Evidence Lower Bound
We want to maximize - the log probability of our data under the model. We cannot compute this directly because is intractable.
Derivation of the ELBO
Start with the log-likelihood and introduce the variational posterior :
Multiply and divide by inside the integral:
This is .
Apply Jensen's inequality ( for concave ):
Expand the logarithm:
This is the ELBO (Evidence Lower Bound):
Maximizing the ELBO is equivalent to:
- Maximizing the reconstruction term: pushing the decoder to accurately reconstruct from samples of the approximate posterior. This is the familiar reconstruction loss.
- Minimizing the KL divergence: pushing the approximate posterior toward the prior . This regularizes the latent space.
The gap between and the ELBO equals - how far the approximate posterior is from the true posterior. A perfect approximation closes the gap entirely.
Closed-Form KL for Gaussians
For the standard VAE with diagonal Gaussian encoder, the KL divergence has a closed form that avoids Monte Carlo estimation:
This is derived by integrating the ratio of two Gaussians analytically. In practice, the encoder outputs (not directly) for numerical stability (log-variance is unconstrained; variance must be positive).
The Reconstruction-Regularization Trade-off
The ELBO creates a fundamental tension:
- Reconstruction term wants the encoder to map each to a narrow, precise distribution (so the decoder gets informative, low-variance latent codes)
- KL term wants the encoder to map every to (totally uninformative)
At equilibrium, the VAE finds latent codes that are informative enough to reconstruct but similar enough to the prior that the entire latent space is covered. Dense regions of the prior (near the origin in a standard Gaussian) correspond to common, typical inputs. The latent space is smooth: nearby latent points decode to similar outputs because the encoder was forced to spread its distributions across overlapping regions.
The Reparameterization Trick
The ELBO contains - an expectation over the encoder's output distribution. To compute gradients with respect to , we need to differentiate through the sampling operation .
The problem: sampling is a stochastic operation. The standard backpropagation algorithm requires the computation graph to be deterministic - gradients cannot flow through a random node. If we draw as a sample and compute the loss, there is no gradient from the loss back to and .
The solution: reparameterize the sample as a deterministic transformation of the parameters plus a fixed, independent noise variable:
where is elementwise multiplication. This is algebraically identical to , but now:
- is sampled once and treated as a constant input - no gradient flows through it
- is a deterministic function of and (the learnable parameters)
- Gradients flow through and via standard backpropagation
Without reparameterization:
x → Encoder → [μ, σ] → sample → z → Decoder → x̂ → loss
↑
GRADIENT BLOCKED HERE (stochastic node)
With reparameterization:
x → Encoder → [μ, σ] → z = μ + σ·ε → Decoder → x̂ → loss
↑_______________________________↑
GRADIENT FLOWS (deterministic path)
ε ~ N(0,I) treated as fixed input
The reparameterization trick also applies to other distributions: for a uniform distribution, , reparameterize as , . For more complex distributions, normalizing flows can reparameterize arbitrary posteriors.
Full VAE Implementation in PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
class VAE(nn.Module):
"""
Variational Autoencoder for MNIST (28×28 flattened to 784).
Architecture:
Encoder: 784 → 400 → 400 → (μ ∈ R^d, log σ² ∈ R^d)
Decoder: d → 400 → 400 → 784 (sigmoid for pixel values)
Parameters
----------
input_dim : int - flattened input dimension (784 for MNIST)
latent_dim : int - dimension of latent space
hidden_dim : int - width of encoder/decoder hidden layers
"""
def __init__(self, input_dim: int = 784,
latent_dim: int = 20,
hidden_dim: int = 400):
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
# --- Encoder ---
# Two-layer MLP → separate heads for μ and log σ²
self.enc_shared = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim) # mean
self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # log σ²
# --- Decoder ---
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid(), # output in [0, 1] - pixel probabilities
)
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Encode input to variational posterior parameters.
Parameters
----------
x : (batch, input_dim) float tensor
Returns
-------
mu : (batch, latent_dim) - posterior mean
logvar : (batch, latent_dim) - posterior log-variance
"""
h = self.enc_shared(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu: torch.Tensor,
logvar: torch.Tensor) -> torch.Tensor:
"""
Reparameterization trick: z = μ + σ·ε, ε ~ N(0, I)
During training: stochastic sample (enables exploration of latent space)
During eval: return μ deterministically (best single estimate)
Parameters
----------
mu : (batch, latent_dim) posterior mean
logvar : (batch, latent_dim) posterior log-variance
Returns
-------
z : (batch, latent_dim) sampled latent code
"""
if self.training:
std = torch.exp(0.5 * logvar) # σ = exp(log σ² / 2)
eps = torch.randn_like(std) # ε ~ N(0, I), same shape as std
return mu + std * eps # reparameterized sample
return mu # deterministic at inference time
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Decode latent code to input space reconstruction."""
return self.decoder(z)
def forward(self, x: torch.Tensor) -> tuple:
"""
Full forward pass: encode → reparameterize → decode.
Returns
-------
x_hat : (batch, input_dim) reconstruction
mu : (batch, latent_dim) posterior mean
logvar : (batch, latent_dim) posterior log-variance
"""
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
x_hat = self.decode(z)
return x_hat, mu, logvar
@torch.no_grad()
def generate(self, n_samples: int,
device: str = 'cpu') -> torch.Tensor:
"""
Generate new samples by sampling from the prior p(z) = N(0, I).
Parameters
----------
n_samples : number of samples to generate
device : torch device
Returns
-------
samples : (n_samples, input_dim) generated outputs
"""
self.eval()
z = torch.randn(n_samples, self.latent_dim, device=device)
return self.decode(z)
@torch.no_grad()
def interpolate(self, x1: torch.Tensor, x2: torch.Tensor,
steps: int = 10) -> torch.Tensor:
"""
Linearly interpolate between two inputs in latent space.
Smooth interpolation confirms the latent space is continuous.
"""
self.eval()
mu1, _ = self.encode(x1.unsqueeze(0))
mu2, _ = self.encode(x2.unsqueeze(0))
alphas = torch.linspace(0, 1, steps, device=mu1.device)
z_interp = torch.stack([
(1 - a) * mu1 + a * mu2
for a in alphas
]).squeeze(1) # (steps, latent_dim)
return self.decode(z_interp)
def vae_loss(x_hat: torch.Tensor, x: torch.Tensor,
mu: torch.Tensor, logvar: torch.Tensor,
beta: float = 1.0) -> dict:
"""
VAE ELBO loss = Reconstruction loss + β × KL divergence.
Parameters
----------
x_hat : (batch, D) reconstruction
x : (batch, D) original input
mu : (batch, latent_dim) encoder mean
logvar : (batch, latent_dim) encoder log-variance
beta : KL weight (1.0 = standard VAE, > 1.0 = β-VAE)
Returns
-------
dict with 'total', 'recon', 'kl' keys (all scalar tensors)
"""
# Reconstruction loss: binary cross-entropy (pixel probabilities)
# reduction='sum' gives total over the batch - divide by N for per-sample
recon_loss = F.binary_cross_entropy(
x_hat, x, reduction='sum'
)
# KL divergence: closed form for N(μ, σ²) vs N(0, I)
# KL = -0.5 * Σ_j (1 + log σ²_j - μ²_j - σ²_j)
# Positive KL means encoder posterior deviates from standard Gaussian
kl_loss = -0.5 * torch.sum(
1 + logvar - mu.pow(2) - logvar.exp()
)
total = recon_loss + beta * kl_loss
return {
'total': total,
'recon': recon_loss,
'kl': kl_loss,
}
def train_vae(model: VAE, n_epochs: int = 30, batch_size: int = 128,
lr: float = 1e-3, beta: float = 1.0,
kl_anneal: bool = False, warmup_epochs: int = 10,
device: str = 'cpu') -> dict:
"""
Train VAE on MNIST.
Parameters
----------
model : VAE instance
n_epochs : number of training epochs
batch_size : mini-batch size
lr : Adam learning rate
beta : KL weight (1.0 = standard VAE)
kl_anneal : if True, linearly increase β from 0 to beta over warmup_epochs
warmup_epochs : epochs over which to anneal β (if kl_anneal=True)
device : 'cpu' or 'cuda'
Returns
-------
history : dict of per-epoch losses
"""
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1)) # 28×28 → 784
])
dataset = datasets.MNIST('./data', train=True,
download=True, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=2, pin_memory=True)
n_train = len(dataset)
history = {'total': [], 'recon': [], 'kl': [], 'beta': []}
for epoch in range(n_epochs):
# KL annealing: ramp beta from 0 to target value
if kl_anneal:
current_beta = beta * min(1.0, epoch / max(warmup_epochs, 1))
else:
current_beta = beta
model.train()
epoch_total = epoch_recon = epoch_kl = 0.0
for x, _ in loader:
x = x.to(device)
optimizer.zero_grad()
x_hat, mu, logvar = model(x)
losses = vae_loss(x_hat, x, mu, logvar, beta=current_beta)
losses['total'].backward()
# Clip gradients: VAE can have large KL gradients early in training
nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
epoch_total += losses['total'].item()
epoch_recon += losses['recon'].item()
epoch_kl += losses['kl'].item()
history['total'].append(epoch_total / n_train)
history['recon'].append(epoch_recon / n_train)
history['kl'].append(epoch_kl / n_train)
history['beta'].append(current_beta)
if (epoch + 1) % 5 == 0:
print(f"Epoch {epoch+1:3d}/{n_epochs} β={current_beta:.2f} "
f"Total: {history['total'][-1]:.2f} "
f"Recon: {history['recon'][-1]:.2f} "
f"KL: {history['kl'][-1]:.2f}")
return history
# Train standard VAE
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vae = VAE(input_dim=784, latent_dim=20, hidden_dim=400)
history = train_vae(vae, n_epochs=30, beta=1.0, device=device)
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, key, color in zip(axes, ['total', 'recon', 'kl'],
['steelblue', 'seagreen', 'tomato']):
ax.plot(history[key], color=color, lw=2)
ax.set_title(f"{key.capitalize()} Loss")
ax.set_xlabel("Epoch")
ax.grid(True, alpha=0.3)
plt.suptitle("VAE Training Curves (β=1)")
plt.tight_layout()
plt.show()
Visualizing the Latent Space
With a 2D latent dimension, you can directly plot the learned manifold and see how different input classes organize themselves:
# Train a 2D VAE for visualization
vae_2d = VAE(input_dim=784, latent_dim=2, hidden_dim=400)
history_2d = train_vae(vae_2d, n_epochs=30, device=device)
# Encode the test set and plot by digit class
transform_flat = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1))
])
test_dataset = datasets.MNIST('./data', train=False, transform=transform_flat)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)
vae_2d.eval()
all_mu, all_labels = [], []
with torch.no_grad():
for x, y in test_loader:
mu, _ = vae_2d.encode(x.to(device))
all_mu.append(mu.cpu())
all_labels.append(y)
all_mu = torch.cat(all_mu).numpy() # (10000, 2)
all_labels = torch.cat(all_labels).numpy()
plt.figure(figsize=(10, 8))
scatter = plt.scatter(all_mu[:, 0], all_mu[:, 1],
c=all_labels, cmap='tab10',
alpha=0.5, s=5)
plt.colorbar(scatter, label='Digit Class')
plt.xlabel("z₁ (Latent Dimension 1)")
plt.ylabel("z₂ (Latent Dimension 2)")
plt.title("VAE 2D Latent Space - MNIST Test Set")
plt.tight_layout()
plt.show()
# Traverse the latent space in a grid
# This produces the "manifold walk" - each grid position decodes to an image
vae_2d.eval()
grid_range = np.linspace(-3, 3, 20)
z1_grid, z2_grid = np.meshgrid(grid_range, grid_range)
z_grid = torch.tensor(
np.stack([z1_grid.ravel(), z2_grid.ravel()], axis=1),
dtype=torch.float32
).to(device)
with torch.no_grad():
decoded = vae_2d.decode(z_grid).cpu()
decoded = decoded.view(-1, 28, 28).numpy()
fig, axes = plt.subplots(20, 20, figsize=(15, 15))
for i, ax in enumerate(axes.flatten()):
ax.imshow(decoded[i], cmap='gray', vmin=0, vmax=1)
ax.axis('off')
plt.suptitle("VAE Latent Space Grid - Traversing z₁ × z₂", fontsize=12)
plt.tight_layout()
plt.show()
The latent space grid reveals the structure the VAE has learned: neighboring grid cells produce similar digits, digits of the same class cluster together, and transitions between classes are smooth.
Latent Space Interpolation
Smooth interpolation in latent space is a direct demonstration that the latent space is continuous and structured - a property that standard autoencoders lack.
# Get test batch for interpolation examples
test_iter = iter(test_loader)
x_batch, y_batch = next(test_iter)
# Pick a digit 4 and a digit 9
idx_4 = (y_batch == 4).nonzero(as_tuple=True)[0][0].item()
idx_9 = (y_batch == 9).nonzero(as_tuple=True)[0][0].item()
x1 = x_batch[idx_4].to(device) # "4"
x2 = x_batch[idx_9].to(device) # "9"
vae.eval()
with torch.no_grad():
interpolated = vae.interpolate(x1, x2, steps=12)
interpolated = interpolated.view(12, 28, 28).cpu().numpy()
fig, axes = plt.subplots(1, 12, figsize=(18, 2))
for i, ax in enumerate(axes):
ax.imshow(interpolated[i], cmap='gray')
ax.axis('off')
ax.set_title(f"α={i/11:.1f}", fontsize=7)
plt.suptitle("Latent Space Interpolation: 4 → 9")
plt.tight_layout()
plt.show()
The smooth morphing from "4" to "9" confirms the VAE has learned a continuous, structured latent space. A standard AE would produce blurry garbage at intermediate points because the path passes through unmapped territory.
β-VAE: Disentangled Representations
β-VAE (Higgins et al., 2017) modifies the ELBO by multiplying the KL term by a constant :
A larger imposes a stronger information bottleneck: the encoder is penalized more heavily for any posterior that differs from the prior. This forces the encoder to be more selective - it must compress the input into fewer, more independent latent dimensions.
Why This Produces Disentanglement
With , the encoder can spread information arbitrarily across the latent dimensions. With , the strong regularization forces the encoder to use as few "active" dimensions as possible, and the dimensions it does use tend to capture statistically independent factors of variation.
Intuitively: if a dataset has 3 independent factors (e.g., digit identity, thickness, and rotation for MNIST), the β-VAE with high will learn to encode each factor in a separate dimension rather than mixing them - because mixing would require multiple dimensions to carry the same information, which is penalized by the KL term.
The β trade-off: higher → more disentangled but worse reconstruction. The information bottleneck limits how much the latent code can carry, so some reconstruction detail is inevitably lost.
import torch
import matplotlib.pyplot as plt
# Train multiple models with different β values to compare
beta_values = [1.0, 2.0, 4.0, 10.0]
vae_betas = {}
histories_betas = {}
for beta in beta_values:
print(f"\n--- Training β-VAE with β={beta} ---")
model = VAE(input_dim=784, latent_dim=10, hidden_dim=400)
history = train_vae(model, n_epochs=20, beta=beta, device=device)
vae_betas[beta] = model
histories_betas[beta] = history
# Compare final metrics
print("\nFinal epoch metrics by β:")
print(f"{'β':>6} {'Total':>8} {'Recon':>8} {'KL':>8}")
for beta in beta_values:
h = histories_betas[beta]
print(f"{beta:>6.1f} {h['total'][-1]:>8.2f} "
f"{h['recon'][-1]:>8.2f} {h['kl'][-1]:>8.2f}")
# Higher β → lower KL (posterior more closely matches prior → more regularized)
# Higher β → higher reconstruction loss (less information in latent code)
Measuring Disentanglement
A simple latent traversal experiment: fix all latent dimensions to their mean value (z = 0 under standard Gaussian prior), then vary one dimension at a time and observe what changes in the decoded output:
def latent_traversal(model: VAE, dim: int, n_steps: int = 11,
value_range: float = 3.0,
device: str = 'cpu') -> np.ndarray:
"""
Traverse one latent dimension while holding others at 0.
In a disentangled VAE, this changes exactly one factor of variation.
"""
model.eval()
latent_dim = model.latent_dim
values = np.linspace(-value_range, value_range, n_steps)
images = []
with torch.no_grad():
for val in values:
z = torch.zeros(1, latent_dim, device=device)
z[0, dim] = val
img = model.decode(z).view(28, 28).cpu().numpy()
images.append(img)
return np.array(images) # (n_steps, 28, 28)
# Traverse each latent dimension for both β=1 and β=4
fig, axes = plt.subplots(8, 11, figsize=(20, 14))
for dim in range(8):
traversal = latent_traversal(vae_betas[4.0], dim=dim,
n_steps=11, device=device)
for step_idx, img in enumerate(traversal):
axes[dim, step_idx].imshow(img, cmap='gray')
axes[dim, step_idx].axis('off')
axes[dim, 0].set_ylabel(f"z_{dim}", fontsize=10)
plt.suptitle("β-VAE (β=4) Latent Traversal - Each Row Varies One Dimension",
fontsize=12)
plt.tight_layout()
plt.show()
In a well-disentangled β-VAE, each row of the traversal plot changes exactly one interpretable property: digit width, digit slant, loop size, etc. In a standard VAE (β=1), varying one dimension typically changes multiple properties simultaneously.
VQ-VAE: Discrete Latent Spaces
VQ-VAE (van den Oord et al., 2017) replaces the continuous Gaussian latent space with a discrete codebook of embedding vectors. Instead of sampling from a Gaussian, the encoder maps each input to its nearest codebook vector (vector quantization).
Motivation
VAEs with continuous Gaussian posteriors produce blurry reconstructions - this is a direct consequence of the reconstruction loss averaging over the posterior . If has nonzero variance, the decoder must produce an output that is reasonable for all samples in the distribution, not just the mode. This averaging creates blurriness.
VQ-VAE forces discrete latent codes: each input maps to exactly one of embeddings, with no averaging. The decoder receives a specific, concrete code rather than a distribution. This produces sharper reconstructions.
VQ-VAE also enables powerful priors over the discrete space: after training the VQ-VAE, you can train a powerful autoregressive model (PixelCNN, GPT) on the sequences of discrete codes - much easier than modeling a continuous high-dimensional latent.
The VQ-VAE Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
class VectorQuantizer(nn.Module):
"""
Vector Quantization layer for VQ-VAE.
Maps continuous encoder output to nearest codebook vector.
Uses the straight-through estimator for the gradient of the argmin operation.
"""
def __init__(self, num_embeddings: int, embedding_dim: int,
commitment_cost: float = 0.25):
super().__init__()
self.num_embeddings = num_embeddings # K: number of codes
self.embedding_dim = embedding_dim # D: code dimension
self.commitment_cost = commitment_cost # β in VQ-VAE paper
# The codebook: K embedding vectors of dimension D
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
# Initialize uniformly - common practice for VQ
nn.init.uniform_(
self.embedding.weight,
-1 / num_embeddings,
1 / num_embeddings
)
def forward(self, z_e: torch.Tensor) -> tuple:
"""
Parameters
----------
z_e : (B, D, H, W) or (B, D) encoder output (continuous)
Returns
-------
z_q : quantized output (same shape as z_e)
loss : VQ + commitment loss
perplexity : codebook usage metric (higher = more spread)
indices : (B, H, W) or (B,) codebook indices
"""
# For 2D case (flatten spatial dimensions if needed)
flat = z_e.view(-1, self.embedding_dim) # (B*H*W, D)
# Compute squared distances to all codebook vectors
# ||z_e - e_k||² = ||z_e||² + ||e_k||² - 2 * z_e · e_k
dist = (
flat.pow(2).sum(dim=1, keepdim=True) # (N, 1)
+ self.embedding.weight.pow(2).sum(dim=1) # (K,) broadcast
- 2 * flat @ self.embedding.weight.t() # (N, K)
)
# Nearest codebook entry
indices = dist.argmin(dim=1) # (N,)
z_q_flat = self.embedding(indices) # (N, D)
z_q = z_q_flat.view_as(z_e)
# VQ loss: move codebook toward encoder output
loss_vq = F.mse_loss(z_q.detach(), z_e)
# Commitment loss: move encoder output toward codebook
loss_commit = self.commitment_cost * F.mse_loss(z_q, z_e.detach())
loss = loss_vq + loss_commit
# Straight-through estimator: pass gradients through quantization
# as if z_q = z_e (but the codebook still gets the VQ gradient)
z_q = z_e + (z_q - z_e).detach()
# Perplexity: measures codebook utilization
# Perfect use = every code used equally → perplexity = K
avg_probs = F.one_hot(indices, self.num_embeddings).float().mean(0)
perplexity = torch.exp(-torch.sum(
avg_probs * torch.log(avg_probs + 1e-10)
))
return z_q, loss, perplexity, indices
class VQVAE(nn.Module):
"""
VQ-VAE: discrete latent space via vector quantization.
Produces sharper reconstructions than continuous VAE.
"""
def __init__(self, input_dim: int = 784,
hidden_dim: int = 256,
num_embeddings: int = 512,
embedding_dim: int = 64,
commitment_cost: float = 0.25):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, embedding_dim),
)
self.vq = VectorQuantizer(num_embeddings, embedding_dim,
commitment_cost)
self.decoder = nn.Sequential(
nn.Linear(embedding_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor) -> tuple:
z_e = self.encoder(x)
z_q, vq_loss, perplexity, indices = self.vq(z_e)
x_hat = self.decoder(z_q)
recon_loss = F.binary_cross_entropy(x_hat, x, reduction='mean')
total_loss = recon_loss + vq_loss
return x_hat, total_loss, recon_loss, perplexity
The straight-through estimator is the VQ-VAE equivalent of the reparameterization trick: during the forward pass, is the quantized (discretized) code; during the backward pass, gradients flow as if (the quantization is treated as an identity in the gradient computation). This allows the encoder to learn despite the non-differentiable argmin operation.
Conditional VAE (CVAE)
A Conditional VAE (CVAE) extends the standard VAE by conditioning both the encoder and decoder on a label . This enables class-conditional generation: "generate a digit that looks like a 7" instead of sampling any random digit.
The ELBO becomes:
class CVAE(nn.Module):
"""
Conditional VAE: conditions both encoder and decoder on class label.
Enables class-specific generation and controlled interpolation.
"""
def __init__(self, input_dim: int = 784,
latent_dim: int = 20,
hidden_dim: int = 400,
n_classes: int = 10):
super().__init__()
self.latent_dim = latent_dim
self.n_classes = n_classes
# Encoder: takes [x, one_hot(y)] concatenated
self.enc = nn.Sequential(
nn.Linear(input_dim + n_classes, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder: takes [z, one_hot(y)] concatenated
self.dec = nn.Sequential(
nn.Linear(latent_dim + n_classes, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid(),
)
def encode(self, x: torch.Tensor,
y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""y: (batch,) integer class labels"""
y_onehot = F.one_hot(y, self.n_classes).float()
h = self.enc(torch.cat([x, y_onehot], dim=1))
return self.fc_mu(h), self.fc_logvar(h)
def decode(self, z: torch.Tensor,
y: torch.Tensor) -> torch.Tensor:
y_onehot = F.one_hot(y, self.n_classes).float()
return self.dec(torch.cat([z, y_onehot], dim=1))
def forward(self, x: torch.Tensor,
y: torch.Tensor) -> tuple:
mu, logvar = self.encode(x, y)
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std) if self.training else torch.zeros_like(std)
z = mu + std * eps
x_hat = self.decode(z, y)
return x_hat, mu, logvar
@torch.no_grad()
def generate_class(self, digit_class: int,
n_samples: int = 16,
device: str = 'cpu') -> torch.Tensor:
"""Generate samples of a specific digit class."""
self.eval()
z = torch.randn(n_samples, self.latent_dim, device=device)
y = torch.full((n_samples,), digit_class,
dtype=torch.long, device=device)
return self.decode(z, y)
VAE vs GAN: Coverage vs Sharpness
VAEs and GANs both learn to generate data, but with fundamentally different objectives and failure modes:
| Property | VAE | GAN |
|---|---|---|
| Training objective | Maximize ELBO (reconstruction + KL) | Min-max game (generator vs discriminator) |
| Sample quality | Blurry but diverse | Sharp but potentially mode-dropping |
| Latent space | Structured, continuous, interpolatable | Often entangled, less structured |
| Training stability | Stable (single loss function) | Unstable (mode collapse, training oscillation) |
| Reconstruction | Yes - can encode and reconstruct inputs | No - generator has no encoder |
| Anomaly detection | Yes - measure reconstruction error | Difficult |
| Uncertainty | Yes - posterior distribution quantifies uncertainty | No |
| Best for | Structured latent space, interpolation, anomaly detection, generation with control | Maximum perceptual quality in images, video |
The blurriness of VAEs is a direct consequence of the reconstruction loss averaged over the posterior. The MSE or BCE loss penalizes deviations from the mean - if a VAE is uncertain between two slightly different outputs, it will produce their average (which is blurry). GANs avoid this by using a discriminator that evaluates the entire output holistically, enabling sharp predictions.
Modern approaches like VQGAN, Stable Diffusion, and DALL-E 2 combine VQ-VAE's discrete latent space with diffusion or transformer-based priors to get the best of both worlds: structured, controllable latent spaces with high visual quality.
Production: VAE for Anomaly Detection
VAEs enable powerful unsupervised anomaly detection: train on normal data, then flag inputs with high reconstruction error (or high ELBO loss) as anomalies. Unlike DBSCAN, this works on high-dimensional data and captures complex multi-modal normal distributions.
import torch
import numpy as np
from sklearn.metrics import roc_auc_score
class VAEAnomalyDetector:
"""
Unsupervised anomaly detector using VAE reconstruction error.
Train on normal data. Anomalies = inputs with high reconstruction error.
"""
def __init__(self, input_dim: int, latent_dim: int = 20,
hidden_dim: int = 256, beta: float = 1.0):
self.vae = VAE(input_dim=input_dim,
latent_dim=latent_dim,
hidden_dim=hidden_dim)
self.beta = beta
self.threshold_ = None
def fit(self, X_normal: np.ndarray,
n_epochs: int = 30,
percentile: float = 95.0,
device: str = 'cpu'):
"""
Train on normal data. Set anomaly threshold at a given percentile.
Parameters
----------
X_normal : (n, d) numpy array of normal observations
n_epochs : training epochs
percentile : reconstruction error percentile to use as threshold
95 → flag top 5% of training scores as anomalous
"""
# Convert to tensors
X_tensor = torch.tensor(X_normal, dtype=torch.float32)
dataset = torch.utils.data.TensorDataset(X_tensor)
loader = torch.utils.data.DataLoader(
dataset, batch_size=128, shuffle=True
)
self.vae = self.vae.to(device)
optimizer = torch.optim.Adam(self.vae.parameters(), lr=1e-3)
# Train
self.vae.train()
for epoch in range(n_epochs):
for (x,) in loader:
x = x.to(device)
optimizer.zero_grad()
x_hat, mu, logvar = self.vae(x)
losses = vae_loss(x_hat, x, mu, logvar, beta=self.beta)
losses['total'].backward()
optimizer.step()
# Set threshold based on training data reconstruction errors
training_scores = self._compute_scores(X_normal, device)
self.threshold_ = np.percentile(training_scores, percentile)
print(f"Anomaly threshold set at {self.threshold_:.4f} "
f"({percentile:.0f}th percentile of training scores)")
return self
def _compute_scores(self, X: np.ndarray,
device: str = 'cpu',
n_samples: int = 10) -> np.ndarray:
"""
Compute reconstruction error scores.
Average over multiple posterior samples for more stable estimates.
"""
self.vae.eval()
X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
scores = []
with torch.no_grad():
for i in range(0, len(X), 256): # batch to avoid OOM
batch = X_tensor[i:i+256]
batch_scores = torch.zeros(len(batch))
for _ in range(n_samples):
x_hat, mu, logvar = self.vae(batch)
# Reconstruction error per sample
recon = F.binary_cross_entropy(
x_hat, batch, reduction='none'
).sum(dim=1)
batch_scores += recon.cpu() / n_samples
scores.append(batch_scores.numpy())
return np.concatenate(scores)
def anomaly_score(self, X: np.ndarray,
device: str = 'cpu') -> np.ndarray:
return self._compute_scores(X, device=device)
def predict(self, X: np.ndarray,
device: str = 'cpu') -> np.ndarray:
"""Returns 1 for anomaly, 0 for normal."""
assert self.threshold_ is not None, "Call fit() first"
scores = self.anomaly_score(X, device=device)
return (scores > self.threshold_).astype(int)
# Example: anomaly detection on MNIST (treat class 9 as anomalies)
from torchvision import datasets, transforms
import numpy as np
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1).numpy())
])
mnist_train = datasets.MNIST('./data', train=True, download=True)
mnist_test = datasets.MNIST('./data', train=False)
# Training: only digits 0-8 (normal)
train_data = np.array([
img.numpy().flatten() / 255.0
for img, label in zip(mnist_train.data, mnist_train.targets)
if label < 9
])
# Test: mix of 0-8 (normal) and 9 (anomaly)
test_normal = np.array([
img.numpy().flatten() / 255.0
for img, label in zip(mnist_test.data[:1000], mnist_test.targets[:1000])
if label < 9
])[:200]
test_anomaly = np.array([
img.numpy().flatten() / 255.0
for img, label in zip(mnist_test.data, mnist_test.targets)
if label == 9
])[:200]
X_test = np.vstack([test_normal, test_anomaly])
y_test = np.array([0]*len(test_normal) + [1]*len(test_anomaly))
detector = VAEAnomalyDetector(input_dim=784, latent_dim=20)
detector.fit(train_data[:5000], n_epochs=20, percentile=95)
scores = detector.anomaly_score(X_test)
auc = roc_auc_score(y_test, scores)
print(f"ROC-AUC for digit-9 anomaly detection: {auc:.3f}")
KL Annealing: Preventing Posterior Collapse
Posterior collapse is the most common VAE training failure: the KL term dominates early and the encoder learns for all inputs - becoming uninformative. The decoder then ignores entirely and generates a blurry mean image.
Why it happens: early in training, the decoder is random and the reconstruction loss is high regardless of . The easiest way to reduce total loss quickly is to collapse the KL term to zero - the encoder outputs the prior and the KL vanishes. Once the decoder learns to ignore , it is hard to recover.
KL annealing: start with (train as a pure autoencoder), gradually increase to 1 over warmup epochs. This lets the decoder first learn to use the latent code for reconstruction, making it harder to ignore later.
# Train with KL annealing to prevent posterior collapse
vae_annealed = VAE(input_dim=784, latent_dim=20, hidden_dim=400)
history_annealed = train_vae(
vae_annealed,
n_epochs=50,
beta=1.0,
kl_anneal=True,
warmup_epochs=15, # ramp β from 0 to 1 over first 15 epochs
device=device
)
# Compare: standard training vs annealed training
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history['kl'], label='Standard (β=1 fixed)', lw=2)
axes[0].plot(history_annealed['kl'], label='KL Annealing', lw=2)
axes[0].set_title("KL Divergence During Training")
axes[0].set_xlabel("Epoch")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[1].plot(history_annealed['beta'], lw=2, color='tomato')
axes[1].set_title("β Schedule (KL Annealing)")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("β value")
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
YouTube Resources
| Title | Channel | Why Watch |
|---|---|---|
| Variational Autoencoders - Arxiv Insights | Arxiv Insights | Best intuitive explanation of the ELBO and reparameterization trick |
| VAE from Scratch in PyTorch | AladdinPersson | Complete code walkthrough - good for interview prep |
| β-VAE and Disentangled Representations | DeepMind | Authors explain disentanglement, latent traversal, and evaluation |
| VQ-VAE Explained | Yannic Kilcher | Deep dive into discrete latent spaces, straight-through estimator |
| VAE vs GAN: Comparison | Serrano.Academy | Clear comparison of both generative approaches with visual intuition |
Common Mistakes
:::danger Using MSE Reconstruction Loss for Binary Data
If your data has pixel values in {0, 1} or [0, 1] (images normalized to 0–1), use binary cross-entropy (F.binary_cross_entropy), not MSE. MSE treats pixel errors at 0.5 the same as errors at 0.0 - but for binary pixels the uncertainty is highest at 0.5. BCE properly penalizes confident wrong predictions. Using MSE on binary data produces visually blurry samples because it implicitly assumes a Gaussian output distribution rather than a Bernoulli one.
:::
:::danger Forgetting the Straight-Through Estimator in VQ-VAE
In VQ-VAE, the quantization operation z_q = codebook[argmin(distances)] is not differentiable - the argmin has zero gradient almost everywhere. If you naively backpropagate, the encoder gets no gradient from the reconstruction loss and cannot learn. The straight-through estimator z_q = z_e + (z_q - z_e).detach() passes gradients from the decoder through z_q as if it equals z_e, while the VQ loss trains the codebook separately. Forgetting this means the encoder never learns, even if the decoder loss decreases.
:::
:::warning KL Term Goes to Zero (Posterior Collapse)
If your KL divergence drops to near zero and stays there after a few epochs, posterior collapse has occurred. The encoder outputs the prior regardless of input; the decoder ignores the latent code. Diagnostics: check if is near zero and is near 1 for all inputs (posterior = prior). Fix: use KL annealing (kl_anneal=True), reduce decoder capacity, reduce learning rate, or use "free bits" - set a minimum KL per latent dimension (e.g., 0.5 nats) so the encoder is forced to use at least some information.
:::
:::warning Very Large β Destroys Reconstruction Quality Setting β too high in β-VAE starves the reconstruction term: the KL penalty dominates and forces the encoder to output something close to the prior for all inputs. The result: all inputs map to similar latent codes, and the decoder has nothing useful to decode. Start with β ∈ {2, 4, 6} and evaluate reconstruction quality visually before pushing to β = 10+. Monitor both reconstruction loss and KL during training: ideally KL decreases somewhat but remains at a few nats per dimension, not zero. :::
Interview Q&A
Q1: What is the reparameterization trick and why is it necessary in VAEs?
To train a VAE we need gradients of the loss with respect to the encoder parameters - specifically, we need . The expectation depends on through the sampling distribution, but sampling is a stochastic operation with no gradient. The reparameterization trick separates the randomness from the parameters: instead of sampling directly, we sample independently and compute . Now is a deterministic function of plus a fixed noise . The gradient is well-defined, and standard backpropagation flows through and to the encoder weights. Without this trick, we could only use score function (REINFORCE) estimators, which have much higher variance and converge far more slowly.
Q2: Explain the ELBO - what each term does and why we maximize it instead of the log-likelihood.
The ELBO (Evidence Lower Bound) is . We maximize it instead of because is intractable - computing it requires integrating over all possible latent codes , which has no closed form for neural network decoders. The ELBO is always a lower bound on , with the gap being . Maximizing the ELBO simultaneously: (1) maximizes the reconstruction term - the decoder is trained to reconstruct from samples of the approximate posterior; (2) minimizes the KL term - the encoder's approximate posterior is regularized toward the standard Gaussian prior, creating a structured, continuous latent space where the prior covers meaningful regions.
Q3: What is posterior collapse in VAEs, how do you detect it, and how do you fix it?
Posterior collapse occurs when the KL term goes to zero during training - the encoder outputs for all inputs, making the latent code uninformative. The decoder then learns to ignore and generates a blurry mean image regardless of the latent code. Detection: monitor per-epoch KL divergence - if it drops below ~0.5 nats/dimension and stays low while reconstruction loss plateaus, you have posterior collapse. Also check: the encoder's mean should be non-zero and vary across inputs; the variance should be less than 1 (the prior). Fixes: (1) KL annealing - start and ramp to 1 over warmup epochs; (2) free bits - set a minimum KL per dimension (e.g., 0.5 nats) using where is the minimum; (3) reduce decoder capacity - force the decoder to need the latent code by making it shallower; (4) use a more expressive posterior (normalizing flows) that can better approximate the true posterior.
Q4: How does a VAE differ from a standard autoencoder in terms of what it can do?
A standard autoencoder maps each input to a single point in latent space. The latent space geometry is unconstrained - the encoder places codes wherever minimizes reconstruction error, leaving the space between points unmapped. Sampling a random point from latent space almost always produces garbage. A VAE maps each input to a distribution (typically a Gaussian) and regularizes the entire space to match the standard Gaussian prior via the KL term. The consequences: (1) generation - you can sample and decode to get valid outputs, because the KL term ensures the prior region is covered; (2) smooth interpolation - the latent space is continuous, so interpolating between two codes produces semantically smooth transitions; (3) uncertainty quantification - the encoder's variance measures how ambiguous the encoding of a given input is. The trade-off is slightly worse reconstruction compared to an AE of equal capacity, because the KL regularization limits the information the latent code can carry.
Q5: What is β-VAE and when would you use it versus a standard VAE?
β-VAE multiplies the KL divergence term by a constant : . A larger creates a stronger information bottleneck - the encoder must compress input information into fewer, more independent latent dimensions. When is large enough, the latent dimensions tend toward disentanglement: each dimension captures one independent factor of variation (e.g., separate dimensions for face orientation, lighting, expression, and identity in a face dataset). The cost is reduced reconstruction quality, since the stronger KL constraint limits the information the latent code can carry. Use β-VAE when the primary goal is interpretable, disentangled representations for: latent space visualization, attribute-controlled generation (generate a face with the same identity but different lighting), downstream classification using latent features, or understanding the factors of variation in a dataset. Use standard VAE (β=1) when reconstruction quality or anomaly detection performance is the primary concern.
Q6: Explain the VQ-VAE's straight-through estimator. Why is it needed and how does it work?
VQ-VAE uses vector quantization: the encoder output is replaced by the nearest codebook vector where . The argmin operation is non-differentiable - its gradient is zero everywhere except at the exact switching point, making standard backpropagation impossible. The straight-through estimator solves this by creating a "fake" backward pass: in the forward pass, use (the quantized code) as intended; in the backward pass, pass gradients through as if by writing z_q_st = z_e + (z_q - z_e).detach(). The .detach() stops gradients from flowing through , but is in the computation graph, so gradients from the decoder loss propagate back through to the encoder. The codebook vectors are trained separately via the VQ loss (move codebook toward encoder) and the commitment loss (move encoder toward codebook).
Q7: How would you use a VAE for anomaly detection in a manufacturing system?
The setup: collect a large dataset of normal manufacturing sensor readings during steady-state operation. Train a VAE on this data with appropriate β and latent dimension. At inference, run each new sensor reading through the VAE and compute the ELBO loss (reconstruction BCE + KL). Normal readings should have low ELBO loss; anomalies (equipment faults, unusual operating conditions) will have high reconstruction error because the VAE was never trained to reconstruct them. Set a threshold at the 95th or 99th percentile of training ELBO scores - tune based on acceptable false-positive rates. Advantages over DBSCAN: works in high dimensions (100+ sensor channels), handles continuous smooth anomalies rather than requiring density-gap separation, and naturally integrates uncertainty (the latent variance is high for ambiguous readings). Key considerations: the VAE should only be trained on truly normal data - if anomalies appear in training, the model learns to reconstruct them and won't flag them as anomalous; monitor reconstruction error distribution over time and retrain when the normal distribution drifts.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Variational Inference demo on the EngineersOfAI Playground - no code required.
:::
