Skip to main content

Hierarchical Models

The Problem of Many Groups with Sparse Data

You're building a recommendation system for 10,000 cities. Each city has different user preferences. You want to learn a recommendation model per city.

  • Option 1 - No pooling (separate models): Train a completely separate model for each city. Problem: most cities have very few users. Your model for Podunk, Iowa (50 users) will be wildly overfit. A city you just launched in has zero historical data - no model at all.

  • Option 2 - Complete pooling (one global model): Train a single model for all cities. Problem: you ignore real differences between cities. New York and rural Montana have genuinely different preferences. The global model is wrong everywhere.

  • Option 3 - Partial pooling (hierarchical model): Learn city-specific models, but share statistical strength across cities by assuming they're drawn from a common distribution. Cities with little data borrow strength from the global estimate; cities with lots of data rely mainly on their own data.

This is the hierarchical model approach - and it consistently outperforms both naive alternatives.

The Mathematical Structure

A hierarchical (multilevel) Bayesian model has at least two levels:

Level 1 - Observation model (within-group): yijP(yθj),i=1,,nj,j=1,,Jy_{ij} \sim P(y \mid \theta_j), \quad i = 1, \ldots, n_j, \quad j = 1, \ldots, J

Level 2 - Group-level model (across groups): θjP(θϕ)\theta_j \sim P(\theta \mid \phi)

Level 3 - Hyperprior (on global parameters): ϕP(ϕ)\phi \sim P(\phi)

The group-level parameters θj\theta_j are drawn from a common distribution parameterized by ϕ\phi (hyperparameters). This is the key: rather than treating θj\theta_j as fixed unknowns or assuming all θj\theta_j are identical, they're treated as samples from a shared distribution.

The posterior: By Bayes theorem, the joint posterior is:

P(θ1,,θJ,ϕy)P(ϕ)j=1JP(θjϕ)i=1njP(yijθj)P(\theta_1, \ldots, \theta_J, \phi \mid \mathbf{y}) \propto P(\phi) \prod_{j=1}^J P(\theta_j \mid \phi) \prod_{i=1}^{n_j} P(y_{ij} \mid \theta_j)

The Classic Example: Eight Schools

The famous "eight schools" dataset (Rubin, 1981) shows the effectiveness of a test preparation program at 8 different schools. The key challenge: the estimates for individual schools are noisy (small sample sizes), but schools may genuinely differ.

SchoolEffect (hat{theta}_j)Std Error (sigma_j)
A2815
B810
C-316
D711
E-19
F111
G1810
H1218

No pooling: School A has effect 28 - but with SE=15, this is very uncertain. Complete pooling: All schools have the same effect - ignores real variation. Hierarchical: Each school effect is drawn from N(μ,τ2)\mathcal{N}(\mu, \tau^2) - learns both individual effects and their variability.

import pymc as pm
import arviz as az
import numpy as np

# Eight Schools data
schools = ["A", "B", "C", "D", "E", "F", "G", "H"]
effects = np.array([28, 8, -3, 7, -1, 1, 18, 12], dtype=float)
sigma_j = np.array([15, 10, 16, 11, 9, 11, 10, 18], dtype=float)
J = len(schools)

# ============================================================
# Model 1: No Pooling -- independent estimates per school
# ============================================================
with pm.Model() as no_pool_model:
theta = pm.Normal('theta', mu=0, sigma=100, shape=J)
y_obs = pm.Normal('y_obs', mu=theta, sigma=sigma_j, observed=effects)
no_pool_trace = pm.sample(2000, tune=1000, return_inferencedata=True,
progressbar=False, random_seed=42)

# ============================================================
# Model 2: Complete Pooling -- single effect for all schools
# ============================================================
with pm.Model() as pool_model:
theta = pm.Normal('theta', mu=0, sigma=100)
y_obs = pm.Normal('y_obs', mu=theta, sigma=sigma_j, observed=effects)
pool_trace = pm.sample(2000, tune=1000, return_inferencedata=True,
progressbar=False, random_seed=42)

