Skip to main content

:::tip ๐ŸŽฎ Interactive Playground Visualize this concept: Try the Multi-Task Learning demo on the EngineersOfAI Playground - no code required. :::

Multi-Task Learning Systems

Fifty Tasks, One Model, and a News Feed That Feeds Three Billion Peopleโ€‹

The Facebook News Feed ranking team has a problem that would make any ML engineer's head spin. The ranking model needs to optimize simultaneously for: likes, comments, shares, reactions, clicks, video watch time, story completions, click-through-rate on links, hide-post signals, not-interested signals, reports, and roughly 40 more engagement and safety signals. Each of these represents a distinct user behavior. Each requires its own training signal. Each has its own label distribution, label frequency, and optimization characteristics.

The naive solution - train one model per task - produces 50 separate models, 50 separate feature pipelines, 50 separate serving deployments, 50 separate training schedules, and 50 separate teams responsible for keeping them healthy. This is operationally unsustainable. But more importantly, it is suboptimal: a model that learns only from like signals cannot benefit from the signal in comment data, which tells you something closely related about what content people find engaging.

In 2016, the Facebook team published "Multitask Learning as Question Answering" describing their approach: one large neural network with a shared trunk and 50 separate task-specific output heads. The shared trunk learns a single representation of users, content, and context that is useful across all tasks simultaneously. Each task head adapts that shared representation to its specific prediction objective. One model. One feature pipeline. One serving infrastructure. Fifty tasks.

The efficiency gains were immediate and significant. A single MTL model was cheaper to train, cheaper to serve, and - critically - the shared trunk produced better representations than any single-task model because it received gradient signal from all 50 tasks simultaneously. Understanding why this works, where it breaks, and how to manage the interactions between tasks is the subject of this lesson.

Why This Existsโ€‹

The Redundancy Problemโ€‹

Separate single-task models share an enormous amount of redundant computation. Consider two models: one predicting video watch time, another predicting video completion rate. Both need to understand user engagement patterns, video content features, creator quality signals, and contextual relevance. They compute essentially the same representation of the user and the video - and they do it twice, with twice the compute, twice the data storage, twice the serving cost.

Multi-task learning eliminates this redundancy by computing the shared representation once and routing it to multiple task-specific heads. The compute savings scale with the number of tasks - at 50 tasks with 80% shared computation, you save roughly 40x the feature extraction cost.

The Data Efficiency Problemโ€‹

Some tasks have abundant training data. "Did the user click?" generates billions of examples per day. "Did the user report this post as misinformation?" generates thousands. A model trained only on the rare report signal has limited generalization - it overfits to the noisy, sparse label distribution.

In MTL, the shared trunk receives dense gradient signal from the click prediction task and sparse signal from the report task. The trunk learns rich general representations from click data, and the report head benefits from those representations even though it cannot drive their learning alone. Rare-task performance consistently improves in MTL over single-task training.

The Transfer Learning Connectionโ€‹

MTL is transfer learning without the sequential step. In sequential transfer learning, you pretrain on Task A, then fine-tune on Task B. In MTL, you train on Tasks A and B simultaneously. Both approaches exploit the same insight: representations learned for one task are useful for related tasks. MTL has the additional advantage that the shared representation adapts to both tasks simultaneously, rather than being fixed by the pretraining task and potentially suboptimal for the fine-tuning task.

Historical Contextโ€‹

Multi-task learning was formalized by Caruana (1997) in the machine learning community, though the idea appears in neural network research as early as the 1980s. Caruana showed that networks trained on multiple related tasks simultaneously learned better representations than single-task networks, attributing this to "hints" - information about related tasks that constrained the learned representation in beneficial ways.

The modern deep learning application to production recommendation systems began with the 2018 Google paper "Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts" (MMoE), which introduced the key insight that different tasks may need different portions of the shared representation. Facebook's 2016 News Feed work and Alibaba's 2018 ESMM (Entire Space Multi-Task Model) for click-through-rate and conversion-rate joint training established MTL as the standard approach for large-scale recommendation.

Core Conceptsโ€‹

Hard vs Soft Parameter Sharingโ€‹

