Skip to main content

Router Mechanisms - How Tokens Get Assigned to Experts

The Traffic Controller Problem

Imagine running an airport. Thousands of flights arrive every day. Each flight needs to be assigned to a gate. If you assign all flights to gate 1, gate 1 is overwhelmed and every other gate sits empty. If you assign randomly, you get chaos with no optimization. What you want is intelligent routing: planes going to terminal B get assigned to gates in terminal B, and the load is roughly balanced across all gates.

The routing problem in MoE models is structurally identical. Thousands of tokens arrive per second. Each token needs to be assigned to one or more of NN experts. If your router sends all tokens to expert 1, that expert handles everything (and is effectively a single dense model - you've wasted the other N-1 experts). If you route randomly, you lose all the benefits of expert specialization.

The router is the most critical component in a MoE architecture. A good router enables specialization and load balancing simultaneously. A bad router causes expert collapse, training instability, and wasted capacity.

This lesson covers the algorithms researchers have developed to solve the routing problem.


Why This Exists - The Difficulties of Routing

At first glance, routing seems straightforward: train a linear layer to assign tokens to experts, and the model will learn good assignments. In practice, several problems arise:

Expert collapse: a small subset of experts receives the vast majority of tokens. The other experts receive so few tokens that they never learn meaningful representations. The model degenerates toward a small dense model.

Discrete selection bottleneck: the top-k selection operation is discrete (non-differentiable). Gradients can only flow through the routing weights (the softmax scores), not through the routing decisions (which experts were selected). This creates a challenging optimization landscape.

Load imbalance amplification: early in training, when router weights are random, some experts happen to receive slightly more tokens. Those experts update more, becoming slightly better. Better experts receive higher router scores, receiving even more tokens. The feedback loop amplifies initial random noise into severe imbalance.

Communication bottleneck: in distributed training with expert parallelism, routing tokens to different devices creates communication overhead. Imbalanced routing means some devices are overloaded while others are idle.

These problems motivated the development of the routing mechanisms described in this lesson.


The Linear Router - Standard Token Choice

The most common routing approach: a single linear transformation that scores all NN experts, followed by top-k selection.

The Basic Formulation

For token representation xRdx \in \mathbb{R}^d:

scores(x)=Softmax(xWr+br)\text{scores}(x) = \text{Softmax}(x W_r + b_r)

where WrRd×NW_r \in \mathbb{R}^{d \times N} and brRNb_r \in \mathbb{R}^N (the router weight matrix and bias).

The top-k experts are selected, and their scores are renormalized:

s^i={sijTopK(s,k)sjiTopK(s,k)0otherwise\hat{s}_i = \begin{cases} \frac{s_i}{\sum_{j \in \text{TopK}(s, k)} s_j} & i \in \text{TopK}(s, k) \\ 0 & \text{otherwise} \end{cases}

The MoE output is then:

y=i=1Ns^iEi(x)y = \sum_{i=1}^{N} \hat{s}_i \cdot E_i(x)

The Load Balancing Problem

Without additional mechanisms, linear routers catastrophically collapse. To see why: consider the beginning of training, when all router weights are random. By random chance, expert 3 gets slightly more tokens than expert 4. Expert 3 updates more per batch, becoming slightly better at its tasks. The router learns to assign slightly higher scores to expert 3. This creates a feedback loop:

More tokens → more updates → higher scores → even more tokens → complete collapse

After convergence, you might have expert 3 handling 80% of all tokens and experts 1, 2, 4–8 handling 2–3% each. This is expert collapse.

import torch
import torch.nn as nn
import torch.nn.functional as F


def load_balance_loss_switch(
router_logits: torch.Tensor, # [T, N] - T tokens, N experts
top_k_indices: torch.Tensor, # [T, k] - selected expert indices
n_experts: int,
alpha: float = 0.01,
) -> torch.Tensor:
"""
Switch Transformer load balancing loss (Fedus et al., 2022).

Minimizes the inner product of:
- f_i: fraction of tokens dispatched to expert i (actual load)
- p_i: average softmax probability assigned to expert i (desired load)

This loss is minimized when all experts receive equal fractions of tokens.
The gradient flows through p_i (the softmax probabilities), not f_i (the discrete counts).

Args:
router_logits: Raw router outputs before softmax
top_k_indices: Expert indices selected for each token
n_experts: Total number of experts
alpha: Loss scale factor

Returns:
Scalar auxiliary load balancing loss
"""
T = router_logits.shape[0]

# Compute routing probabilities (differentiable)
router_probs = F.softmax(router_logits, dim=-1) # [T, N]
p = router_probs.mean(dim=0) # [N] - mean probability per expert

# Compute token fraction per expert (non-differentiable counts)
# One-hot encode expert assignments
one_hot = torch.zeros(T, n_experts, device=router_logits.device)
for k_idx in range(top_k_indices.shape[1]):
one_hot.scatter_(1, top_k_indices[:, k_idx:k_idx+1], 1.0)
# Note: if top_k > 1, a token can contribute to multiple experts

f = one_hot.mean(dim=0) # [N] - fraction of tokens per expert

# Load balance loss: n_experts * sum(f_i * p_i)
# When perfectly balanced: f_i = p_i = 1/n_experts for all i
# Loss = n_experts * sum((1/n)^2) = n_experts * n * (1/n)^2 = 1
load_balance = alpha * n_experts * (f * p).sum()

return load_balance

Noisy Top-K Gating - Shazeer et al. (2017)

The original sparse MoE paper (Shazeer et al., 2017) introduced noisy top-k gating to encourage exploration during training and prevent early collapse.

The idea: add Gaussian noise to the router logits before top-k selection during training. This noise ensures that all experts have some probability of being selected even if their baseline scores are low, giving them a chance to learn and develop useful representations.

noisy_logits(x)=xWr+ϵSoftplus(xWnoise)\text{noisy\_logits}(x) = x W_r + \epsilon \cdot \text{Softplus}(x W_{\text{noise}})

where ϵN(0,1)\epsilon \sim \mathcal{N}(0, 1) and WnoiseW_{\text{noise}} is a learned weight matrix that controls per-token noise magnitude.

The noise has magnitude proportional to a second learned transformation WnoiseW_{\text{noise}}, allowing the model to learn appropriate noise levels for different tokens.

class NoisyTopKRouter(nn.Module):
"""
Noisy Top-K router from Shazeer et al. (2017).

Adds learned Gaussian noise to router logits during training,
encouraging exploration and preventing expert collapse.
"""

def __init__(self, d_model: int, n_experts: int, top_k: int = 2):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k

# Standard router weights
self.w_router = nn.Linear(d_model, n_experts, bias=False)

# Noise magnitude weights (learned)
self.w_noise = nn.Linear(d_model, n_experts, bias=False)

def forward(
self,
x: torch.Tensor,
training: bool = True,
) -> tuple:
"""
Args:
x: [batch * seq_len, d_model]
training: Whether to add noise (only during training)

Returns:
routing_weights: [T, n_experts] normalized routing weights
top_k_indices: [T, top_k] selected expert indices
"""
# Compute base logits
logits = self.w_router(x) # [T, N]

if training:
# Add Gaussian noise scaled by learned magnitude
noise_std = F.softplus(self.w_noise(x)) # [T, N], always positive
noise = torch.randn_like(logits)
logits = logits + noise * noise_std

# Softmax over all experts
scores = F.softmax(logits, dim=-1) # [T, N]

# Select top-k
top_k_scores, top_k_indices = torch.topk(scores, self.top_k, dim=-1)

# Normalize top-k scores to sum to 1
top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-8)

# Build full routing weight matrix [T, N] with zeros for non-selected
routing_weights = torch.zeros_like(scores)
routing_weights.scatter_(1, top_k_indices, top_k_weights)

return routing_weights, top_k_indices

The Auxiliary Load Balancing Loss in Detail

The Switch Transformer load balance loss is the standard approach today, but it's worth understanding exactly why it works and its limitations.

Why the gradient doesn't flow through f_i: the fraction of tokens dispatched to expert ii, fif_i, is computed by counting how many tokens selected expert ii in their top-k. This counting operation is non-differentiable - you can't backpropagate through an argmax or topk. So fif_i has no gradient.

How the gradient flows through p_i: the mean routing probability pi=1Ttst,ip_i = \frac{1}{T}\sum_t s_{t,i} is computed from the softmax outputs, which are differentiable. So gradients can flow back through pip_i to the router weights.

The mechanism: the loss Laux=αNifipiL_{\text{aux}} = \alpha \cdot N \cdot \sum_i f_i \cdot p_i is minimized when fipif_i \cdot p_i is small for all ii. Since fif_i has no gradient, the optimizer reduces pip_i for experts that receive many tokens (high fif_i). This lowers those experts' routing probabilities, ultimately routing fewer tokens to them in future batches. The feedback loop that caused collapse is reversed.

def z_loss(
router_logits: torch.Tensor,
beta: float = 0.001,
) -> torch.Tensor:
"""
Z-loss from ST-MoE (Zoph et al., 2022).

Additional regularization that prevents router logits from becoming
too large in magnitude, which causes numerical instability in large
MoE models.

L_z = beta * (1/T) * sum_t (log sum_n exp(logit_{t,n}))^2

This penalizes large logit magnitudes, keeping the softmax distribution
from becoming too sharp (near one-hot).
"""
# Log-sum-exp of logits for each token
log_z = torch.logsumexp(router_logits, dim=-1) # [T]

# Z-loss: penalize large log-partition function
z_loss = beta * (log_z ** 2).mean()

return z_loss

Expert Choice Routing - An Alternative Paradigm

Standard routing is token choice: each token selects its top-k experts. An alternative is expert choice routing: each expert selects its top-k tokens.

Zhou et al. (2022) "Mixture-of-Experts with Expert Choice Routing" proposed this as a solution to load imbalance. With expert choice, each expert selects exactly mm tokens per batch (the mm tokens for which the expert is most relevant). This guarantees perfect load balance by construction.

ExpertChoicee,b=TopKtbatch(s(xt,e),m)\text{ExpertChoice}_{e,b} = \text{TopK}_{t \in \text{batch}} \left( s(x_t, e), m \right)

Each expert selects the top-mm tokens from the batch, where m=TkNm = \frac{T \cdot k}{N} ensures each expert handles an equal share of tokens.

Advantages:

  • Perfect load balance by construction - no auxiliary loss needed
  • Each token is guaranteed to be processed by at least 0 experts (though average is k)
  • Simplifies distributed training (experts have predictable workloads)

Disadvantages:

  • Some tokens may be processed by 0 experts, others by 3+ (unequal coverage)
  • Cannot be used for autoregressive generation at inference time: you don't have the full batch of tokens when generating token-by-token
  • Batch-level coupling makes inference harder
class ExpertChoiceRouter(nn.Module):
"""
Expert Choice Router (Zhou et al., 2022).

Each expert selects the top-m tokens from the batch,
guaranteeing perfect load balance.

NOTE: This is for training only. Inference requires a different
approach (e.g., token choice with smaller capacity factor).
"""

def __init__(self, d_model: int, n_experts: int, expert_capacity: int):
super().__init__()
self.n_experts = n_experts
self.expert_capacity = expert_capacity # m tokens per expert
self.router = nn.Linear(d_model, n_experts, bias=False)

def forward(self, x: torch.Tensor) -> tuple:
"""
Args:
x: [T, d_model] - T tokens in the batch

Returns:
dispatch_mask: [T, n_experts] sparse assignment matrix
combine_weights: [T, n_experts] weighted combination matrix
"""
T = x.shape[0]

# Compute scores: how relevant is each token for each expert?
scores = F.softmax(self.router(x), dim=-1) # [T, n_experts]

# Each expert selects top-m tokens (token dimension = dim 0)
# Transpose: [n_experts, T]
scores_T = scores.T # [n_experts, T]

# Top-m selection per expert
top_m_scores, top_m_indices = torch.topk(
scores_T,
self.expert_capacity,
dim=-1,
) # Both [n_experts, m]

# Build dispatch mask [T, n_experts]: 1 if expert e processes token t
dispatch_mask = torch.zeros(T, self.n_experts, device=x.device)
for expert_idx in range(self.n_experts):
dispatch_mask[top_m_indices[expert_idx], expert_idx] = 1.0

# Normalize combine weights (subset of scores for selected tokens)
combine_weights = scores * dispatch_mask # [T, n_experts]
combine_sum = combine_weights.sum(dim=-1, keepdim=True).clamp(min=1e-8)
combine_weights = combine_weights / combine_sum

return dispatch_mask, combine_weights

Token Choice vs. Expert Choice - The Trade-Off

AspectToken ChoiceExpert Choice
Load balanceRequires auxiliary lossPerfect by construction
Token coverageEvery token processed by exactly k expertsSome tokens may get 0 or k+ experts
Inference compatibilityWorks naturally (greedy per-token)Requires batch-level aggregation
Implementation complexitySimplerMore complex, batch-dependent
Typical useProduction models (Mixtral, DeepSeek)Research, training-time only
Drop tokens?Only if capacity exceededNever (each expert processes exactly m tokens)

For production models, token choice with auxiliary load balancing is the standard approach because it's compatible with autoregressive inference. Expert choice is primarily a training technique that requires modifications for deployment.


Switch Transformer - Top-1 Routing for Simplicity

Switch Transformer (Fedus et al., 2022) simplified MoE routing by using top-1 routing: each token is sent to exactly one expert. This eliminates the complexity of combining multiple expert outputs and reduces communication overhead in distributed training.

The key insight: empirically, top-1 routing performs comparably to top-2 or higher on most tasks, while being substantially simpler to implement and cheaper to run.

Top-1 routing is also more stable: with only one expert active per token, there's no possibility of two experts "canceling out" if they have conflicting representations.

class SwitchRouter(nn.Module):
"""
Switch Transformer router: top-1 routing with auxiliary load balance loss.
(Fedus et al., 2022)

Simplified MoE with exactly 1 expert per token.
"""

def __init__(
self,
d_model: int,
n_experts: int,
capacity_factor: float = 1.25,
aux_loss_alpha: float = 0.01,
):
super().__init__()
self.n_experts = n_experts
self.capacity_factor = capacity_factor
self.aux_loss_alpha = aux_loss_alpha

self.router = nn.Linear(d_model, n_experts, bias=False)

def forward(self, x: torch.Tensor) -> tuple:
"""
Args:
x: [T, d_model]

Returns:
routing_weights: [T, n_experts] (sparse: one non-zero per row)
top_1_indices: [T] - which expert handles each token
aux_loss: Load balance auxiliary loss
n_dropped: Number of tokens dropped due to capacity overflow
"""
T = x.shape[0]

# Compute router scores
logits = self.router(x) # [T, n_experts]
probs = F.softmax(logits, dim=-1) # [T, n_experts]

# Top-1 selection: each token goes to its highest-scoring expert
top_1_scores, top_1_indices = probs.max(dim=-1) # Both [T]

# Expert capacity: max tokens per expert
# capacity_factor > 1 gives buffer for load imbalance
capacity = int(self.capacity_factor * T / self.n_experts)

# Check for overflow: if an expert receives more tokens than capacity,
# drop the excess tokens (they pass through as zeros)
n_dropped = 0
routing_weights = torch.zeros(T, self.n_experts, device=x.device)

# Fill in routing weights, respecting capacity constraints
expert_counts = torch.zeros(self.n_experts, device=x.device)
for t_idx in range(T):
e_idx = top_1_indices[t_idx].item()
if expert_counts[e_idx] < capacity:
routing_weights[t_idx, e_idx] = top_1_scores[t_idx]
expert_counts[e_idx] += 1
else:
n_dropped += 1 # Token is dropped

# Auxiliary load balance loss
# Using mean probabilities and token fractions
token_fraction = torch.zeros(self.n_experts, device=x.device)
token_fraction.scatter_add_(0, top_1_indices, torch.ones(T, device=x.device))
token_fraction = token_fraction / T

mean_probs = probs.mean(dim=0) # [n_experts]

aux_loss = self.aux_loss_alpha * self.n_experts * (token_fraction * mean_probs).sum()

return routing_weights, top_1_indices, aux_loss, n_dropped

The Importance Factor - Additional Load Balancing

Shazeer et al. (2017) introduced an additional concept: the importance of an expert, defined as the sum of routing scores assigned to it:

Importance(e)=t=1Tst,e\text{Importance}(e) = \sum_{t=1}^{T} s_{t,e}

where st,es_{t,e} is the routing score (pre-top-k normalization) assigned to expert ee for token tt.

A balanced model should have all experts with similar importance. The importance loss penalizes high variance in importance across experts:

Limportance=αCV(Importance)2L_{\text{importance}} = \alpha \cdot \text{CV}(\text{Importance})^2

where CV is the coefficient of variation (std/mean). This encourages all experts to receive similar total routing probability, complementing the per-batch load balancing loss.


Production Engineering Notes

Monitoring Router Behavior

In production, monitoring the router is essential for diagnosing problems:

class RouterMonitor:
"""
Monitor router behavior during training and inference.
Track expert utilization, load balance, and routing patterns.
"""

def __init__(self, n_experts: int, n_layers: int):
self.n_experts = n_experts
self.n_layers = n_layers
self.reset_stats()

def reset_stats(self):
self.expert_counts = {
layer: [0] * self.n_experts
for layer in range(self.n_layers)
}
self.total_tokens = 0

def record_routing(
self,
layer_idx: int,
top_k_indices: torch.Tensor,
):
"""Record which experts handled which tokens."""
self.total_tokens += top_k_indices.shape[0]
for k_idx in range(top_k_indices.shape[1]):
for expert_idx in top_k_indices[:, k_idx].tolist():
self.expert_counts[layer_idx][expert_idx] += 1

def load_balance_stats(self) -> dict:
"""Compute load balance statistics across all layers."""
stats = {}
for layer_idx, counts in self.expert_counts.items():
total = sum(counts)
if total == 0:
continue
fractions = [c / total for c in counts]
ideal_fraction = 1.0 / self.n_experts

# CV of expert utilization
mean = sum(fractions) / len(fractions)
variance = sum((f - mean)**2 for f in fractions) / len(fractions)
cv = (variance ** 0.5) / mean if mean > 0 else 0

stats[f"layer_{layer_idx}"] = {
"expert_fractions": fractions,
"max_fraction": max(fractions),
"min_fraction": min(fractions),
"cv": cv, # coefficient of variation - lower is more balanced
"load_balance_ok": max(fractions) < 3.0 * ideal_fraction,
}

return stats

Auxiliary Loss Weight Tuning

The auxiliary loss coefficient α\alpha requires careful tuning:

  • Too low (α\alpha less than 0.001): expert collapse occurs during training
  • Too high (α\alpha greater than 0.1): the auxiliary loss dominates, hurting primary task performance
  • Recommended starting point: α=0.01\alpha = 0.01 (used by Switch Transformer, Mixtral)

Monitor both the auxiliary loss value and the actual expert utilization variance. The goal is not zero variance (some specialization is beneficial) but preventing extreme collapse.


:::danger Common Mistake: Forgetting the Auxiliary Loss The single most common mistake in implementing MoE models is training without the auxiliary load balancing loss or with α\alpha set too small. Without it, expert collapse is virtually guaranteed within the first few thousand training steps. Always include the auxiliary loss and monitor expert utilization during early training. :::

:::warning Token Dropping in Production Switch Transformer and similar architectures can "drop" tokens when an expert's capacity buffer overflows. Dropped tokens are replaced with zeros, effectively causing those tokens to skip the MoE layer. This is acceptable during training (dropout-like regularization) but can cause quality degradation at inference. For production serving, either use a capacity factor greater than 1.0 (buffering extra capacity) or switch to an algorithm that never drops tokens. :::

:::tip Diagnosing Router Issues If a MoE model's performance is poor despite reasonable perplexity, check router behavior first. Signs of router problems: (1) Expert utilization variance greater than 5x the ideal uniform distribution. (2) One or two experts consistently receiving greater than 50% of tokens. (3) Some experts receiving fewer than 1% of tokens. (4) Router confidence increasing monotonically (one-hot routing) - the model is degenerating to a single expert. All of these suggest the auxiliary loss is too small or the router is otherwise misconfigured. :::


Interview Questions and Answers

Q1: Explain the expert collapse problem and how load balancing loss prevents it.

Expert collapse occurs when the routing process creates a feedback loop: experts that receive more tokens update more and become better, causing the router to assign them even more tokens. This degenerates the MoE into a model with only a few active experts, wasting the remaining capacity. The auxiliary load balancing loss prevents this by penalizing imbalanced routing. The loss is Laux=αNifipiL_{\text{aux}} = \alpha N \sum_i f_i p_i, where fif_i is the fraction of tokens sent to expert ii (non-differentiable) and pip_i is the mean routing probability assigned to expert ii (differentiable). When expert ii receives too many tokens (fif_i is high), the gradient reduces pip_i, lowering the expert's routing scores and eventually reducing the fraction of tokens it receives.

Q2: What is the difference between token choice and expert choice routing?

Token choice: each token selects its top-k experts based on its own routing scores. Standard in production models (Mixtral, DeepSeek). Simple, works for autoregressive inference, but requires auxiliary load balancing loss to prevent collapse. Expert choice: each expert selects the top-m tokens from the batch to process. Guarantees perfect load balance by construction, requires no auxiliary loss. But some tokens may receive 0 or more than k experts' processing, and it cannot be used directly for autoregressive inference (requires the full batch, not available during token-by-token generation). Expert choice is primarily useful as a training technique.

Q3: Why did Switch Transformer use top-1 routing instead of top-2?

Switch Transformer used top-1 to simplify the architecture and reduce communication overhead in distributed training. With top-1, each token is sent to exactly one expert on one device, minimizing all-to-all communication. With top-2 or higher, each token may need to be sent to experts on different devices, increasing communication cost. Empirically, Switch Transformer found that top-1 performed comparably to top-2 on most tasks while being more computationally efficient. The simplification also eliminated the need to normalize and combine multiple expert outputs, making both forward and backward passes simpler. The main downside is reduced robustness: a wrong routing decision means no correction from a second expert.

Q4: What is the z-loss and why is it needed in addition to the load balance loss?

The z-loss penalizes the magnitude of router logits: Lz=β1Tt(lognexp(ltn))2L_z = \beta \cdot \frac{1}{T}\sum_t (\log \sum_n \exp(l_{tn}))^2. It's needed because, without it, router logits can grow to very large values in magnitude, causing the softmax distribution to become essentially one-hot (all probability concentrated on one expert). When logits are very large, gradients vanish in the softmax, making training unstable. The z-loss keeps logit magnitudes reasonable, maintaining softmax entropy and training stability. Zoph et al. (2022, ST-MoE) introduced z-loss after observing instability in large MoE models without it.

Q5: How would you diagnose a poorly routing MoE model in production?

Monitoring approach: (1) Track per-expert token fraction per layer - ideal is 1/N1/N for N experts. More than 3/N3/N for any single expert suggests beginning collapse. (2) Compute the coefficient of variation (CV) of expert utilization - CV greater than 0.5 is a warning sign. (3) Monitor router entropy per layer: H=ipilogpiH = -\sum_i p_i \log p_i. Decreasing entropy over training means the router is becoming more one-hot, potentially over-specializing. (4) Check token drop rate if using capacity-bounded routing - high drop rates indicate severe load imbalance. (5) Compare quality metrics stratified by input domain: if the model performs well on common domains and poorly on rare ones, rare domains may be routed to undertrained experts. Remediation: increase auxiliary loss alpha, reduce learning rate to slow the feedback loop, or switch to expert choice routing for training.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the MoE Router & Expert Selection demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.