KL Divergence
The Scenario That Motivates This Lesson
You are training a Variational Autoencoder (VAE). The loss function has two terms: a reconstruction loss and a "KL loss." Your training crashes - the KL loss explodes to thousands while reconstruction is still poor.
Your colleague says: "The posterior is collapsing - just reduce the beta coefficient." You nod, but do you actually understand why the KL divergence appears in the VAE objective, what it measures, and why it can explode?
This lesson answers those questions. KL divergence is the fundamental measure of how much one probability distribution differs from another - and it appears everywhere in modern ML.
Definition: KL Divergence
The Kullback-Leibler divergence from distribution to distribution is:
For continuous distributions:
Read as: "The KL divergence of from " or "the KL divergence when using to approximate ."
Alternative Expressions
Or equivalently using entropy:
where is the cross-entropy. This identity - - is one of the most important equations in this entire module.
:::note By Convention: 0 log(0/0) = 0 and 0 log(0/q) = 0 If , the term contributes 0 regardless of . But if and , the KL divergence is . This means: Q must have support everywhere P does, or KL divergence is infinite. :::
Key Properties
Non-Negativity (Gibbs' Inequality)
Proof by Jensen's inequality:
Since is convex, by Jensen's inequality:
Asymmetry: Why KL Is Not a Distance
This is the most important property to internalize. KL divergence is not a metric - it does not satisfy the symmetry axiom, and it does not satisfy the triangle inequality.
Asymmetry Example:
P = [0.9, 0.1] (concentrated on outcome 1)
Q = [0.1, 0.9] (concentrated on outcome 2)
D_KL(P||Q) = 0.9 * log(0.9/0.1) + 0.1 * log(0.1/0.9)
= 0.9 * log(9) + 0.1 * log(1/9)
= 0.9 * 2.197 - 0.1 * 2.197
= 1.758 nats
D_KL(Q||P) = 0.1 * log(0.1/0.9) + 0.9 * log(0.9/0.1)
= 0.1 * (-2.197) + 0.9 * 2.197
= 1.758 nats
(This symmetric case is coincidental for swapped distributions of this form.
In general D_KL(P||Q) ≠ D_KL(Q||P).)
No Triangle Inequality
in general.
Forward KL vs. Reverse KL: Geometric Intuition
This asymmetry has profound consequences for variational inference. There are two ways to use KL divergence to fit an approximate distribution to a target :
Forward KL: - "Mean-Seeking"
Minimize the expectation under :
The loss is large when is small but is large. This forces to cover all modes of . If is bimodal, must place probability mass at both modes.
When has a mode where but , the term . So is penalized severely for ignoring any region where has mass.
Result: tends to be diffuse, covering all modes of (mode-covering behavior).
Reverse KL: - "Mode-Seeking"
Minimize the expectation under :
The loss is large when is large but is small. This forces to not place mass where has none. But when , the term contributes 0 regardless of .
So is free to ignore modes of (by setting there), as long as it concentrates on a region where .
Result: tends to be a sharp approximation concentrated on one mode of (mode-seeking behavior).
Forward KL (P||Q): Q must cover all modes of P
True P (bimodal): | * * |
Q fitted by fwd KL: | ***** | ← spreads across both modes
+---------+-+
Reverse KL (Q||P): Q can pick one mode and ignore the rest
True P (bimodal): | * * |
Q fitted by rev KL: | *** | ← concentrates on one mode
+---------+-+
:::info ML Connection - VAEs Use Reverse KL Variational Autoencoders minimize the reverse KL: . This means the encoder posterior is pushed to concentrate on regions of the prior that have mass. It won't spread across the entire prior - it seeks modes. This is why VAE latent spaces can collapse (mode-seeking behavior: posterior collapses to a single mode of the prior). :::
KL Divergence Between Gaussians
A particularly important closed-form result used everywhere in deep learning:
For two univariate Gaussians and :
Special case for VAEs - KL from to the standard normal :
For a diagonal multivariate Gaussian of dimension :
This is the exact formula used in every VAE implementation.
Python: Computing KL Divergence
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
def kl_divergence_discrete(p: np.ndarray, q: np.ndarray) -> float:
"""
KL divergence D_KL(P||Q) for discrete distributions.
Args:
p: true distribution (reference)
q: approximate distribution (model)
Returns:
KL divergence in nats
"""
p = np.asarray(p, dtype=float)
q = np.asarray(q, dtype=float)
# Ensure valid distributions
assert np.isclose(p.sum(), 1.0) and np.isclose(q.sum(), 1.0)
# Only sum where p > 0 (0 * log(0/q) = 0 by convention)
mask = p > 0
if np.any(q[mask] == 0):
return float('inf') # KL is infinite if q is 0 where p > 0
return float(np.sum(p[mask] * np.log(p[mask] / q[mask])))
def kl_divergence_gaussian(mu1, sigma1, mu2=0.0, sigma2=1.0) -> float:
"""
KL divergence D_KL(N(mu1,sigma1^2) || N(mu2,sigma2^2)).
Default Q = N(0,1) for VAE KL term.
"""
return (
np.log(sigma2 / sigma1)
+ (sigma1**2 + (mu1 - mu2)**2) / (2 * sigma2**2)
- 0.5
)
def vae_kl_loss(mu: np.ndarray, log_var: np.ndarray) -> float:
"""
KL divergence for VAE: D_KL(N(mu, sigma^2) || N(0, I)).
Input log_var = log(sigma^2) (standard VAE parameterization).
"""
# Per-dimension: 0.5 * (mu^2 + sigma^2 - log_sigma^2 - 1)
sigma_sq = np.exp(log_var)
per_dim = 0.5 * (mu**2 + sigma_sq - log_var - 1)
return float(np.sum(per_dim)) # sum over latent dimensions
# --- Demo 1: Asymmetry of KL divergence ---
print("=== Asymmetry of KL Divergence ===")
distributions = [
([0.9, 0.1], [0.5, 0.5], "Peaked P, Uniform Q"),
([0.5, 0.5], [0.9, 0.1], "Uniform P, Peaked Q"),
([0.7, 0.2, 0.1], [0.1, 0.2, 0.7], "Reversed peaks"),
([0.99, 0.01], [0.5, 0.5], "Very peaked P"),
]
print(f"{'Description':<30} | {'D_KL(P||Q)':>12} | {'D_KL(Q||P)':>12}")
print("-" * 60)
for p, q, desc in distributions:
kl_fwd = kl_divergence_discrete(p, q)
kl_rev = kl_divergence_discrete(q, p)
print(f"{desc:<30} | {kl_fwd:>12.4f} | {kl_rev:>12.4f}")
# --- Demo 2: KL divergence goes infinite ---
p_has_support = [0.7, 0.2, 0.1]
q_missing_support = [0.8, 0.2, 0.0] # q=0 where p>0!
kl = kl_divergence_discrete(p_has_support, q_missing_support)
print(f"\nD_KL(P||Q) when Q has zero where P>0: {kl}")
# Output: inf
# --- Demo 3: Gaussian KL for VAE ---
print("\n=== VAE KL Loss (D_KL(q||N(0,1))) ===")
test_cases = [
(0.0, 0.0, "Perfect match: mu=0, log_var=0 (sigma=1)"),
(1.0, 0.0, "Shifted: mu=1, sigma=1"),
(0.0, 1.0, "Wider: mu=0, sigma=sqrt(e)≈1.65"),
(2.0, 1.0, "Shifted + wider: mu=2, sigma=sqrt(e)"),
(0.0, -2.0, "Narrow: mu=0, sigma=exp(-1)≈0.37"),
]
for mu, log_var, desc in test_cases:
kl = vae_kl_loss(np.array([mu]), np.array([log_var]))
print(f" {desc}")
print(f" KL = {kl:.4f} nats")
KL Divergence in the VAE Loss
The VAE maximizes the Evidence Lower BOund (ELBO):
Term 1 (Reconstruction): How well does the decoder reconstruct from latent ? Measures the likelihood of the data.
Term 2 (KL Penalty): How much does the encoder posterior deviate from the prior ?
The KL term acts as regularization: if the encoder encodes too much specific information about (departing from the prior), it is penalized.
import torch
import torch.nn.functional as F
def vae_loss(
x: torch.Tensor,
x_recon: torch.Tensor,
mu: torch.Tensor,
log_var: torch.Tensor,
beta: float = 1.0,
) -> dict:
"""
VAE ELBO loss = reconstruction + beta * KL.
Args:
x: original input (batch, dim)
x_recon: decoder output (batch, dim)
mu: encoder mean (batch, latent_dim)
log_var: encoder log variance (batch, latent_dim)
beta: KL weight (beta-VAE: beta>1 encourages disentanglement)
Returns:
dict with 'total', 'reconstruction', 'kl' losses
"""
# Reconstruction: binary cross-entropy (for images in [0,1])
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
# KL: -0.5 * sum(1 + log_var - mu^2 - exp(log_var))
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
total = recon_loss + beta * kl_loss
return {
'total': total,
'reconstruction': recon_loss,
'kl': kl_loss
}
# Example: batch of 4 samples, latent dim 2
torch.manual_seed(42)
batch_size, latent_dim = 4, 2
mu = torch.randn(batch_size, latent_dim) * 0.5
log_var = torch.randn(batch_size, latent_dim) * 0.5
x = torch.rand(batch_size, 16)
x_recon = torch.sigmoid(torch.randn(batch_size, 16))
losses = vae_loss(x, x_recon, mu, log_var, beta=1.0)
print(f"VAE losses:")
print(f" Reconstruction: {losses['reconstruction'].item():.4f}")
print(f" KL Divergence: {losses['kl'].item():.4f}")
print(f" Total: {losses['total'].item():.4f}")
KL Divergence in Reinforcement Learning: PPO
Proximal Policy Optimization (PPO) constrains how much the policy changes at each update using KL divergence:
The clipping is equivalent to a soft constraint on:
The KL-penalty version of PPO makes this explicit:
Why? Large policy changes can destabilize training. By measuring the distributional change (not just parameter change), we ensure the new policy doesn't behave too differently from the old one - regardless of the scale of parameter space.
import torch
import torch.nn.functional as F
def ppo_kl_loss(
old_log_probs: torch.Tensor,
new_log_probs: torch.Tensor,
advantages: torch.Tensor,
beta: float = 0.01,
kl_target: float = 0.01,
) -> dict:
"""
PPO with KL penalty.
Args:
old_log_probs: log probs under old policy, shape (batch, num_actions)
new_log_probs: log probs under new policy, shape (batch, num_actions)
advantages: advantage estimates, shape (batch,)
beta: KL penalty coefficient
kl_target: target KL for adaptive beta
Returns:
dict with policy loss, kl, and updated beta
"""
# Importance sampling ratio
ratio = torch.exp(new_log_probs - old_log_probs) # per action
# Policy gradient objective (take max action ratio for each state)
pg_loss = -(ratio * advantages.unsqueeze(-1)).mean()
# KL divergence: D_KL(old || new)
# = sum_a pi_old(a) * log(pi_old(a) / pi_new(a))
old_probs = torch.exp(old_log_probs)
kl = torch.sum(old_probs * (old_log_probs - new_log_probs), dim=-1).mean()
# Adaptive beta (PPO paper appendix)
if kl > 1.5 * kl_target:
beta = beta * 2
elif kl < kl_target / 1.5:
beta = beta / 2
total_loss = pg_loss + beta * kl
return {'policy_loss': pg_loss, 'kl': kl, 'beta': beta, 'total': total_loss}
# Simulate a small policy update
torch.manual_seed(0)
batch = 8
actions = 4
old_logits = torch.randn(batch, actions)
new_logits = old_logits + 0.1 * torch.randn(batch, actions) # small update
advantages = torch.randn(batch)
old_log_probs = F.log_softmax(old_logits, dim=-1)
new_log_probs = F.log_softmax(new_logits, dim=-1)
result = ppo_kl_loss(old_log_probs, new_log_probs, advantages)
print(f"KL divergence (old||new): {result['kl'].item():.6f}")
print(f"Policy loss: {result['policy_loss'].item():.4f}")
print(f"Adaptive beta: {result['beta']:.4f}")
Jensen-Shannon Divergence: The Symmetric Alternative
The Jensen-Shannon divergence is a symmetrized, bounded version of KL:
where is the mixture distribution.
Properties:
- Symmetric:
- Bounded: (nats) or (bits)
- Well-defined even when distributions don't share support (unlike KL)
- is a proper metric (satisfies triangle inequality)
:::info ML Connection - GANs and Jensen-Shannon Divergence The original GAN paper (Goodfellow et al., 2014) showed that the optimal GAN discriminator minimizes the Jensen-Shannon divergence between the real data distribution and the generator distribution :
This is why standard GANs can suffer from vanishing gradients when the two distributions don't overlap - saturates at for non-overlapping distributions. The Wasserstein GAN replaced JS divergence with the Earth Mover's distance to fix this. :::
def js_divergence(p: np.ndarray, q: np.ndarray) -> float:
"""Jensen-Shannon divergence (nats). Always in [0, ln(2)]."""
p = np.asarray(p, dtype=float)
q = np.asarray(q, dtype=float)
m = 0.5 * (p + q)
return 0.5 * kl_divergence_discrete(p, m) + 0.5 * kl_divergence_discrete(q, m)
# Compare KL vs JS for distributions with varying overlap
print("\n=== KL vs JS Divergence ===")
test_pairs = [
([0.5, 0.5], [0.5, 0.5], "Identical"),
([0.9, 0.1], [0.1, 0.9], "Opposite"),
([0.7, 0.3], [0.5, 0.5], "Moderate"),
([0.99, 0.01], [0.5, 0.5], "Very different"),
]
print(f"{'Case':<20} | {'KL(P||Q)':>10} | {'KL(Q||P)':>10} | {'JS':>8}")
print("-" * 55)
for p, q, name in test_pairs:
kl_pq = kl_divergence_discrete(p, q)
kl_qp = kl_divergence_discrete(q, p)
js = js_divergence(p, q)
print(f"{name:<20} | {kl_pq:>10.4f} | {kl_qp:>10.4f} | {js:>8.4f}")
print(f"\nln(2) = {np.log(2):.4f} nats (maximum JS)")
Summary: KL Variants Used in ML
Divergence | Formula | ML Application
--------------------+----------------------------+---------------------------
Forward KL | E_P[log P/Q] | EM algorithm, ADF
Reverse KL | E_Q[log Q/P] | VAE, VI (mode-seeking)
JS Divergence | 0.5*KL(P||M)+0.5*KL(Q||M) | GAN training
Alpha divergence | family parameterized by α | EP, power EP
Renyi divergence | 1/(α-1) log E_P[(P/Q)^α] | Min-max RL
Total variation | 0.5 Σ |p(x)-q(x)| | Theoretical bounds
Interview Questions and Answers
Q1: What is KL divergence and why is it not a true distance metric?
KL divergence measures how much information is lost when using to approximate . It is not a metric because it violates two metric axioms: (1) it is not symmetric - in general, and (2) it does not satisfy the triangle inequality. We still use it because it has excellent theoretical properties: it equals zero iff , is always non-negative, and has a natural information-theoretic interpretation.
Q2: What is the difference between forward and reverse KL, and which does a VAE use?
Forward KL: - minimize by placing 's mass everywhere has mass (mode-covering). If you ignore a mode of with , the term becomes huge.
Reverse KL: - minimize by not placing 's mass where has none (mode-seeking). Zero-probability regions of are safely ignored by setting .
VAEs use reverse KL: . The encoder posterior is penalized for deviating from the prior , but not for failing to cover all of 's support. This leads to posterior collapse in extreme cases - a known VAE failure mode.
Q3: Why does the KL divergence go to infinity when Q assigns zero probability to a region where P has nonzero probability?
In , when but , the term becomes . Intuitively: is asserting that event is impossible, but says it can happen. If does occur (sampled from ), we have infinite surprise - is completely unprepared for it. This is why in practice, we add label smoothing or use -regularization to prevent any .
Q4: How does KL divergence appear in the PPO algorithm, and why use it there instead of L2 distance between parameters?
PPO constrains policy updates using . This measures the behavioral difference between policies - how differently they would act in the environment - not the Euclidean distance between parameters.
Parameters can have very different scales across layers, so L2 distance in parameter space is not behaviorally meaningful: a tiny change in a highly sensitive parameter can drastically change policy behavior. KL divergence directly measures the change in the action distribution, which is what matters for stability. Two policies can have very different parameters but identical KL divergence (and thus identical behavior) - this is what we want to constrain.
Q5: What is the relationship between KL divergence and information gain in Bayesian inference?
In Bayesian inference, after observing data , you update from prior to posterior . The KL divergence from prior to posterior:
measures how much information you gained about from . This is also called the expected information gain or the Bayesian surprise. It quantifies how much data shifted your beliefs. Small KL divergence = data was not very informative (consistent with many parameter values). Large KL = data strongly ruled out large regions of parameter space. This is used in Bayesian experimental design to choose experiments that maximize expected information gain.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the KL Divergence demo on the EngineersOfAI Playground - no code required.
:::