Hard parameter sharing: All tasks share the same neural network trunk with no task-specific adaptation of the shared layers. Each task has its own output head connected to the top of the shared trunk. The shared trunk learns a single representation that must serve all tasks simultaneously.

Advantages: Simple, memory-efficient, strong regularization. Disadvantages: Representation must be a jack of all trades, which hurts tasks with conflicting objectives.

Soft parameter sharing (MMoE): Multiple "expert" networks learn different aspects of the representation. Each task has a gating network that learns a weighted combination of the experts. Tasks that are closely related naturally learn similar gate weights; tasks with conflicting objectives learn different combinations of experts.

The MMoE formulation:

hk=โˆ‘i=1ngkiโ‹…fi(x)h_k = \sum_{i=1}^{n} g_k^i \cdot f_i(x)

where hkh_k is the input to task kk's head, fif_i are the expert networks, and gkig_k^i are the softmax gate weights for task kk over experts i=1,โ€ฆ,ni = 1, \ldots, n.

MTL Loss Balancingโ€‹

The critical challenge in MTL: different tasks have different loss scales, different learning rates, and different gradient magnitudes. Without careful loss balancing, the training dynamics are dominated by whichever task has the largest loss values - typically the task with the least well-calibrated output or the most noisy labels.

Naive weighted sum:

Ltotal=โˆ‘k=1KwkLk\mathcal{L}_{\text{total}} = \sum_{k=1}^{K} w_k \mathcal{L}_k

The weights wkw_k must be hand-tuned. This is error-prone and does not adapt as training progresses.

Uncertainty Weighting (Kendall et al., 2018): Learn the task weights as model parameters. The intuition: tasks with higher uncertainty (homoscedastic task uncertainty) should contribute less to the total loss.

For regression tasks:

Lkunc=12ฯƒk2Lk+logโกฯƒk\mathcal{L}_k^{\text{unc}} = \frac{1}{2\sigma_k^2} \mathcal{L}_k + \log \sigma_k

For classification tasks:

Lkunc=1ฯƒk2Lk+logโกฯƒk\mathcal{L}_k^{\text{unc}} = \frac{1}{\sigma_k^2} \mathcal{L}_k + \log \sigma_k

where ฯƒk\sigma_k is a learnable parameter per task. During training, ฯƒk\sigma_k increases for tasks that are harder to fit, automatically down-weighting them.

GradNorm (Chen et al., 2018): Balance tasks by directly controlling gradient norms. Tasks whose gradients are too large (relative to a target gradient norm) are down-weighted; tasks whose gradients are too small are up-weighted. GradNorm adaptively adjusts task weights during training to maintain a target balance.

The GradNorm update rule sets target gradient norms based on each task's relative learning progress:

w^kโ†Kโ‹…rkโˆ‘jrj\hat{w}_k \leftarrow K \cdot \frac{r_k}{\sum_j r_j}

where rk=L~k/Lk(0)r_k = \tilde{\mathcal{L}}_k / \mathcal{L}_k^{(0)} is the inverse training rate (how much task kk has been learned relative to its initial loss), and KK is the number of tasks.

Gradient Conflict Detectionโ€‹

Two tasks have conflicting gradients when their gradient vectors point in opposite directions in the shared parameter space. If task A wants to increase weight wijw_{ij} and task B wants to decrease it, training oscillates and neither task improves optimally.

Formally, task A and task B have conflicting gradients for parameter ฮธ\theta when:

cosโก(โˆ‡ฮธLA,โˆ‡ฮธLB)<0\cos(\nabla_\theta \mathcal{L}_A, \nabla_\theta \mathcal{L}_B) < 0

PCGrad (Project Conflicting Gradients, Yu et al., 2020): When conflict is detected, project each task's gradient onto the normal plane of the conflicting gradient:

gAโ€ฒ=gAโˆ’gAโ‹…gBโˆฃgBโˆฃ2gBwhenย gAโ‹…gB<0g_A' = g_A - \frac{g_A \cdot g_B}{|g_B|^2} g_B \quad \text{when } g_A \cdot g_B < 0