# ============================================================
# Model 3: Hierarchical (Partial Pooling) -- the correct approach
# Non-centered parameterization to avoid funnel geometry
# ============================================================
with pm.Model() as hierarchical_model:
# Hyperpriors (global distribution of school effects)
mu = pm.Normal('mu', mu=0, sigma=10) # global mean effect
tau = pm.HalfNormal('tau', sigma=10) # between-school variability

# Non-centered parameterization: theta_j = mu + tau * eta_j
# eta_j ~ N(0, 1) avoids the funnel problem in HMC
eta = pm.Normal('eta', mu=0, sigma=1, shape=J)
theta = pm.Deterministic('theta', mu + tau * eta)

# Likelihood
y_obs = pm.Normal('y_obs', mu=theta, sigma=sigma_j, observed=effects)

hier_trace = pm.sample(2000, tune=1000, return_inferencedata=True,
progressbar=False, random_seed=42)

# Compare posteriors for school A (most extreme estimate)
print("Posterior mean for School A effect:")
print(f" No pooling: {az.summary(no_pool_trace, var_names=['theta']).loc['theta[0]', 'mean']:.2f}")
print(f" Complete pool: {az.summary(pool_trace, var_names=['theta'])['mean']['theta']:.2f}")
print(f" Hierarchical: {az.summary(hier_trace, var_names=['theta']).loc['theta[0]', 'mean']:.2f}")
print(f" Observed: 28.00 (SE=15)")
print()
print("The hierarchical estimate is shrunk toward the global mean (~8).")
print("This is partial pooling / Bayesian shrinkage.")

The hierarchical model's estimate for School A (≈ 10-12) is shrunk toward the global mean (≈ 7-8), because the observed effect of 28 with SE=15 is uncertain. Schools with smaller standard errors (more data) are shrunk less.

Understanding Partial Pooling: The Shrinkage Formula

For the Gaussian hierarchical model, the posterior mean of θj\theta_j has an elegant closed form:

θ^j(hier)=1σj2yj+1τ2μ1σj2+1τ2\hat{\theta}_j^{(hier)} = \frac{\frac{1}{\sigma_j^2} y_j + \frac{1}{\tau^2} \mu}{\frac{1}{\sigma_j^2} + \frac{1}{\tau^2}}

This is a precision-weighted average of the observed school estimate yjy_j and the global mean μ\mu:

θ^j(hier)=λjyj+(1λj)μ,λj=1/σj21/σj2+1/τ2\hat{\theta}_j^{(hier)} = \lambda_j \cdot y_j + (1 - \lambda_j) \cdot \mu, \quad \lambda_j = \frac{1/\sigma_j^2}{1/\sigma_j^2 + 1/\tau^2}

The weight λj[0,1]\lambda_j \in [0, 1] is the shrinkage factor:

  • If σj2τ2\sigma_j^2 \ll \tau^2 (school data is very precise, much more precise than global variation): λj1\lambda_j \to 1 - use the school's own estimate
  • If σj2τ2\sigma_j^2 \gg \tau^2 (school data is noisy, noisier than global variation): λj0\lambda_j \to 0 - use the global mean
  • τ2=0\tau^2 = 0 (no between-school variation): all schools get the complete pooling estimate
  • τ2=\tau^2 = \infty (infinite between-school variation): no pooling, each school is independent
def shrinkage_factor(sigma_j, tau):
"""Compute the shrinkage factor lambda_j for each group."""
return (1/sigma_j**2) / (1/sigma_j**2 + 1/tau**2)

# Illustrate how shrinkage varies with group sample size
sigma_values = np.array([15, 10, 16, 11, 9, 11, 10, 18]) # eight schools
tau_estimate = 7.0 # typical between-school std from posterior

