Router Mechanisms - How Tokens Get Assigned to Experts
The Traffic Controller Problem
Imagine running an airport. Thousands of flights arrive every day. Each flight needs to be assigned to a gate. If you assign all flights to gate 1, gate 1 is overwhelmed and every other gate sits empty. If you assign randomly, you get chaos with no optimization. What you want is intelligent routing: planes going to terminal B get assigned to gates in terminal B, and the load is roughly balanced across all gates.
The routing problem in MoE models is structurally identical. Thousands of tokens arrive per second. Each token needs to be assigned to one or more of experts. If your router sends all tokens to expert 1, that expert handles everything (and is effectively a single dense model - you've wasted the other N-1 experts). If you route randomly, you lose all the benefits of expert specialization.
The router is the most critical component in a MoE architecture. A good router enables specialization and load balancing simultaneously. A bad router causes expert collapse, training instability, and wasted capacity.
This lesson covers the algorithms researchers have developed to solve the routing problem.
Why This Exists - The Difficulties of Routing
At first glance, routing seems straightforward: train a linear layer to assign tokens to experts, and the model will learn good assignments. In practice, several problems arise:
Expert collapse: a small subset of experts receives the vast majority of tokens. The other experts receive so few tokens that they never learn meaningful representations. The model degenerates toward a small dense model.
Discrete selection bottleneck: the top-k selection operation is discrete (non-differentiable). Gradients can only flow through the routing weights (the softmax scores), not through the routing decisions (which experts were selected). This creates a challenging optimization landscape.
Load imbalance amplification: early in training, when router weights are random, some experts happen to receive slightly more tokens. Those experts update more, becoming slightly better. Better experts receive higher router scores, receiving even more tokens. The feedback loop amplifies initial random noise into severe imbalance.
Communication bottleneck: in distributed training with expert parallelism, routing tokens to different devices creates communication overhead. Imbalanced routing means some devices are overloaded while others are idle.
These problems motivated the development of the routing mechanisms described in this lesson.
The Linear Router - Standard Token Choice
The most common routing approach: a single linear transformation that scores all experts, followed by top-k selection.
The Basic Formulation
For token representation :
where and (the router weight matrix and bias).
The top-k experts are selected, and their scores are renormalized:
The MoE output is then:
The Load Balancing Problem
Without additional mechanisms, linear routers catastrophically collapse. To see why: consider the beginning of training, when all router weights are random. By random chance, expert 3 gets slightly more tokens than expert 4. Expert 3 updates more per batch, becoming slightly better at its tasks. The router learns to assign slightly higher scores to expert 3. This creates a feedback loop:
More tokens → more updates → higher scores → even more tokens → complete collapse
After convergence, you might have expert 3 handling 80% of all tokens and experts 1, 2, 4–8 handling 2–3% each. This is expert collapse.
import torch
import torch.nn as nn
import torch.nn.functional as F
def load_balance_loss_switch(
router_logits: torch.Tensor, # [T, N] - T tokens, N experts
top_k_indices: torch.Tensor, # [T, k] - selected expert indices
n_experts: int,
alpha: float = 0.01,
) -> torch.Tensor:
"""
Switch Transformer load balancing loss (Fedus et al., 2022).
Minimizes the inner product of:
- f_i: fraction of tokens dispatched to expert i (actual load)
- p_i: average softmax probability assigned to expert i (desired load)
This loss is minimized when all experts receive equal fractions of tokens.
The gradient flows through p_i (the softmax probabilities), not f_i (the discrete counts).
Args:
router_logits: Raw router outputs before softmax
top_k_indices: Expert indices selected for each token
n_experts: Total number of experts
alpha: Loss scale factor
Returns:
Scalar auxiliary load balancing loss
"""
T = router_logits.shape[0]
# Compute routing probabilities (differentiable)
router_probs = F.softmax(router_logits, dim=-1) # [T, N]
p = router_probs.mean(dim=0) # [N] - mean probability per expert
# Compute token fraction per expert (non-differentiable counts)
# One-hot encode expert assignments
one_hot = torch.zeros(T, n_experts, device=router_logits.device)
for k_idx in range(top_k_indices.shape[1]):
one_hot.scatter_(1, top_k_indices[:, k_idx:k_idx+1], 1.0)
# Note: if top_k > 1, a token can contribute to multiple experts
f = one_hot.mean(dim=0) # [N] - fraction of tokens per expert
# Load balance loss: n_experts * sum(f_i * p_i)
# When perfectly balanced: f_i = p_i = 1/n_experts for all i
# Loss = n_experts * sum((1/n)^2) = n_experts * n * (1/n)^2 = 1
load_balance = alpha * n_experts * (f * p).sum()
return load_balance
Noisy Top-K Gating - Shazeer et al. (2017)
The original sparse MoE paper (Shazeer et al., 2017) introduced noisy top-k gating to encourage exploration during training and prevent early collapse.
The idea: add Gaussian noise to the router logits before top-k selection during training. This noise ensures that all experts have some probability of being selected even if their baseline scores are low, giving them a chance to learn and develop useful representations.
where and is a learned weight matrix that controls per-token noise magnitude.
The noise has magnitude proportional to a second learned transformation , allowing the model to learn appropriate noise levels for different tokens.
class NoisyTopKRouter(nn.Module):
"""
Noisy Top-K router from Shazeer et al. (2017).
Adds learned Gaussian noise to router logits during training,
encouraging exploration and preventing expert collapse.
"""
def __init__(self, d_model: int, n_experts: int, top_k: int = 2):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
# Standard router weights
self.w_router = nn.Linear(d_model, n_experts, bias=False)
# Noise magnitude weights (learned)
self.w_noise = nn.Linear(d_model, n_experts, bias=False)
def forward(
self,
x: torch.Tensor,
training: bool = True,
) -> tuple:
"""
Args:
x: [batch * seq_len, d_model]
training: Whether to add noise (only during training)
Returns:
routing_weights: [T, n_experts] normalized routing weights
top_k_indices: [T, top_k] selected expert indices
"""
# Compute base logits
logits = self.w_router(x) # [T, N]
if training:
# Add Gaussian noise scaled by learned magnitude
noise_std = F.softplus(self.w_noise(x)) # [T, N], always positive
noise = torch.randn_like(logits)
logits = logits + noise * noise_std
# Softmax over all experts
scores = F.softmax(logits, dim=-1) # [T, N]
# Select top-k
top_k_scores, top_k_indices = torch.topk(scores, self.top_k, dim=-1)
# Normalize top-k scores to sum to 1
top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-8)
# Build full routing weight matrix [T, N] with zeros for non-selected
routing_weights = torch.zeros_like(scores)
routing_weights.scatter_(1, top_k_indices, top_k_weights)
return routing_weights, top_k_indices
The Auxiliary Load Balancing Loss in Detail
The Switch Transformer load balance loss is the standard approach today, but it's worth understanding exactly why it works and its limitations.
Why the gradient doesn't flow through f_i: the fraction of tokens dispatched to expert , , is computed by counting how many tokens selected expert in their top-k. This counting operation is non-differentiable - you can't backpropagate through an argmax or topk. So has no gradient.
How the gradient flows through p_i: the mean routing probability is computed from the softmax outputs, which are differentiable. So gradients can flow back through to the router weights.
The mechanism: the loss is minimized when is small for all . Since has no gradient, the optimizer reduces for experts that receive many tokens (high ). This lowers those experts' routing probabilities, ultimately routing fewer tokens to them in future batches. The feedback loop that caused collapse is reversed.
def z_loss(
router_logits: torch.Tensor,
beta: float = 0.001,
) -> torch.Tensor:
"""
Z-loss from ST-MoE (Zoph et al., 2022).
Additional regularization that prevents router logits from becoming
too large in magnitude, which causes numerical instability in large
MoE models.
L_z = beta * (1/T) * sum_t (log sum_n exp(logit_{t,n}))^2
This penalizes large logit magnitudes, keeping the softmax distribution
from becoming too sharp (near one-hot).
"""
# Log-sum-exp of logits for each token
log_z = torch.logsumexp(router_logits, dim=-1) # [T]
# Z-loss: penalize large log-partition function
z_loss = beta * (log_z ** 2).mean()
return z_loss
Expert Choice Routing - An Alternative Paradigm
Standard routing is token choice: each token selects its top-k experts. An alternative is expert choice routing: each expert selects its top-k tokens.
Zhou et al. (2022) "Mixture-of-Experts with Expert Choice Routing" proposed this as a solution to load imbalance. With expert choice, each expert selects exactly tokens per batch (the tokens for which the expert is most relevant). This guarantees perfect load balance by construction.
Each expert selects the top- tokens from the batch, where ensures each expert handles an equal share of tokens.
Advantages:
- Perfect load balance by construction - no auxiliary loss needed
- Each token is guaranteed to be processed by at least 0 experts (though average is k)
- Simplifies distributed training (experts have predictable workloads)
Disadvantages:
- Some tokens may be processed by 0 experts, others by 3+ (unequal coverage)
- Cannot be used for autoregressive generation at inference time: you don't have the full batch of tokens when generating token-by-token
- Batch-level coupling makes inference harder
class ExpertChoiceRouter(nn.Module):
"""
Expert Choice Router (Zhou et al., 2022).
Each expert selects the top-m tokens from the batch,
guaranteeing perfect load balance.
NOTE: This is for training only. Inference requires a different
approach (e.g., token choice with smaller capacity factor).
"""
def __init__(self, d_model: int, n_experts: int, expert_capacity: int):
super().__init__()
self.n_experts = n_experts
self.expert_capacity = expert_capacity # m tokens per expert
self.router = nn.Linear(d_model, n_experts, bias=False)
def forward(self, x: torch.Tensor) -> tuple:
"""
Args:
x: [T, d_model] - T tokens in the batch
Returns:
dispatch_mask: [T, n_experts] sparse assignment matrix
combine_weights: [T, n_experts] weighted combination matrix
"""
T = x.shape[0]
# Compute scores: how relevant is each token for each expert?
scores = F.softmax(self.router(x), dim=-1) # [T, n_experts]
# Each expert selects top-m tokens (token dimension = dim 0)
# Transpose: [n_experts, T]
scores_T = scores.T # [n_experts, T]
# Top-m selection per expert
top_m_scores, top_m_indices = torch.topk(
scores_T,
self.expert_capacity,
dim=-1,
) # Both [n_experts, m]
# Build dispatch mask [T, n_experts]: 1 if expert e processes token t
dispatch_mask = torch.zeros(T, self.n_experts, device=x.device)
for expert_idx in range(self.n_experts):
dispatch_mask[top_m_indices[expert_idx], expert_idx] = 1.0
# Normalize combine weights (subset of scores for selected tokens)
combine_weights = scores * dispatch_mask # [T, n_experts]
combine_sum = combine_weights.sum(dim=-1, keepdim=True).clamp(min=1e-8)
combine_weights = combine_weights / combine_sum
return dispatch_mask, combine_weights
Token Choice vs. Expert Choice - The Trade-Off
| Aspect | Token Choice | Expert Choice |
|---|---|---|
| Load balance | Requires auxiliary loss | Perfect by construction |
| Token coverage | Every token processed by exactly k experts | Some tokens may get 0 or k+ experts |
| Inference compatibility | Works naturally (greedy per-token) | Requires batch-level aggregation |
| Implementation complexity | Simpler | More complex, batch-dependent |
| Typical use | Production models (Mixtral, DeepSeek) | Research, training-time only |
| Drop tokens? | Only if capacity exceeded | Never (each expert processes exactly m tokens) |
For production models, token choice with auxiliary load balancing is the standard approach because it's compatible with autoregressive inference. Expert choice is primarily a training technique that requires modifications for deployment.
Switch Transformer - Top-1 Routing for Simplicity
Switch Transformer (Fedus et al., 2022) simplified MoE routing by using top-1 routing: each token is sent to exactly one expert. This eliminates the complexity of combining multiple expert outputs and reduces communication overhead in distributed training.
The key insight: empirically, top-1 routing performs comparably to top-2 or higher on most tasks, while being substantially simpler to implement and cheaper to run.
Top-1 routing is also more stable: with only one expert active per token, there's no possibility of two experts "canceling out" if they have conflicting representations.
class SwitchRouter(nn.Module):
"""
Switch Transformer router: top-1 routing with auxiliary load balance loss.
(Fedus et al., 2022)
Simplified MoE with exactly 1 expert per token.
"""
def __init__(
self,
d_model: int,
n_experts: int,
capacity_factor: float = 1.25,
aux_loss_alpha: float = 0.01,
):
super().__init__()
self.n_experts = n_experts
self.capacity_factor = capacity_factor
self.aux_loss_alpha = aux_loss_alpha
self.router = nn.Linear(d_model, n_experts, bias=False)
def forward(self, x: torch.Tensor) -> tuple:
"""
Args:
x: [T, d_model]
Returns:
routing_weights: [T, n_experts] (sparse: one non-zero per row)
top_1_indices: [T] - which expert handles each token
aux_loss: Load balance auxiliary loss
n_dropped: Number of tokens dropped due to capacity overflow
"""
T = x.shape[0]
# Compute router scores
logits = self.router(x) # [T, n_experts]
probs = F.softmax(logits, dim=-1) # [T, n_experts]
# Top-1 selection: each token goes to its highest-scoring expert
top_1_scores, top_1_indices = probs.max(dim=-1) # Both [T]
# Expert capacity: max tokens per expert
# capacity_factor > 1 gives buffer for load imbalance
capacity = int(self.capacity_factor * T / self.n_experts)
# Check for overflow: if an expert receives more tokens than capacity,
# drop the excess tokens (they pass through as zeros)
n_dropped = 0
routing_weights = torch.zeros(T, self.n_experts, device=x.device)
# Fill in routing weights, respecting capacity constraints
expert_counts = torch.zeros(self.n_experts, device=x.device)
for t_idx in range(T):
e_idx = top_1_indices[t_idx].item()
if expert_counts[e_idx] < capacity:
routing_weights[t_idx, e_idx] = top_1_scores[t_idx]
expert_counts[e_idx] += 1
else:
n_dropped += 1 # Token is dropped
# Auxiliary load balance loss
# Using mean probabilities and token fractions
token_fraction = torch.zeros(self.n_experts, device=x.device)
token_fraction.scatter_add_(0, top_1_indices, torch.ones(T, device=x.device))
token_fraction = token_fraction / T
mean_probs = probs.mean(dim=0) # [n_experts]
aux_loss = self.aux_loss_alpha * self.n_experts * (token_fraction * mean_probs).sum()
return routing_weights, top_1_indices, aux_loss, n_dropped
The Importance Factor - Additional Load Balancing
Shazeer et al. (2017) introduced an additional concept: the importance of an expert, defined as the sum of routing scores assigned to it:
where is the routing score (pre-top-k normalization) assigned to expert for token .
A balanced model should have all experts with similar importance. The importance loss penalizes high variance in importance across experts:
where CV is the coefficient of variation (std/mean). This encourages all experts to receive similar total routing probability, complementing the per-batch load balancing loss.
Production Engineering Notes
Monitoring Router Behavior
In production, monitoring the router is essential for diagnosing problems:
class RouterMonitor:
"""
Monitor router behavior during training and inference.
Track expert utilization, load balance, and routing patterns.
"""
def __init__(self, n_experts: int, n_layers: int):
self.n_experts = n_experts
self.n_layers = n_layers
self.reset_stats()
def reset_stats(self):
self.expert_counts = {
layer: [0] * self.n_experts
for layer in range(self.n_layers)
}
self.total_tokens = 0
def record_routing(
self,
layer_idx: int,
top_k_indices: torch.Tensor,
):
"""Record which experts handled which tokens."""
self.total_tokens += top_k_indices.shape[0]
for k_idx in range(top_k_indices.shape[1]):
for expert_idx in top_k_indices[:, k_idx].tolist():
self.expert_counts[layer_idx][expert_idx] += 1
def load_balance_stats(self) -> dict:
"""Compute load balance statistics across all layers."""
stats = {}
for layer_idx, counts in self.expert_counts.items():
total = sum(counts)
if total == 0:
continue
fractions = [c / total for c in counts]
ideal_fraction = 1.0 / self.n_experts
# CV of expert utilization
mean = sum(fractions) / len(fractions)
variance = sum((f - mean)**2 for f in fractions) / len(fractions)
cv = (variance ** 0.5) / mean if mean > 0 else 0
stats[f"layer_{layer_idx}"] = {
"expert_fractions": fractions,
"max_fraction": max(fractions),
"min_fraction": min(fractions),
"cv": cv, # coefficient of variation - lower is more balanced
"load_balance_ok": max(fractions) < 3.0 * ideal_fraction,
}
return stats
Auxiliary Loss Weight Tuning
The auxiliary loss coefficient requires careful tuning:
- Too low ( less than 0.001): expert collapse occurs during training
- Too high ( greater than 0.1): the auxiliary loss dominates, hurting primary task performance
- Recommended starting point: (used by Switch Transformer, Mixtral)
Monitor both the auxiliary loss value and the actual expert utilization variance. The goal is not zero variance (some specialization is beneficial) but preventing extreme collapse.
:::danger Common Mistake: Forgetting the Auxiliary Loss The single most common mistake in implementing MoE models is training without the auxiliary load balancing loss or with set too small. Without it, expert collapse is virtually guaranteed within the first few thousand training steps. Always include the auxiliary loss and monitor expert utilization during early training. :::
:::warning Token Dropping in Production Switch Transformer and similar architectures can "drop" tokens when an expert's capacity buffer overflows. Dropped tokens are replaced with zeros, effectively causing those tokens to skip the MoE layer. This is acceptable during training (dropout-like regularization) but can cause quality degradation at inference. For production serving, either use a capacity factor greater than 1.0 (buffering extra capacity) or switch to an algorithm that never drops tokens. :::
:::tip Diagnosing Router Issues If a MoE model's performance is poor despite reasonable perplexity, check router behavior first. Signs of router problems: (1) Expert utilization variance greater than 5x the ideal uniform distribution. (2) One or two experts consistently receiving greater than 50% of tokens. (3) Some experts receiving fewer than 1% of tokens. (4) Router confidence increasing monotonically (one-hot routing) - the model is degenerating to a single expert. All of these suggest the auxiliary loss is too small or the router is otherwise misconfigured. :::
Interview Questions and Answers
Q1: Explain the expert collapse problem and how load balancing loss prevents it.
Expert collapse occurs when the routing process creates a feedback loop: experts that receive more tokens update more and become better, causing the router to assign them even more tokens. This degenerates the MoE into a model with only a few active experts, wasting the remaining capacity. The auxiliary load balancing loss prevents this by penalizing imbalanced routing. The loss is , where is the fraction of tokens sent to expert (non-differentiable) and is the mean routing probability assigned to expert (differentiable). When expert receives too many tokens ( is high), the gradient reduces , lowering the expert's routing scores and eventually reducing the fraction of tokens it receives.
Q2: What is the difference between token choice and expert choice routing?
Token choice: each token selects its top-k experts based on its own routing scores. Standard in production models (Mixtral, DeepSeek). Simple, works for autoregressive inference, but requires auxiliary load balancing loss to prevent collapse. Expert choice: each expert selects the top-m tokens from the batch to process. Guarantees perfect load balance by construction, requires no auxiliary loss. But some tokens may receive 0 or more than k experts' processing, and it cannot be used directly for autoregressive inference (requires the full batch, not available during token-by-token generation). Expert choice is primarily useful as a training technique.
Q3: Why did Switch Transformer use top-1 routing instead of top-2?
Switch Transformer used top-1 to simplify the architecture and reduce communication overhead in distributed training. With top-1, each token is sent to exactly one expert on one device, minimizing all-to-all communication. With top-2 or higher, each token may need to be sent to experts on different devices, increasing communication cost. Empirically, Switch Transformer found that top-1 performed comparably to top-2 on most tasks while being more computationally efficient. The simplification also eliminated the need to normalize and combine multiple expert outputs, making both forward and backward passes simpler. The main downside is reduced robustness: a wrong routing decision means no correction from a second expert.
Q4: What is the z-loss and why is it needed in addition to the load balance loss?
The z-loss penalizes the magnitude of router logits: . It's needed because, without it, router logits can grow to very large values in magnitude, causing the softmax distribution to become essentially one-hot (all probability concentrated on one expert). When logits are very large, gradients vanish in the softmax, making training unstable. The z-loss keeps logit magnitudes reasonable, maintaining softmax entropy and training stability. Zoph et al. (2022, ST-MoE) introduced z-loss after observing instability in large MoE models without it.
Q5: How would you diagnose a poorly routing MoE model in production?
Monitoring approach: (1) Track per-expert token fraction per layer - ideal is for N experts. More than for any single expert suggests beginning collapse. (2) Compute the coefficient of variation (CV) of expert utilization - CV greater than 0.5 is a warning sign. (3) Monitor router entropy per layer: . Decreasing entropy over training means the router is becoming more one-hot, potentially over-specializing. (4) Check token drop rate if using capacity-bounded routing - high drop rates indicate severe load imbalance. (5) Compare quality metrics stratified by input domain: if the model performs well on common domains and poorly on rare ones, rare domains may be routed to undertrained experts. Remediation: increase auxiliary loss alpha, reduce learning rate to slow the feedback loop, or switch to expert choice routing for training.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the MoE Router & Expert Selection demo on the EngineersOfAI Playground - no code required.
:::