This removes the component of gradient gAg_A that conflicts with gradient gBg_B, allowing both tasks to make progress without cancellation.

import torch
from typing import List


def pcgrad_update(
gradients: List[torch.Tensor], # list of per-task gradients for shared params
) -> torch.Tensor:
"""
PCGrad: project conflicting gradients.
Returns the averaged projected gradient for the shared parameters.
"""
n_tasks = len(gradients)
projected = [g.clone() for g in gradients]

for i in range(n_tasks):
for j in range(n_tasks):
if i == j:
continue
gi = projected[i]
gj = gradients[j]
dot = (gi * gj).sum()
if dot < 0:
# Project out the conflicting component
projected[i] = gi - (dot / (gj * gj).sum()) * gj

return torch.stack(projected).mean(dim=0)

Negative Transfer Detectionโ€‹

Negative transfer occurs when adding a task to an MTL model makes another task worse compared to its single-task baseline. This happens when tasks have incompatible objectives or when the shared representation is stretched in incompatible directions.

Detecting negative transfer:

def compute_transfer_matrix(
single_task_losses: dict, # {task: val_loss when trained alone}
mtl_losses: dict, # {task: val_loss when trained in MTL}
) -> dict:
"""
Compute the Transfer Improvement Ratio (TIR) for each task.
TIR > 0: positive transfer (MTL helps this task)
TIR < 0: negative transfer (MTL hurts this task)
"""
tir = {}
for task in single_task_losses:
solo_loss = single_task_losses[task]
mtl_loss = mtl_losses[task]
# Positive means improvement (lower loss = better)
tir[task] = (solo_loss - mtl_loss) / solo_loss
return tir

If a task shows consistent negative transfer (TIR consistently below -0.05), remove it from the MTL group or give it its own dedicated expert in an MMoE architecture.

Routing Networks and Mixture of Expertsโ€‹

The MMoE architecture generalizes naturally to full mixture-of-experts routing:

class MMoELayer(nn.Module):
"""
Multi-gate Mixture-of-Experts layer.
Each task has its own gating network that selects how to combine expert outputs.
"""

def __init__(
self,
input_dim: int,
expert_dim: int,
n_experts: int,
n_tasks: int,
):
super().__init__()
# Expert networks (shared across tasks)
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, expert_dim),
nn.ReLU(),
nn.Linear(expert_dim, expert_dim),
)
for _ in range(n_experts)
])

# Task-specific gating networks
self.gates = nn.ModuleList([
nn.Linear(input_dim, n_experts)
for _ in range(n_tasks)
])

self.n_tasks = n_tasks