lambdas = shrinkage_factor(sigma_values, tau_estimate)
hier_means = lambdas * effects + (1 - lambdas) * np.mean(effects)

print(f"School | Observed | Sigma | Lambda | Hierarchical")
print("-" * 55)
for s, obs, sig, lam, hier in zip(schools, effects, sigma_values, lambdas, hier_means):
print(f" {s} | {obs:5.1f} | {sig:4.0f} | {lam:.3f} | {hier:6.2f}")

print(f"\nSchools with higher sigma (less data) are shrunk more toward the mean.")

Hierarchical Linear Regression

Hierarchical models extend naturally to regression. Instead of a single regression line, each group gets its own line, with lines drawn from a shared distribution.

yij=αj+βjxij+ϵijy_{ij} = \alpha_j + \beta_j x_{ij} + \epsilon_{ij} (αj,βj)N((μα,μβ),Σ)(\alpha_j, \beta_j) \sim \mathcal{N}((\mu_\alpha, \mu_\beta), \boldsymbol{\Sigma}) (μα,μβ,Σ)hyperpriors(\mu_\alpha, \mu_\beta, \boldsymbol{\Sigma}) \sim \text{hyperpriors}

This is multilevel regression - the core statistical machinery of many recommendation and personalization systems.

import pymc as pm
import numpy as np

# Simulate: users in 10 cities, with city-specific price sensitivity
np.random.seed(42)
n_cities = 10
users_per_city = np.random.randint(20, 200, n_cities) # sparse for some cities
true_global_alpha = 5.0
true_global_beta = -0.8 # price sensitivity
true_tau_alpha = 1.5
true_tau_beta = 0.3

# Generate city-specific parameters from shared distribution
true_alphas = true_global_alpha + np.random.normal(0, true_tau_alpha, n_cities)
true_betas = true_global_beta + np.random.normal(0, true_tau_beta, n_cities)

# Generate data
cities, prices, purchases = [], [], []
for j in range(n_cities):
n_j = users_per_city[j]
price_j = np.random.uniform(10, 100, n_j)
p_purchase = 1 / (1 + np.exp(-(true_alphas[j] + true_betas[j] * price_j / 10)))
y_j = np.random.binomial(1, p_purchase, n_j)
cities.extend([j] * n_j)
prices.extend(price_j / 10) # scale prices
purchases.extend(y_j)

cities = np.array(cities)
prices = np.array(prices)
purchases = np.array(purchases)

# Hierarchical logistic regression
with pm.Model() as hierarchical_logistic:
# Hyperpriors (global distribution of city effects)
mu_alpha = pm.Normal('mu_alpha', mu=0, sigma=5)
mu_beta = pm.Normal('mu_beta', mu=0, sigma=2)
sigma_alpha = pm.HalfNormal('sigma_alpha', sigma=2)
sigma_beta = pm.HalfNormal('sigma_beta', sigma=1)

# City-level parameters (non-centered)
alpha_raw = pm.Normal('alpha_raw', mu=0, sigma=1, shape=n_cities)
beta_raw = pm.Normal('beta_raw', mu=0, sigma=1, shape=n_cities)

alpha = pm.Deterministic('alpha', mu_alpha + sigma_alpha * alpha_raw)
beta = pm.Deterministic('beta', mu_beta + sigma_beta * beta_raw)

# Likelihood
logit_p = alpha[cities] + beta[cities] * prices
y_obs = pm.Bernoulli('y_obs', logit_p=logit_p, observed=purchases)

trace = pm.sample(2000, tune=1000, return_inferencedata=True,
progressbar=False, random_seed=42)

summary = az.summary(trace, var_names=['mu_alpha', 'mu_beta',
'sigma_alpha', 'sigma_beta'])
print(summary)
print(f"\nTrue global alpha: {true_global_alpha:.2f}, True global beta: {true_global_beta:.2f}")
print(f"Learned mu_alpha: {summary.loc['mu_alpha', 'mean']:.2f}")
print(f"Learned mu_beta: {summary.loc['mu_beta', 'mean']:.2f}")

