Cross-Entropy and Loss Functions
The Scenario That Motivates This Lesson
A junior engineer asks you:
"Why do we use cross-entropy loss for classification instead of mean squared error? Both measure prediction error."
You've used cross-entropy hundreds of times, but can you explain why it works? The answer involves three convergent perspectives:
- Information theory: minimizing cross-entropy minimizes the information lost when using the model's distribution to approximate the true label distribution
- Maximum likelihood: cross-entropy minimization is exactly maximum likelihood estimation
- Gradient dynamics: cross-entropy gives larger gradients when the model is confidently wrong, while MSE gives small gradients even for catastrophically wrong predictions
This lesson derives cross-entropy from first principles and shows exactly when and why to use each variant.
Definition: Cross-Entropy
The cross-entropy between a true distribution and an approximate distribution is:
For continuous distributions:
Interpretation: The average number of bits (nats) needed to encode samples from using a code optimized for . If , you need more bits than the entropy - the excess is exactly the KL divergence.
The Fundamental Identity
Proof:
This identity is the key to understanding everything that follows.
Why Minimizing Cross-Entropy Minimizes KL Divergence
When training a classifier, the true label distribution is fixed (determined by your dataset). Therefore is a constant with respect to model parameters .
Minimizing cross-entropy is equivalent to minimizing KL divergence from the true distribution to the model distribution. You are making the model's predictions as close as possible to the true label distribution - in the information-theoretic sense.
Why Cross-Entropy Minimization = Maximum Likelihood Estimation
Consider a classification dataset where . The model predicts .
The log-likelihood of the dataset is:
The empirical cross-entropy loss is:
Maximizing log-likelihood = minimizing cross-entropy loss. They are identical, just with opposite signs and a scaling factor of .
This is why cross-entropy is the natural loss for classification: it is not an arbitrary heuristic - it is exactly maximum likelihood estimation of the model's class probabilities.
:::note The connection
Both are equivalent to making as close to as possible in the KL sense. This is the principled reason cross-entropy works. :::
Binary Cross-Entropy
For binary classification (), let where is the sigmoid function and is the logit.
When : - penalize if is small When : - penalize if is large
Loss vs. prediction for binary CE:
y=1:
| *
4 | *
| *
2 | **
| ****
0 |_____________****-----------
0 0.5 1.0
predicted p
y=0: mirror image
Gradient Analysis: Why BCE Beats MSE for Classification
For the BCE loss with output :
This gradient is large when and differ greatly. If the model predicts for a positive label (), the gradient is - a large, useful update signal.
Compare with MSE loss on raw probability:
When and : gradient - nearly zero! The sigmoid's saturation region kills the gradient.
Cross-entropy avoids this because its gradient through the sigmoid cancels the derivative of sigmoid, leaving which is always large when the model is wrong.
Gradient magnitude for y=1, BCE vs MSE:
predicted p | BCE gradient | MSE gradient (on z)
------------|-------------|--------------------
0.01 | -0.990 | -0.020 ← MSE nearly dead
0.10 | -0.900 | -0.162
0.50 | -0.500 | -0.500 ← equal here
0.90 | -0.100 | -0.162
0.99 | -0.010 | -0.020
Categorical Cross-Entropy
For multiclass classification with classes:
With one-hot labels (only the true class has ):
This is the negative log-softmax loss, also called cross-entropy with softmax.
Numerically Stable Implementation
import numpy as np
import torch
import torch.nn.functional as F
def cross_entropy_naive(logits: np.ndarray, labels: np.ndarray) -> float:
"""
Naive cross-entropy - numerically unstable for large logits.
DO NOT use in production.
"""
probs = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
n = len(labels)
return -np.mean(np.log(probs[np.arange(n), labels]))
def cross_entropy_stable(logits: np.ndarray, labels: np.ndarray) -> float:
"""
Numerically stable cross-entropy using log-sum-exp trick.
log(softmax(z))_k = z_k - log(sum_j exp(z_j))
Use: log(sum_j exp(z_j)) = max(z) + log(sum_j exp(z_j - max(z)))
"""
n = len(labels)
# Subtract max for numerical stability (log-sum-exp trick)
shifted = logits - logits.max(axis=-1, keepdims=True)
log_sum_exp = np.log(np.sum(np.exp(shifted), axis=-1))
log_probs = shifted - log_sum_exp[:, np.newaxis] # log_softmax
# Pick log-probability of true class
return float(-np.mean(log_probs[np.arange(n), labels]))
# Verify against PyTorch
np.random.seed(42)
logits_np = np.random.randn(4, 3)
labels_np = np.array([0, 2, 1, 2])
logits_pt = torch.tensor(logits_np, dtype=torch.float32)
labels_pt = torch.tensor(labels_np, dtype=torch.long)
naive_loss = cross_entropy_naive(logits_np, labels_np)
stable_loss = cross_entropy_stable(logits_np, labels_np)
torch_loss = F.cross_entropy(logits_pt, labels_pt).item()
print(f"Naive CE: {naive_loss:.6f}")
print(f"Stable CE: {stable_loss:.6f}")
print(f"PyTorch CE: {torch_loss:.6f}")
# All should match
# Demonstrate instability of naive approach
large_logits = np.array([[1000.0, 2000.0, 3000.0]]) # overflow!
try:
cross_entropy_naive(large_logits, np.array([2]))
print("Naive succeeded (may have returned nan/inf)")
except Exception as e:
print(f"Naive failed: {e}")
result = cross_entropy_stable(large_logits, np.array([2]))
print(f"Stable succeeded: {result:.6f}") # 0.0 - correct!
Cross-Entropy vs. MSE: When to Use Which
Loss Function | Use When | Why
---------------|----------------------------|--------------------------------
Cross-entropy | Classification (any) | = MLE; proper scoring rule
| Outputs are probabilities | Good gradients everywhere
Binary CE | Binary classification | Sigmoid + BCE = sigmoid CE
| Multi-label | Independent binary decisions
Cat. CE | Multiclass (one-hot) | Softmax + NLL
MSE | Regression | L2 norm; Gaussian noise model
| Direct value prediction | When residuals are Gaussian
MAE | Regression, outliers | L1 norm; Laplacian noise model
Huber loss | Regression, some outliers | Smooth transition L1/L2
:::warning When NOT to use MSE for classification Using MSE to train a classifier that outputs probabilities has two problems: (1) the gradient dies in sigmoid saturation regions (shown above), and (2) MSE is not a proper scoring rule for probabilities - it does not have a unique minimizer at the true probability. Cross-entropy is a proper scoring rule: its unique minimum is achieved exactly when . :::
Label Smoothing: Regularizing with Soft Targets
Standard cross-entropy uses hard one-hot labels: the model is pushed to predict probability 1 for the correct class. This encourages overconfidence and poor calibration.
Label smoothing replaces the one-hot distribution with a soft target:
where is the smoothing factor (typically 0.1). This gives every class a small non-zero target probability.
The label-smoothed cross-entropy becomes:
def label_smoothed_cross_entropy(
logits: torch.Tensor,
labels: torch.Tensor,
smoothing: float = 0.1,
num_classes: int = None,
) -> torch.Tensor:
"""
Cross-entropy with label smoothing.
Args:
logits: model logits (batch, num_classes)
labels: true class indices (batch,)
smoothing: label smoothing factor ε ∈ [0, 1]
num_classes: C (inferred from logits if None)
Returns:
scalar loss
"""
if num_classes is None:
num_classes = logits.size(-1)
# Log-softmax for numerical stability
log_probs = F.log_softmax(logits, dim=-1) # (batch, C)
# Smoothed targets: one-hot * (1-eps) + eps/C
# Hard CE: -log_probs[i, labels[i]]
# Soft CE: -sum_c y_c * log_probs[i, c]
nll_loss = -log_probs.gather(dim=-1, index=labels.unsqueeze(1)).squeeze(1)
smooth_loss = -log_probs.mean(dim=-1)
loss = (1 - smoothing) * nll_loss + smoothing * smooth_loss
return loss.mean()
# Compare standard CE vs label-smoothed CE
torch.manual_seed(0)
batch_size, n_classes = 8, 10
logits = torch.randn(batch_size, n_classes)
labels = torch.randint(0, n_classes, (batch_size,))
ce_standard = F.cross_entropy(logits, labels).item()
ce_smoothed = label_smoothed_cross_entropy(logits, labels, smoothing=0.1).item()
print(f"Standard cross-entropy: {ce_standard:.4f}")
print(f"Label-smoothed (eps=0.1) CE: {ce_smoothed:.4f}")
# Label-smoothed loss is typically slightly higher
# but training leads to better-calibrated models
:::tip Label Smoothing in Practice Label smoothing (ε = 0.1) is used in:
- Image classification (ResNets, EfficientNets)
- Machine translation (Transformer paper used ε=0.1)
- Large language models during SFT It improves calibration (model confidence better matches actual accuracy) and reduces overconfidence on noisy labels. :::
Focal Loss: Handling Class Imbalance
Standard cross-entropy treats all examples equally. In object detection (RetinaNet) and other imbalanced settings, easy negatives dominate the loss.
Focal loss down-weights easy examples:
where:
- if , else
- is the focusing parameter (typically 2.0)
When (easy example), - the loss is suppressed. When (hard example), - the loss is nearly unchanged.
def focal_loss(
logits: torch.Tensor,
labels: torch.Tensor,
gamma: float = 2.0,
alpha: float = 0.25,
) -> torch.Tensor:
"""
Focal loss for binary classification (RetinaNet-style).
Args:
logits: raw predictions (batch,)
labels: binary targets (batch,) - 0 or 1
gamma: focusing parameter (0 = standard CE, 2 = typical)
alpha: class balancing weight for positive class
Returns:
mean focal loss
"""
probs = torch.sigmoid(logits)
probs_t = torch.where(labels == 1, probs, 1 - probs)
# Standard BCE
bce = F.binary_cross_entropy_with_logits(logits, labels.float(), reduction='none')
# Focal weight: down-weight easy examples
focal_weight = (1 - probs_t) ** gamma
# Alpha weighting for class imbalance
alpha_t = torch.where(labels == 1, alpha, 1 - alpha)
loss = alpha_t * focal_weight * bce
return loss.mean()
# Compare CE vs Focal on easy and hard examples
print("\n=== Focal Loss vs Standard BCE ===")
print(f"{'Prediction':>12} {'True Label':>12} {'BCE':>10} {'Focal(γ=2)':>12}")
print("-" * 50)
for pred_logit, true_label in [
(3.0, 1), # easy positive (correct, confident)
(-3.0, 0), # easy negative (correct, confident)
(0.5, 1), # hard positive (correct, uncertain)
(0.5, 0), # hard negative
(-3.0, 1), # very hard positive (wrong, confident)
]:
logit_t = torch.tensor([pred_logit])
label_t = torch.tensor([true_label])
pred_p = torch.sigmoid(logit_t).item()
bce_val = F.binary_cross_entropy_with_logits(logit_t, label_t.float()).item()
focal_val = focal_loss(logit_t, label_t, gamma=2.0, alpha=0.5).item()
print(f"{pred_p:>12.3f} {true_label:>12} {bce_val:>10.4f} {focal_val:>12.4f}")
Temperature Scaling for Calibration
During inference, softmax predictions can be overconfident. Temperature scaling divides logits by a scalar before softmax:
- : standard softmax
- : softer distribution, less confident (higher entropy)
- : sharper distribution, more confident (lower entropy)
def temperature_scaled_cross_entropy(
logits: torch.Tensor,
labels: torch.Tensor,
temperature: float = 1.0,
) -> torch.Tensor:
"""Cross-entropy with temperature scaling (used for calibration)."""
scaled_logits = logits / temperature
return F.cross_entropy(scaled_logits, labels)
# Effect of temperature on entropy
torch.manual_seed(1)
logits = torch.tensor([[2.0, 0.5, -1.0, -0.5]]) # 4-class example
print("\n=== Temperature Scaling Effect ===")
print(f"{'Temperature':>12} | {'Probabilities':<40} | {'Entropy (nats)':>14}")
print("-" * 72)
for T in [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]:
probs = F.softmax(logits / T, dim=-1).squeeze()
log_probs = F.log_softmax(logits / T, dim=-1).squeeze()
h = (-probs * log_probs).sum().item()
p_str = " ".join(f"{p:.3f}" for p in probs.tolist())
print(f"{T:>12.1f} | [{p_str}] | {h:>14.4f}")
Cross-Entropy Across ML Frameworks
import torch
import torch.nn as nn
# PyTorch equivalences:
# F.cross_entropy(logits, labels) = NLLLoss(log_softmax(logits), labels)
# F.binary_cross_entropy_with_logits = numerically stable sigmoid + BCE
# All three are equivalent for multiclass:
logits = torch.randn(8, 5)
labels = torch.randint(0, 5, (8,))
# Method 1: F.cross_entropy (recommended, internally stable)
loss1 = F.cross_entropy(logits, labels)
# Method 2: log_softmax + NLLLoss
log_probs = F.log_softmax(logits, dim=-1)
loss2 = F.nll_loss(log_probs, labels)
# Method 3: manual computation (shows the formula)
probs = F.softmax(logits, dim=-1)
log_probs_manual = torch.log(probs + 1e-10) # small epsilon for stability
loss3 = F.nll_loss(log_probs_manual, labels)
print(f"F.cross_entropy: {loss1.item():.6f}")
print(f"log_softmax + NLL: {loss2.item():.6f}")
print(f"Manual (less stable): {loss3.item():.6f}")
# For binary classification:
logits_bin = torch.randn(8)
labels_bin = torch.randint(0, 2, (8,)).float()
# Recommended: numerically stable
loss_bce = F.binary_cross_entropy_with_logits(logits_bin, labels_bin)
print(f"\nBCE with logits: {loss_bce.item():.6f}")
# NOT recommended: can have numerical issues
probs_bin = torch.sigmoid(logits_bin)
loss_bce_naive = F.binary_cross_entropy(probs_bin, labels_bin)
print(f"BCE (naive): {loss_bce_naive.item():.6f}")
Perplexity: Cross-Entropy for Language Models
For language models, the evaluation metric is perplexity:
Or equivalently with natural log:
Perplexity = e^(cross-entropy loss). Lower perplexity = better model.
Model | Perplexity | Interpretation
-----------------|------------|-------------------------------
Random (50k vocab)| 50,000 | Completely random
GPT-2 (small) | ~50 | Avg. 50 equally likely choices
GPT-3 | ~20 | Avg. 20 equally likely choices
GPT-4 (est.) | ~10 | High confidence per token
Human (estimate) | ~7-10 | Approximate human-level
Interview Questions and Answers
Q1: Why do we use cross-entropy loss for classification instead of mean squared error?
Three reasons:
-
Information-theoretic: Cross-entropy minimization equals minimizing the KL divergence between the true label distribution and the model's predicted distribution - it is the information-theoretically correct loss for fitting a probability distribution.
-
MLE equivalence: Cross-entropy minimization is exactly maximum likelihood estimation. The model parameters that maximize the likelihood of observed labels are exactly those that minimize cross-entropy.
-
Gradient behavior: For a sigmoid output, the BCE gradient is - always proportional to the error, even when the model is confidently wrong. MSE combined with sigmoid gives a gradient of , which goes to zero when or - killing learning exactly when the model is most wrong.
Q2: Derive the relationship between cross-entropy, KL divergence, and entropy.
Starting from the definition of cross-entropy:
Add and subtract :
Therefore: cross-entropy = entropy of true distribution + KL divergence from true to model. Since is fixed during training, minimizing cross-entropy is equivalent to minimizing .
Q3: What is label smoothing, and why does it improve model calibration?
Label smoothing replaces the one-hot target vector with a soft distribution: . Instead of pushing the model to output probability exactly 1 for the correct class, we ask for something like 0.9 while giving 0.01 to each other class.
This improves calibration because hard labels encourage the model to become infinitely confident (logit ), which makes the softmax output numerically 1 and the cross-entropy 0. But this is overfit: a well-calibrated model should express 90% confidence when it is right 90% of the time. Label smoothing regularizes against overconfidence, resulting in models whose predicted probabilities better reflect true accuracy (as measured by calibration curves and ECE).
Q4: When would you use focal loss over standard cross-entropy?
Use focal loss when there is severe class imbalance, especially when the dominant class consists of easy examples. The motivating case is object detection: in a typical image, there are perhaps 5 objects and thousands of background patches (easy negatives). Standard BCE averages over all patches, so easy negatives dominate the loss and gradients - the model learns almost nothing from the rare hard positives.
Focal loss down-weights easy examples by a factor that grows with confidence. A correctly classified easy negative with is down-weighted by - effectively ignored. A hard example with is down-weighted by only . This focuses training on the informative hard examples.
Q5: How does perplexity relate to cross-entropy, and what does a perplexity of 50 mean for a language model?
Perplexity is the exponentiated cross-entropy: where is the average negative log-likelihood per token.
A perplexity of 50 means the model is as uncertain as if it had to choose uniformly among 50 equally likely next tokens at each step. Equivalently, the model assigns average probability per token - it needs about bits to encode each token on average.
Lower perplexity = the model is more confident and correct = it is better at predicting the next token = it is a better compressor of the language. GPT-2 achieves PPL ~50 on PTB, GPT-3 achieves ~20, and the theoretical limit (human perplexity) is estimated at 7–15 depending on the corpus.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Cross-Entropy Loss demo on the EngineersOfAI Playground - no code required.
:::