def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Returns a list of n_tasks tensors, one per task.
Each tensor is a gated combination of expert outputs.
"""
# Compute all expert outputs
expert_outputs = torch.stack(
[expert(x) for expert in self.experts], dim=1
) # (batch, n_experts, expert_dim)

task_outputs = []
for gate in self.gates:
gate_weights = torch.softmax(gate(x), dim=-1) # (batch, n_experts)
task_out = (gate_weights.unsqueeze(-1) * expert_outputs).sum(dim=1)
task_outputs.append(task_out)

return task_outputs # [n_tasks ร— (batch, expert_dim)]


class MTLRankingModel(nn.Module):
"""
Full MTL ranking model with MMoE layers and task-specific output heads.
Architecture used by Facebook News Feed and similar systems.
"""

def __init__(
self,
input_dim: int,
expert_dim: int = 128,
n_experts: int = 8,
task_names: List[str] = ["ctr", "like", "share", "watch_time"],
):
super().__init__()
self.task_names = task_names
n_tasks = len(task_names)

# Input encoding
self.input_encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
)

# MMoE layer
self.mmoe = MMoELayer(256, expert_dim, n_experts, n_tasks)

# Task-specific heads
self.task_heads = nn.ModuleDict({
task: nn.Sequential(
nn.Linear(expert_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
for task in task_names
})

def forward(self, features: torch.Tensor) -> dict:
"""Returns dict mapping task names to predicted scores."""
encoded = self.input_encoder(features)
task_representations = self.mmoe(encoded)

predictions = {}
for task, task_repr in zip(self.task_names, task_representations):
predictions[task] = self.task_heads[task](task_repr).squeeze(-1)

return predictions


class UncertaintyWeightedLoss(nn.Module):
"""Kendall et al. (2018) uncertainty weighting for MTL losses."""

def __init__(self, task_names: List[str]):
super().__init__()
# Log sigma^2 for each task (log scale for numerical stability)
self.log_vars = nn.ParameterDict({
task: nn.Parameter(torch.zeros(1))
for task in task_names
})

def forward(self, task_losses: dict) -> torch.Tensor:
total_loss = 0
for task, loss in task_losses.items():
log_var = self.log_vars[task]
# For classification: L / exp(log_var) + log_var
total_loss += loss * torch.exp(-log_var) + log_var
return total_loss

Production Engineering Notesโ€‹

Task Selection for MTLโ€‹

Not every task benefits from MTL. The rule of thumb: tasks that share input features and have related objectives benefit from MTL. Tasks with orthogonal inputs or conflicting objectives do not.

Good MTL groupings:

  • User engagement signals (click, like, share, comment) on the same content type
  • Multiple quality metrics for the same prediction problem (CTR + CVR)
  • Multilingual versions of the same task

Bad MTL groupings:

  • Content recommendation + system load prediction (inputs and objectives are unrelated)
  • Click prediction + fraud detection (adversarial objectives - fraud detection must not transfer to click prediction)
  • Tasks with radically different label frequencies without careful sampling strategy

Debugging MTL Modelsโ€‹

The standard debugging toolkit for MTL:

  1. Gradient conflict monitoring: Log the average cosine similarity between task gradients during training. If consistently negative for two tasks, they are in conflict.
  2. Task learning curves: Plot validation loss per task over training. If one task's loss stops decreasing or increases after an initial period, it may be experiencing negative transfer.
  3. Gate weight visualization (MMoE): Plot the distribution of gate weights for each task. If all tasks have identical gate distributions, the experts are not specializing. If one expert dominates all gates, the routing has collapsed.
  4. Single-task comparison: Train each task alone and compare to MTL performance. This is the ground truth for detecting negative transfer.

Serving MTL Modelsโ€‹

MTL models have one network, but tasks may be scored at different frequencies. At inference time, you need the output for all tasks to compute the final ranking score. The final ranking score is typically a weighted combination of task scores:

rankingย score=โˆkpkwk\text{ranking score} = \prod_k p_k^{w_k}

where pkp_k is the predicted probability for task kk and wkw_k is the business weight for that task (e.g., purchase = 5x, like = 1x, hide = -2x). These business weights are tuned separately from model training and can be updated without retraining.

Common Mistakesโ€‹

danger

Mistake: Mixing tasks with very different label sparsities without resampling.

If task A has labels for 100% of training examples (every item has a click/no-click label) and task B has labels for 0.1% (only items that were flagged get a quality label), training with equal-weight data will produce terrible task B performance. The model will overfit to task A's dense signal and never learn task B. Use task-specific sampling rates: oversample task B examples or use task-specific loss masking so task B's loss is computed only on examples with task B labels.

warning

Mistake: Using a single global learning rate for all task heads.

Task heads that receive sparse gradient signal (rare task) should use a higher learning rate to compensate. Task heads that receive dense gradient signal (common task) should use a lower learning rate to prevent overfitting. Use parameter groups in PyTorch to set per-task learning rates independently. The shared trunk should use the default learning rate.

warning

Mistake: Treating MTL as a free lunch for rare tasks.

MTL improves rare task performance primarily through better shared representations, not through magic. If the rare task is truly unrelated to other tasks in the group, MTL will not help and may hurt. Always validate rare task improvement by comparing MTL to a well-trained single-task model with careful hyperparameter tuning. The single-task baseline is often competitive if properly tuned.

tip

Tip: Start with hard parameter sharing, then upgrade to MMoE if you detect negative transfer.

Hard parameter sharing is simpler to implement, debug, and serve. Start there. If your monitoring shows consistent negative transfer between specific task pairs, those tasks are candidates for separate expert networks (soft parameter sharing / MMoE). Do not add MMoE complexity preemptively - it adds training time, memory, and debugging complexity.

Interview Q&Aโ€‹

Q: Why does multi-task learning often outperform single-task learning even for the main task?

A: Two mechanisms. First, auxiliary tasks act as regularization. When a shared representation must satisfy constraints imposed by multiple task objectives, it is prevented from overfitting to the idiosyncrasies of any single task's training data. This regularization effect is especially strong when tasks have different noise patterns - overfitting to one task's noise is penalized by another task's gradient. Second, auxiliary tasks provide gradient signal about relevant semantic relationships that the main task alone might not surface. A click prediction model may not learn from a single example that a user tends to engage with cooking content - but if like prediction and share prediction also see the same user's behavior, the shared representation converges faster to a feature space where "cooking affinity" is explicitly represented. This representation is then available to click prediction, improving its predictions even though click prediction didn't directly learn the cooking affinity feature.

Q: What is gradient conflict in MTL and how does PCGrad address it?

A: Gradient conflict occurs when two tasks' gradient vectors for the shared parameters point in opposite directions. Concretely, if the gradient for Task A wants to increase a shared weight ww by 0.1 and Task B's gradient wants to decrease ww by 0.1, the net gradient is zero - neither task makes progress, and training stalls. PCGrad detects conflict (cosine similarity between task gradients is negative) and projects each task's gradient onto the plane perpendicular to the conflicting gradient. This removes the component of Task A's gradient that directly opposes Task B, allowing both tasks to make progress without cancellation. PCGrad is applied per-parameter for each pair of conflicting tasks during each training step. It adds computational overhead proportional to the number of task pairs but typically improves convergence significantly for task groups with many conflicting objectives.

Q: You're designing an MTL system for a social media feed. What tasks would you group together, and which would you separate?

A: Group together: all engagement prediction tasks (CTR, like rate, share rate, comment rate, video watch time, story completion rate). These all predict user responses to content and share the same input features (user embedding, content embedding, context). Group together but with separate experts (MMoE): engagement tasks and safety/quality tasks (misinformation flag rate, report rate, hide-post rate). These share some features (content type, creator history) but have partially conflicting objectives - high engagement and safety are sometimes in tension. Separate completely: feed ranking and ad relevance prediction. Feed ranking optimizes for organic engagement; ad relevance optimizes for ad response. Ads have fundamentally different features (advertiser attributes, bid prices, ad creative features) that have no relevance to organic feed ranking. Mixing them would pollute both representations.

Q: How would you detect and respond to catastrophic negative transfer in production?

A: Negative transfer is detected by continuously monitoring per-task validation metrics alongside the overall MTL metric. If a task's standalone validation metric degrades more than 3-5% relative to its single-task baseline, and this degradation persists across multiple training runs, that task is experiencing negative transfer. Response options in order of increasing complexity: (1) Increase the task's loss weight to give it more influence over shared representation learning. (2) Move the task to a separate expert in an MMoE architecture so it has its own dedicated computation pathway. (3) Freeze the shared trunk for a few steps and train only the task-specific head (fine-tuning within MTL). (4) If none of the above work, remove the task from the MTL group entirely and train it as a separate single-task model. The last resort is the right call when the task truly has incompatible objectives - forcing incompatible tasks together always hurts both.

Q: What is the difference between hard parameter sharing and soft parameter sharing, and when would you choose each?

A: Hard parameter sharing uses one shared trunk with no task-specific variation - all tasks see identical intermediate representations. Soft parameter sharing (MMoE, routing networks) allows tasks to adaptively weight different portions of the shared representation through gating mechanisms. Choose hard parameter sharing when tasks are closely related (e.g., CTR and CVR for the same product category), when model size is constrained, or when you're starting out and want simplicity. Choose soft parameter sharing when tasks have different data distributions, when you've detected negative transfer in hard sharing, when tasks require different types of information at different layers, or when you have more than ~10 tasks with varied objectives. The practical rule: start hard, add softness where the data tells you it's needed.

ยฉ 2026 EngineersOfAI. All rights reserved.