Connection to Multi-Task Learning

Hierarchical models are the Bayesian perspective on multi-task learning (MTL) - learning multiple related tasks simultaneously, sharing information across them.

Bayesian HierarchicalMulti-Task Learning Equivalent
Group jjTask jj
Group-level parameters θj\theta_jTask-specific model parameters
Shared hyperparameters ϕ\phiShared representation / meta-parameters
Partial poolingSoft parameter sharing across tasks
Prior θjP(θϕ)\theta_j \sim P(\theta \mid \phi)L2L_2 regularization toward shared parameters
Hierarchical GPGaussian process multi-task learning

Hard parameter sharing (most common MTL in DL): All tasks share the same hidden layers, only the output heads differ.

Soft parameter sharing (Bayesian MTL): Each task has its own parameters, but they're regularized to be similar. Equivalent to a hierarchical model with Gaussian prior between task parameters.

import torch
import torch.nn as nn

class MultiTaskNetwork(nn.Module):
"""
Multi-task neural network with soft parameter sharing.
Bayesian interpretation: each task's parameters are drawn
from a shared prior distribution (hierarchical model).
"""
def __init__(self, input_dim, hidden_dim, n_tasks, output_dim=1):
super().__init__()
self.n_tasks = n_tasks

# Shared backbone: captures shared structure across tasks
self.shared = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)

# Task-specific heads (each task has its own output layer)
self.task_heads = nn.ModuleList([
nn.Linear(hidden_dim, output_dim)
for _ in range(n_tasks)
])

def forward(self, x, task_id):
"""Forward pass for a specific task."""
shared_repr = self.shared(x)
return self.task_heads[task_id](shared_repr)

# Bayesian interpretation: the shared backbone captures phi (hyperparameters)
# The task heads capture theta_j (group-specific parameters)
# The L2 weight decay on task heads = Gaussian prior = partial pooling
model = MultiTaskNetwork(input_dim=10, hidden_dim=64, n_tasks=10)
# Train with L2 regularization on task heads only:
optimizer = torch.optim.Adam([
{'params': model.shared.parameters(), 'weight_decay': 1e-4},
{'params': model.task_heads.parameters(), 'weight_decay': 1e-2} # stronger regularization
], lr=1e-3)
print("Multi-task model created with soft parameter sharing")
print("Stronger L2 on task heads = stronger shrinkage toward shared representation")

When Hierarchical Models Help Most

ScenarioWhy Hierarchical Helps
New user/item in recommendationBorrows from global distribution, not cold start
Rare event modeling (city-level fraud)Small groups borrow strength from common distribution
Clinical trials across sitesSites vary; partial pooling gives more accurate site estimates
NLP: language models for rare languagesMultilingual models share cross-lingual representations
A/B testing across heterogeneous segmentsSegments with little data borrow from global effect
Personalized pricingCity/user-level models regularized toward global

Interview Questions

Q1: What is partial pooling and why is it better than no pooling or complete pooling?

Partial pooling is the hierarchical Bayesian approach: each group has its own parameter θj\theta_j, but these parameters are assumed to come from a common distribution θjP(θϕ)\theta_j \sim P(\theta|\phi). The result is shrinkage: each group's estimate is pulled toward the global mean by an amount proportional to how noisy the group's data is. No pooling (separate estimates per group) overfits sparse groups - a group with 3 observations will have a terrible estimate. Complete pooling (one estimate for all groups) underfits - it ignores real between-group variation. Partial pooling adaptively balances these extremes: groups with many observations rely mainly on their own data; groups with few observations borrow heavily from the global estimate. In recommendation, this is the solution to the cold-start problem: new users/items get the global prior, which is immediately updated as data accumulates.

Q2: What is the non-centered parameterization and why is it important for MCMC?

