Training MoE Models
The Night the MoE Diverged
Your team is three weeks into training a 200B MoE model. Everything looks fine until step 45,000. The loss curve, which has been steadily decreasing, suddenly spikes. Then stabilizes at a higher loss than before the spike. Then crashes again. The training never fully recovers. Six weeks of GPU hours - gone.
Post-mortem reveals the culprit: expert collapse had been building quietly for days. By step 45,000, two of sixteen experts were handling over 80% of all tokens. Those experts had accumulated gradients ten times larger than other experts, causing their weights to drift outside the distribution that the optimizer state had been tracking. When the weight update hit, the loss exploded.
This is not a hypothetical. MoE training failures like this happened repeatedly as the field was figuring out how to train these models stably. The solutions - auxiliary losses, capacity factors, careful initialization, z-loss - were discovered the hard way.
Why This Exists - MoE Training is Qualitatively Harder Than Dense Training
Training a dense transformer is a well-solved engineering problem. The loss curves are smooth and predictable. Stability issues are usually attributable to learning rate, batch size, or gradient clipping. The optimization landscape is well-understood.
Training a MoE model introduces several entirely new failure modes:
Expert collapse: already discussed. Two or three experts dominate, others receive no signal, training degenerates.
Token dropping: when too many tokens are routed to the same expert and the expert's capacity buffer overflows, excess tokens are discarded (processed as zeros). This creates inconsistent training - some tokens receive proper gradient updates, others receive none.
Router noise: the routing decisions themselves are discrete (top-k is not differentiable through the selection). This means gradients through the router are noisy and approximate, making the router harder to train than the experts.
Communication bottleneck: in distributed training with expert parallelism, tokens must be sent from their host device to the device hosting the selected expert. This all-to-all communication pattern is expensive and sensitive to load imbalance.
Layer-wise instability: in very deep MoE models, instability in one MoE layer can cascade to others. An expert that suddenly receives 10x more tokens generates 10x larger gradients, which propagate to the layers before and after it.
Expert Parallelism - Distributing Experts Across Devices
The key infrastructure concept in MoE training is expert parallelism. Because experts are independent (they don't share weights and are applied independently to tokens), they can be distributed across different devices.
In a typical configuration:
- experts are spread across devices
- Each device hosts experts
- For each MoE forward pass, tokens must be routed from their source device to the device hosting their selected expert
The communication pattern involves two all-to-all operations per MoE layer:
- Dispatch: send tokens from their host GPU to the GPU hosting their selected expert
- Combine: after the expert processes the token, send the result back to the token's host GPU
For a model with 32 MoE layers, this means 64 all-to-all communication rounds per forward pass - significantly more communication overhead than tensor parallelism (which uses all-reduce operations that are generally faster).
import torch
import torch.distributed as dist
from typing import List, Optional
def all_to_all_single_expert_dispatch(
local_tokens: torch.Tensor,
routing_indices: torch.Tensor, # [T] - which expert for each token
n_experts: int,
world_size: int,
rank: int,
) -> torch.Tensor:
"""
Dispatch tokens to their assigned experts across GPUs.
Each GPU hosts n_experts / world_size experts.
Tokens are sent to the GPU hosting their target expert.
Args:
local_tokens: [T_local, d_model] - tokens on this GPU
routing_indices: [T_local] - expert index for each token
n_experts: Total number of experts
world_size: Number of GPUs
rank: This GPU's rank
Returns:
received_tokens: [T_received, d_model] - tokens assigned to local experts
"""
experts_per_gpu = n_experts // world_size
# Determine which GPU each token should go to
target_gpu = routing_indices // experts_per_gpu # [T_local]
# Group tokens by target GPU
output_splits = []
input_splits = []
send_tensors = []
for gpu in range(world_size):
mask = (target_gpu == gpu)
gpu_tokens = local_tokens[mask]
send_tensors.append(gpu_tokens)
output_splits.append(gpu_tokens.shape[0])
# Exchange information about how many tokens each GPU is sending/receiving
# (simplified - in practice use all-to-all to exchange counts first)
send_counts = torch.tensor(output_splits, dtype=torch.int64)
recv_counts = torch.zeros_like(send_counts)
dist.all_to_all_single(recv_counts, send_counts)
# Concatenate tokens to send and receive buffer
send_buffer = torch.cat(send_tensors, dim=0)
recv_buffer = torch.zeros(
recv_counts.sum().item(),
local_tokens.shape[1],
device=local_tokens.device,
dtype=local_tokens.dtype,
)
# Perform all-to-all
dist.all_to_all_single(
recv_buffer, send_buffer,
output_split_sizes=recv_counts.tolist(),
input_split_sizes=output_splits,
)
return recv_buffer
The Capacity Factor - Controlling Token Dropping
Each expert has a capacity factor: a multiplier that determines how many tokens an expert can handle per batch. If more tokens are routed to an expert than its capacity allows, the excess tokens are dropped (replaced with zeros).
where is the total number of tokens in the batch and is the number of experts.
With perfect load balance ( tokens per expert), capacity_factor = 1.0 would mean zero dropped tokens. In practice, routing is not perfectly balanced, so capacity_factor > 1.0 is used as a buffer:
- capacity_factor = 1.0: strict - any imbalance causes drops
- capacity_factor = 1.25: common default - allows 25% imbalance per expert
- capacity_factor = 2.0: generous - very rarely drops tokens but wastes memory
def compute_expert_capacity(
batch_tokens: int,
n_experts: int,
top_k: int = 2,
capacity_factor: float = 1.25,
) -> int:
"""
Compute the maximum number of tokens an expert can handle per batch.
With top_k routing, each token goes to k experts, so the total
"demand" across all experts is k * T. With N experts, the average
demand per expert is k * T / N.
capacity = floor(capacity_factor * k * T / N)
"""
avg_tokens_per_expert = top_k * batch_tokens / n_experts
capacity = int(capacity_factor * avg_tokens_per_expert)
return max(capacity, 1) # At least 1 token per expert
class ExpertWithCapacity(nn.Module):
"""
Expert that enforces a capacity limit, dropping excess tokens.
This is the mechanism used in Switch Transformer and GShard.
"""
def __init__(self, expert: nn.Module, capacity: int):
super().__init__()
self.expert = expert
self.capacity = capacity
def forward(
self,
tokens: torch.Tensor, # [T_e, d_model] - all tokens routed here
routing_weights: torch.Tensor, # [T_e] - weights for each token
) -> tuple:
"""
Process tokens up to capacity, drop the rest.
Returns:
output: [T_e, d_model] - expert output (zeros for dropped tokens)
n_dropped: Number of tokens dropped
"""
T_e = tokens.shape[0]
n_dropped = max(0, T_e - self.capacity)
if n_dropped > 0:
# Process only up to capacity
tokens_to_process = tokens[:self.capacity]
weights_to_use = routing_weights[:self.capacity]
expert_out = self.expert(tokens_to_process)
weighted_out = expert_out * weights_to_use.unsqueeze(-1)
# Pad with zeros for dropped tokens
full_output = torch.zeros_like(tokens)
full_output[:self.capacity] = weighted_out
else:
expert_out = self.expert(tokens)
full_output = expert_out * routing_weights.unsqueeze(-1)
return full_output, n_dropped
Token dropping is controversial. It simplifies distributed training (each expert processes exactly the same number of tokens, enabling efficient batching) but introduces training inconsistency (some tokens never get processed by their preferred expert during the update). In practice, with a capacity_factor of 1.25 and a good auxiliary loss, drop rates below 1% are achievable and the training degradation is minimal.
Load Imbalance - The Persistent Problem
Despite auxiliary losses and noisy gating, load imbalance remains a challenge at scale. Several papers have documented that:
-
Early training imbalance: in the first few thousand steps, before the auxiliary loss has time to correct routing, severe imbalance can occur and set patterns that are hard to break later
-
Domain-induced imbalance: training data has unequal topic distribution. If 30% of the data is code and 70% is natural language, code-specialist experts may receive 30% of all tokens while non-specialist experts receive tiny fractions
-
Depth-dependent imbalance: in deep MoE models, imbalance tends to be worse in early layers (where token representations are less semantically meaningful) than in later layers (where tokens are more semantically distinct)
def monitor_load_balance(
expert_counts: torch.Tensor, # [n_experts] - tokens per expert this batch
n_experts: int,
batch_tokens: int,
log_prefix: str = "",
) -> dict:
"""
Compute and log load balance statistics.
Call this every N training steps to monitor router health.
"""
fractions = expert_counts / batch_tokens
ideal_fraction = 1.0 / n_experts
max_deviation = (fractions - ideal_fraction).abs().max().item()
# Coefficient of variation (lower is better)
mean = fractions.mean().item()
std = fractions.std().item()
cv = std / mean if mean > 0 else 0
# Entropy of the expert distribution (higher is better)
# Perfect balance = log(n_experts)
log_fractions = torch.log(fractions + 1e-8)
entropy = -(fractions * log_fractions).sum().item()
max_entropy = torch.log(torch.tensor(n_experts, dtype=torch.float)).item()
normalized_entropy = entropy / max_entropy # 1.0 = perfect balance
overloaded = (fractions > 3.0 * ideal_fraction).sum().item()
underloaded = (fractions < 0.1 * ideal_fraction).sum().item()
stats = {
"expert_fractions": fractions.tolist(),
"max_fraction": fractions.max().item(),
"min_fraction": fractions.min().item(),
"coefficient_of_variation": cv,
"normalized_entropy": normalized_entropy,
"n_overloaded_experts": overloaded,
"n_underloaded_experts": underloaded,
"load_balance_ok": cv < 0.3 and overloaded == 0,
}
if log_prefix:
status = "OK" if stats["load_balance_ok"] else "WARNING"
print(f"{log_prefix} [{status}] Load balance: "
f"max={stats['max_fraction']:.2%}, "
f"min={stats['min_fraction']:.2%}, "
f"cv={cv:.3f}, "
f"entropy={normalized_entropy:.2%}")
return stats
GShard - Scaling MoE to 600B Parameters
Lepikhin et al. (2021) "GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding" was the first to demonstrate MoE training at truly massive scale - 600 billion parameters for a multilingual neural machine translation model.
Key innovations from GShard:
Automatic sharding: GShard introduced an annotation system where model parallelism patterns are specified declaratively, and a compiler automatically generates the distributed execution plan. This made it practical to specify and train complex distributed models without writing custom distribution code.
Top-2 routing with balanced assignment: GShard used top-2 routing with a novel assignment procedure: the first expert is selected greedily (highest score), the second expert uses a random selection weighted by scores, with the constraint that the two experts cannot be on the same device. This ensures cross-device communication happens for every token.
Per-expert capacity buffers: GShard formalized the capacity buffer concept, showing that capacity_factor = 1.0 causes excessive dropping at scale, and 1.25–2.0 is appropriate.
Auxiliary losses for multilingual balance: GShard found that with highly diverse multilingual data, experts naturally specialize by language, which is actually beneficial for translation quality. But within a language, load imbalance still occurs. The auxiliary loss needed to be tuned carefully to encourage within-language balance without destroying cross-language specialization.
Training Instability - Why MoE Models Are Harder to Train
Beyond expert collapse and token dropping, MoE models exhibit higher training instability than dense models:
Loss spikes: sudden increases in training loss, usually caused by expert weight drift after periods of imbalanced routing. Mitigations: gradient clipping (clip at 1.0 rather than the 1.0 standard), reduced learning rate for expert weights, and the z-loss.
Router gradient noise: because the routing decisions are discrete (not differentiable), the gradient signal for the router is inherently noisier than for the expert weights. The router updates based on the gradient of routing scores (continuous), not routing decisions (discrete). This means the router can oscillate between routing strategies.
Depth amplification: in a 32-layer MoE model, instability in layer 5's router can propagate forward through 27 more layers of transformation, amplifying the error. Dense models of equivalent depth don't have this issue because every layer's computation is deterministic given the input.
class StableMoETrainer:
"""
Training wrapper with MoE-specific stability techniques.
"""
def __init__(
self,
model,
optimizer,
scheduler,
max_grad_norm: float = 1.0, # Aggressive clipping for MoE stability
aux_loss_alpha: float = 0.01, # Load balance loss weight
z_loss_beta: float = 0.001, # Z-loss weight for router stability
check_interval: int = 100, # How often to check load balance
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.max_grad_norm = max_grad_norm
self.aux_loss_alpha = aux_loss_alpha
self.z_loss_beta = z_loss_beta
self.check_interval = check_interval
self.step = 0
self.load_balance_history = []
def training_step(self, batch: dict) -> dict:
"""
Single training step with MoE stability checks.
"""
self.model.train()
self.optimizer.zero_grad()
# Forward pass - collect auxiliary losses from all MoE layers
outputs = self.model(**batch)
main_loss = outputs.loss
# Collect auxiliary losses from all MoE layers
aux_losses = outputs.get("aux_losses", []) # List of per-layer aux losses
z_losses = outputs.get("z_losses", []) # List of per-layer z-losses
total_aux_loss = sum(aux_losses) if aux_losses else torch.tensor(0.0)
total_z_loss = sum(z_losses) if z_losses else torch.tensor(0.0)
# Combined loss
total_loss = main_loss + self.aux_loss_alpha * total_aux_loss + total_z_loss
# Backward pass
total_loss.backward()
# Gradient clipping (critical for MoE stability)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
# Warning if gradients are being aggressively clipped
if grad_norm > 10 * self.max_grad_norm:
print(f"WARNING: Very high gradient norm: {grad_norm:.2f}. "
"Possible training instability.")
self.optimizer.step()
self.scheduler.step()
self.step += 1
# Periodic load balance monitoring
if self.step % self.check_interval == 0:
self._check_load_balance()
return {
"loss": main_loss.item(),
"aux_loss": total_aux_loss.item(),
"z_loss": total_z_loss.item(),
"total_loss": total_loss.item(),
"grad_norm": grad_norm.item(),
}
def _check_load_balance(self):
"""Check expert utilization and warn if imbalanced."""
for layer_idx, layer in enumerate(self.model.moe_layers):
if hasattr(layer, 'expert_counts'):
stats = monitor_load_balance(
layer.expert_counts,
layer.n_experts,
batch_tokens=1, # Normalized
log_prefix=f"Step {self.step}, Layer {layer_idx}",
)
if not stats["load_balance_ok"]:
print(f"LOAD BALANCE WARNING at step {self.step}, layer {layer_idx}: "
f"Consider increasing aux_loss_alpha")
Gradient Flow Through Sparse Routing
The discrete top-k selection is not differentiable - you can't compute gradients through which experts were selected, only through the routing scores for the selected experts.
In the backward pass:
- Expert weight gradients: flow normally through the expert computations
- Routing weight gradients: flow through the routing scores (the softmax outputs) for selected experts
The routing weight gradient for expert for token is:
where is the loss, is the MoE output, and is the output of expert . This gradient only exists for selected experts (others get zero gradient through the routing).
The consequence: the router learns which combination of expert outputs best reduces the loss, but has no direct gradient signal about which experts to select - only about how much to weight the experts it already selected.
This is why load balance losses are essential: they provide an indirect gradient that encourages the router to spread its selection across all experts, even though the selection mechanism itself is non-differentiable.
:::danger Common Mistake: Ignoring Gradient Clipping MoE models require more aggressive gradient clipping than equivalent dense models. Expert weight gradients can be very large in early training when load imbalance is severe (a few experts receive many tokens, accumulate large gradients). Use clip_grad_norm with a value of 1.0 (vs. the common 2.0–5.0 for dense models). Failure to clip aggressively is a common cause of MoE training instability. :::
:::warning Token Dropping During Inference Token dropping (when expert capacity buffers overflow) is designed for training, where dropping occasional tokens is acceptable as a form of stochastic regularization. At inference time, token dropping causes quality degradation and inconsistent outputs. Set capacity_factor to infinity (no capacity limit) or to a very large value during inference, or switch to routing strategies that guarantee no dropping. :::
:::tip Monitor Expert Utilization From Step 1 Don't wait to monitor load balance. Expert collapse can begin in the first few hundred training steps. Log per-expert token fractions every 100 steps from the very beginning of training. If CV (coefficient of variation) exceeds 0.3 at any point in the first 5,000 steps, immediately increase alpha (the auxiliary loss weight) by 2x. It's much easier to prevent collapse early than to recover from it. :::
Interview Questions and Answers
Q1: What is expert parallelism and how does it differ from tensor parallelism?
Expert parallelism distributes different experts across different devices. Each device hosts a subset of experts and processes only the tokens routed to its experts. Communication happens via all-to-all operations: tokens are dispatched from their host device to the device with their selected expert, then returned after processing. Tensor parallelism, by contrast, splits individual weight matrices across devices - each device holds a shard of every weight matrix and processes all tokens for that portion of each layer. Expert parallelism gives lower communication overhead for MoE models (tokens only go to one or two devices per MoE layer) and naturally maps to the MoE structure. Tensor parallelism gives lower communication overhead for dense models (all-reduce is highly optimized). In practice, large MoE models often combine both: expert parallelism for MoE layers and tensor parallelism for attention layers.
Q2: What is the capacity factor and why is it necessary?
The capacity factor is a multiplier that determines how many tokens an expert can process per batch: where T is batch tokens and N is experts. With perfect load balance, capacity_factor = 1.0 would never drop tokens. But routing is never perfectly balanced, so some experts receive more tokens than . The capacity factor provides a buffer: with capacity_factor = 1.25, each expert can handle 25% more tokens than the perfectly-balanced share before dropping. Tokens routed to an expert beyond its capacity are dropped (replaced with zeros). This prevents memory explosion (fixed buffer size enables efficient GPU memory allocation) and communication deadlocks (each expert processes exactly capacity tokens per step, enabling deterministic scheduling).
Q3: Why does MoE training exhibit higher instability than dense model training?
Three key reasons: (1) Feedback loops from expert collapse - uneven routing causes uneven gradient updates, which further concentrates routing, creating a runaway feedback loop that can cause sudden loss spikes. (2) Non-differentiable routing - gradients can't flow through the top-k selection, only through routing scores. This makes the router's optimization landscape noisier and more discontinuous than a fully differentiable network. (3) Cross-layer propagation - instability in one MoE layer's routing can cascade through subsequent layers, as the output distribution shifts significantly and subsequent layers' routers have to adapt to a changed distribution. Mitigations: smaller learning rates, aggressive gradient clipping, z-loss for router stability, and careful monitoring.
Q4: What did GShard contribute to MoE training at scale?
GShard (Lepikhin et al., 2021) made several key contributions: (1) Demonstrated MoE scaling to 600B parameters for machine translation, proving the paradigm worked at unprecedented scale. (2) Introduced automatic sharding specifications - developers annotate which dimensions of tensors should be sharded across devices, and a compiler generates the distributed execution plan. This eliminated hand-written distribution code. (3) Formalized the capacity buffer concept with empirical guidance on capacity factors (1.25–2.0). (4) Showed that language specialization in multilingual MoE was beneficial and shouldn't be penalized by the auxiliary loss, introducing nuance into load balance objectives.
Q5: How do you handle token dropping at inference time?
At inference time, token dropping is unacceptable because it causes inconsistent, quality-degraded outputs. Solutions: (1) Set capacity_factor to a very large value (e.g., 100) so the capacity buffer never fills. This adds slight memory overhead but ensures no dropping. (2) Use a routing strategy that never drops tokens by design, such as expert choice routing with a post-hoc token-to-expert reassignment for tokens not selected. (3) For autoregressive generation, token sequences are short (one token at a time or small batches), so capacity overflow is very unlikely even with capacity_factor = 1.25 - just verify empirically. (4) For very large batch inference, monitor drop rates during production and alert if drops exceed 0.1% of tokens per batch.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Mixture of Experts (MoE) Architecture demo on the EngineersOfAI Playground - no code required.
:::
