Hierarchical Models
The Problem of Many Groups with Sparse Data
You're building a recommendation system for 10,000 cities. Each city has different user preferences. You want to learn a recommendation model per city.
-
Option 1 - No pooling (separate models): Train a completely separate model for each city. Problem: most cities have very few users. Your model for Podunk, Iowa (50 users) will be wildly overfit. A city you just launched in has zero historical data - no model at all.
-
Option 2 - Complete pooling (one global model): Train a single model for all cities. Problem: you ignore real differences between cities. New York and rural Montana have genuinely different preferences. The global model is wrong everywhere.
-
Option 3 - Partial pooling (hierarchical model): Learn city-specific models, but share statistical strength across cities by assuming they're drawn from a common distribution. Cities with little data borrow strength from the global estimate; cities with lots of data rely mainly on their own data.
This is the hierarchical model approach - and it consistently outperforms both naive alternatives.
The Mathematical Structure
A hierarchical (multilevel) Bayesian model has at least two levels:
Level 1 - Observation model (within-group):
Level 2 - Group-level model (across groups):
Level 3 - Hyperprior (on global parameters):
The group-level parameters are drawn from a common distribution parameterized by (hyperparameters). This is the key: rather than treating as fixed unknowns or assuming all are identical, they're treated as samples from a shared distribution.
The posterior: By Bayes theorem, the joint posterior is:
The Classic Example: Eight Schools
The famous "eight schools" dataset (Rubin, 1981) shows the effectiveness of a test preparation program at 8 different schools. The key challenge: the estimates for individual schools are noisy (small sample sizes), but schools may genuinely differ.
| School | Effect (hat{theta}_j) | Std Error (sigma_j) |
|---|---|---|
| A | 28 | 15 |
| B | 8 | 10 |
| C | -3 | 16 |
| D | 7 | 11 |
| E | -1 | 9 |
| F | 1 | 11 |
| G | 18 | 10 |
| H | 12 | 18 |
No pooling: School A has effect 28 - but with SE=15, this is very uncertain. Complete pooling: All schools have the same effect - ignores real variation. Hierarchical: Each school effect is drawn from - learns both individual effects and their variability.
import pymc as pm
import arviz as az
import numpy as np
# Eight Schools data
schools = ["A", "B", "C", "D", "E", "F", "G", "H"]
effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=float)
sigma_j = np.array([15, 10, 16, 11, 9, 11, 10, 18], dtype=float)
J = len(schools)
# ============================================================
# Model 1: No Pooling -- independent estimates per school
# ============================================================
with pm.Model() as no_pool_model:
theta = pm.Normal('theta', mu=0, sigma=100, shape=J)
y_obs = pm.Normal('y_obs', mu=theta, sigma=sigma_j, observed=effects)
no_pool_trace = pm.sample(2000, tune=1000, return_inferencedata=True,
progressbar=False, random_seed=42)
# ============================================================
# Model 2: Complete Pooling -- single effect for all schools
# ============================================================
with pm.Model() as pool_model:
theta = pm.Normal('theta', mu=0, sigma=100)
y_obs = pm.Normal('y_obs', mu=theta, sigma=sigma_j, observed=effects)
pool_trace = pm.sample(2000, tune=1000, return_inferencedata=True,
progressbar=False, random_seed=42)
# ============================================================
# Model 3: Hierarchical (Partial Pooling) -- the correct approach
# Non-centered parameterization to avoid funnel geometry
# ============================================================
with pm.Model() as hierarchical_model:
# Hyperpriors (global distribution of school effects)
mu = pm.Normal('mu', mu=0, sigma=10) # global mean effect
tau = pm.HalfNormal('tau', sigma=10) # between-school variability
# Non-centered parameterization: theta_j = mu + tau * eta_j
# eta_j ~ N(0, 1) avoids the funnel problem in HMC
eta = pm.Normal('eta', mu=0, sigma=1, shape=J)
theta = pm.Deterministic('theta', mu + tau * eta)
# Likelihood
y_obs = pm.Normal('y_obs', mu=theta, sigma=sigma_j, observed=effects)
hier_trace = pm.sample(2000, tune=1000, return_inferencedata=True,
progressbar=False, random_seed=42)
# Compare posteriors for school A (most extreme estimate)
print("Posterior mean for School A effect:")
print(f" No pooling: {az.summary(no_pool_trace, var_names=['theta']).loc['theta[0]', 'mean']:.2f}")
print(f" Complete pool: {az.summary(pool_trace, var_names=['theta'])['mean']['theta']:.2f}")
print(f" Hierarchical: {az.summary(hier_trace, var_names=['theta']).loc['theta[0]', 'mean']:.2f}")
print(f" Observed: 28.00 (SE=15)")
print()
print("The hierarchical estimate is shrunk toward the global mean (~8).")
print("This is partial pooling / Bayesian shrinkage.")
The hierarchical model's estimate for School A (≈ 10-12) is shrunk toward the global mean (≈ 7-8), because the observed effect of 28 with SE=15 is uncertain. Schools with smaller standard errors (more data) are shrunk less.
Understanding Partial Pooling: The Shrinkage Formula
For the Gaussian hierarchical model, the posterior mean of has an elegant closed form:
This is a precision-weighted average of the observed school estimate and the global mean :
The weight is the shrinkage factor:
- If (school data is very precise, much more precise than global variation): - use the school's own estimate
- If (school data is noisy, noisier than global variation): - use the global mean
- (no between-school variation): all schools get the complete pooling estimate
- (infinite between-school variation): no pooling, each school is independent
def shrinkage_factor(sigma_j, tau):
"""Compute the shrinkage factor lambda_j for each group."""
return (1/sigma_j**2) / (1/sigma_j**2 + 1/tau**2)
# Illustrate how shrinkage varies with group sample size
sigma_values = np.array([15, 10, 16, 11, 9, 11, 10, 18]) # eight schools
tau_estimate = 7.0 # typical between-school std from posterior
lambdas = shrinkage_factor(sigma_values, tau_estimate)
hier_means = lambdas * effects + (1 - lambdas) * np.mean(effects)
print(f"School | Observed | Sigma | Lambda | Hierarchical")
print("-" * 55)
for s, obs, sig, lam, hier in zip(schools, effects, sigma_values, lambdas, hier_means):
print(f" {s} | {obs:5.1f} | {sig:4.0f} | {lam:.3f} | {hier:6.2f}")
print(f"\nSchools with higher sigma (less data) are shrunk more toward the mean.")
Hierarchical Linear Regression
Hierarchical models extend naturally to regression. Instead of a single regression line, each group gets its own line, with lines drawn from a shared distribution.
This is multilevel regression - the core statistical machinery of many recommendation and personalization systems.
import pymc as pm
import numpy as np
# Simulate: users in 10 cities, with city-specific price sensitivity
np.random.seed(42)
n_cities = 10
users_per_city = np.random.randint(20, 200, n_cities) # sparse for some cities
true_global_alpha = 5.0
true_global_beta = -0.8 # price sensitivity
true_tau_alpha = 1.5
true_tau_beta = 0.3
# Generate city-specific parameters from shared distribution
true_alphas = true_global_alpha + np.random.normal(0, true_tau_alpha, n_cities)
true_betas = true_global_beta + np.random.normal(0, true_tau_beta, n_cities)
# Generate data
cities, prices, purchases = [], [], []
for j in range(n_cities):
n_j = users_per_city[j]
price_j = np.random.uniform(10, 100, n_j)
p_purchase = 1 / (1 + np.exp(-(true_alphas[j] + true_betas[j] * price_j / 10)))
y_j = np.random.binomial(1, p_purchase, n_j)
cities.extend([j] * n_j)
prices.extend(price_j / 10) # scale prices
purchases.extend(y_j)
cities = np.array(cities)
prices = np.array(prices)
purchases = np.array(purchases)
# Hierarchical logistic regression
with pm.Model() as hierarchical_logistic:
# Hyperpriors (global distribution of city effects)
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=5)
mu_beta = pm.Normal('mu_beta', mu=0, sigma=2)
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=2)
sigma_beta = pm.HalfNormal('sigma_beta', sigma=1)
# City-level parameters (non-centered)
alpha_raw = pm.Normal('alpha_raw', mu=0, sigma=1, shape=n_cities)
beta_raw = pm.Normal('beta_raw', mu=0, sigma=1, shape=n_cities)
alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_raw)
beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_raw)
# Likelihood
logit_p = alpha[cities] + beta[cities] * prices
y_obs = pm.Bernoulli('y_obs', logit_p=logit_p, observed=purchases)
trace = pm.sample(2000, tune=1000, return_inferencedata=True,
progressbar=False, random_seed=42)
summary = az.summary(trace, var_names=['mu_alpha', 'mu_beta',
'sigma_alpha', 'sigma_beta'])
print(summary)
print(f"\nTrue global alpha: {true_global_alpha:.2f}, True global beta: {true_global_beta:.2f}")
print(f"Learned mu_alpha: {summary.loc['mu_alpha', 'mean']:.2f}")
print(f"Learned mu_beta: {summary.loc['mu_beta', 'mean']:.2f}")
Connection to Multi-Task Learning
Hierarchical models are the Bayesian perspective on multi-task learning (MTL) - learning multiple related tasks simultaneously, sharing information across them.
| Bayesian Hierarchical | Multi-Task Learning Equivalent |
|---|---|
| Group | Task |
| Group-level parameters | Task-specific model parameters |
| Shared hyperparameters | Shared representation / meta-parameters |
| Partial pooling | Soft parameter sharing across tasks |
| Prior | regularization toward shared parameters |
| Hierarchical GP | Gaussian process multi-task learning |
Hard parameter sharing (most common MTL in DL): All tasks share the same hidden layers, only the output heads differ.
Soft parameter sharing (Bayesian MTL): Each task has its own parameters, but they're regularized to be similar. Equivalent to a hierarchical model with Gaussian prior between task parameters.
import torch
import torch.nn as nn
class MultiTaskNetwork(nn.Module):
"""
Multi-task neural network with soft parameter sharing.
Bayesian interpretation: each task's parameters are drawn
from a shared prior distribution (hierarchical model).
"""
def __init__(self, input_dim, hidden_dim, n_tasks, output_dim=1):
super().__init__()
self.n_tasks = n_tasks
# Shared backbone: captures shared structure across tasks
self.shared = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Task-specific heads (each task has its own output layer)
self.task_heads = nn.ModuleList([
nn.Linear(hidden_dim, output_dim)
for _ in range(n_tasks)
])
def forward(self, x, task_id):
"""Forward pass for a specific task."""
shared_repr = self.shared(x)
return self.task_heads[task_id](shared_repr)
# Bayesian interpretation: the shared backbone captures phi (hyperparameters)
# The task heads capture theta_j (group-specific parameters)
# The L2 weight decay on task heads = Gaussian prior = partial pooling
model = MultiTaskNetwork(input_dim=10, hidden_dim=64, n_tasks=10)
# Train with L2 regularization on task heads only:
optimizer = torch.optim.Adam([
{'params': model.shared.parameters(), 'weight_decay': 1e-4},
{'params': model.task_heads.parameters(), 'weight_decay': 1e-2} # stronger regularization
], lr=1e-3)
print("Multi-task model created with soft parameter sharing")
print("Stronger L2 on task heads = stronger shrinkage toward shared representation")
When Hierarchical Models Help Most
| Scenario | Why Hierarchical Helps |
|---|---|
| New user/item in recommendation | Borrows from global distribution, not cold start |
| Rare event modeling (city-level fraud) | Small groups borrow strength from common distribution |
| Clinical trials across sites | Sites vary; partial pooling gives more accurate site estimates |
| NLP: language models for rare languages | Multilingual models share cross-lingual representations |
| A/B testing across heterogeneous segments | Segments with little data borrow from global effect |
| Personalized pricing | City/user-level models regularized toward global |
Interview Questions
Q1: What is partial pooling and why is it better than no pooling or complete pooling?
Partial pooling is the hierarchical Bayesian approach: each group has its own parameter , but these parameters are assumed to come from a common distribution . The result is shrinkage: each group's estimate is pulled toward the global mean by an amount proportional to how noisy the group's data is. No pooling (separate estimates per group) overfits sparse groups - a group with 3 observations will have a terrible estimate. Complete pooling (one estimate for all groups) underfits - it ignores real between-group variation. Partial pooling adaptively balances these extremes: groups with many observations rely mainly on their own data; groups with few observations borrow heavily from the global estimate. In recommendation, this is the solution to the cold-start problem: new users/items get the global prior, which is immediately updated as data accumulates.
Q2: What is the non-centered parameterization and why is it important for MCMC?
In a hierarchical model, the centered parameterization writes . When is small (groups are similar), the posterior forms a "funnel": near , the are tightly clustered around , creating a highly curved posterior that HMC struggles to sample. The non-centered parameterization avoids this: write where . Now and are nearly independent, and the funnel geometry disappears. HMC explores efficiently across all values of , including near zero. This is the most important practical trick for fitting hierarchical models with MCMC. PyMC's pm.Normal('eta', 0, 1) + pm.Deterministic('theta', mu + tau * eta) is the standard implementation.
Q3: How does the hierarchical model connect to regularization in standard ML?
The connection is exact. In a hierarchical model, the prior induces shrinkage of each toward . If we compute the MAP estimate, the optimization objective becomes: . The second term is exactly L2 regularization of each group's parameters toward . In multi-task neural networks, this corresponds to: penalizing each task's parameters (heads) for deviating from the global average. The regularization strength () is learned from data, unlike typical ML where is fixed by cross-validation. Hierarchical MAP is empirical Bayes - it learns the regularization strength from data.
Q4: What is shrinkage, and when should you shrink toward zero vs. shrink toward a global mean?
Shrinkage is the pull of Bayesian estimation toward the prior mean. Shrink toward zero when you believe parameters should be small in magnitude (L2/Lasso regularization: prior centered at zero). Shrink toward a global mean when you believe groups share a common baseline but differ around it (hierarchical model: prior centered at the learned global mean ). The key difference: with hierarchical shrinkage, the center is estimated from the data. If all cities show positive price elasticity, the prior mean will be positive, and city-specific estimates are shrunk toward that positive value - not toward zero. This is data-adaptive regularization, and it's strictly better than fixed L2 when you have multiple related groups. The James-Stein estimator (1961) proved formally that shrinking multiple means toward their common mean has strictly lower MSE than separate MLE estimation - this is why hierarchical modeling improves accuracy.
Q5: How would you implement a hierarchical model for multi-armed bandits (many arms, sparse data per arm)?
In a Thompson sampling bandit, maintain Beta posteriors per arm . The hierarchical extension: assume each arm's true CTR , where are learned from all arms. Concretely: (1) maintain hyperparameters representing the global distribution of CTRs; (2) for each arm, the posterior is where are arm-specific successes and failures; (3) update using empirical Bayes (maximize marginal likelihood) or with a hyperprior. New arms start with - borrowing from global experience rather than an uninformative prior. This gives better initialization for new arms and faster convergence to their true CTR. Used in production bandit systems at recommendation platforms where new content items are added continuously.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Hierarchical Models demo on the EngineersOfAI Playground - no code required.
:::