In a hierarchical model, the centered parameterization writes θjN(μ,τ2)\theta_j \sim \mathcal{N}(\mu, \tau^2). When τ\tau is small (groups are similar), the posterior forms a "funnel": near τ0\tau \approx 0, the θj\theta_j are tightly clustered around μ\mu, creating a highly curved posterior that HMC struggles to sample. The non-centered parameterization avoids this: write θj=μ+τηj\theta_j = \mu + \tau \eta_j where ηjN(0,1)\eta_j \sim \mathcal{N}(0, 1). Now ηj\eta_j and τ\tau are nearly independent, and the funnel geometry disappears. HMC explores efficiently across all values of τ\tau, including near zero. This is the most important practical trick for fitting hierarchical models with MCMC. PyMC's pm.Normal('eta', 0, 1) + pm.Deterministic('theta', mu + tau * eta) is the standard implementation.

Q3: How does the hierarchical model connect to regularization in standard ML?

The connection is exact. In a hierarchical model, the prior θjN(μ,τ2I)\theta_j \sim \mathcal{N}(\mu, \tau^2 I) induces shrinkage of each θj\theta_j toward μ\mu. If we compute the MAP estimate, the optimization objective becomes: j[log-likelihood for group j]+j12τ2θjμ2\sum_j [\text{log-likelihood for group } j] + \sum_j \frac{1}{2\tau^2}\|\theta_j - \mu\|^2. The second term is exactly L2 regularization of each group's parameters toward μ\mu. In multi-task neural networks, this corresponds to: penalizing each task's parameters (heads) for deviating from the global average. The regularization strength (1/τ21/\tau^2) is learned from data, unlike typical ML where λ\lambda is fixed by cross-validation. Hierarchical MAP is empirical Bayes - it learns the regularization strength from data.

Q4: What is shrinkage, and when should you shrink toward zero vs. shrink toward a global mean?

Shrinkage is the pull of Bayesian estimation toward the prior mean. Shrink toward zero when you believe parameters should be small in magnitude (L2/Lasso regularization: prior centered at zero). Shrink toward a global mean when you believe groups share a common baseline but differ around it (hierarchical model: prior centered at the learned global mean μ\mu). The key difference: with hierarchical shrinkage, the center is estimated from the data. If all cities show positive price elasticity, the prior mean will be positive, and city-specific estimates are shrunk toward that positive value - not toward zero. This is data-adaptive regularization, and it's strictly better than fixed L2 when you have multiple related groups. The James-Stein estimator (1961) proved formally that shrinking multiple means toward their common mean has strictly lower MSE than separate MLE estimation - this is why hierarchical modeling improves accuracy.

Q5: How would you implement a hierarchical model for multi-armed bandits (many arms, sparse data per arm)?

In a Thompson sampling bandit, maintain Beta posteriors Beta(αk,βk)\text{Beta}(\alpha_k, \beta_k) per arm kk. The hierarchical extension: assume each arm's true CTR θkBeta(α0,β0)\theta_k \sim \text{Beta}(\alpha_0, \beta_0), where (α0,β0)(\alpha_0, \beta_0) are learned from all arms. Concretely: (1) maintain hyperparameters (α0,β0)(\alpha_0, \beta_0) representing the global distribution of CTRs; (2) for each arm, the posterior is Beta(α0+hk,β0+tk)\text{Beta}(\alpha_0 + h_k, \beta_0 + t_k) where hk,tkh_k, t_k are arm-specific successes and failures; (3) update (α0,β0)(\alpha_0, \beta_0) using empirical Bayes (maximize marginal likelihood) or with a hyperprior. New arms start with Beta(α0,β0)\text{Beta}(\alpha_0, \beta_0) - borrowing from global experience rather than an uninformative prior. This gives better initialization for new arms and faster convergence to their true CTR. Used in production bandit systems at recommendation platforms where new content items are added continuously.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Hierarchical Models demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.