Two-Tower Models - The Architecture Powering Google, TikTok, and YouTube
Reading time: ~40 minutes | Level: Recommender Systems | Role: MLE, AI Engineer, MLOps
The Constraint That Changed Everything
Picture a TikTok engineer at their desk in 2020. The app has crossed 1 billion active users. Every time someone opens TikTok - which happens billions of times a day - the system has roughly 100 milliseconds to find the best 50 videos from a catalog of over 100 million items. That is the product requirement: 100ms, 100 million items, 1 billion users, billions of requests per day.
Run the numbers on a NeuMF model. A single forward pass for one user-item pair takes about 0.1 milliseconds on a modern GPU. For one user, scoring all 100 million items would take - nearly three hours, even batched aggressively. You would need millions of GPUs dedicated to a single user's request to hit the 100ms target. That is not a resource problem you can throw hardware at. It is a fundamental architectural incompatibility.
The breakthrough that solved this - that made billion-scale real-time recommendation physically possible - came from a deceptively simple architectural constraint. What if you could design the model so that the user computation and the item computation were completely independent? If you could precompute item representations offline and store them in a fast lookup structure, then at serving time you would only need to: (1) compute the user representation in real time (one forward pass, ~1ms), and (2) find the nearest precomputed item vectors (approximate nearest neighbor search, ~5–10ms). Total: under 15ms for a billion-item catalog.
This is the two-tower model. The name describes the architecture: a user tower that processes user features into a single user embedding, and a separate item tower that processes item features into a single item embedding. The two towers are trained jointly, but they are architecturally independent - the user tower's computation never touches item features, and the item tower's computation never touches user features. This constraint is the entire point. It is what makes precomputation possible.
The two-tower model is now the dominant architecture for the retrieval stage of large-scale recommendation systems. Google uses it in YouTube. TikTok uses it. Pinterest uses it. Spotify uses it. Understanding two-tower is not optional for anyone working on recommendation at scale - it is the foundation.
Why This Exists
The previous lesson covered NeuMF: a powerful model that concatenates user and item embeddings and passes them through a deep network. NeuMF is expressive and accurate. It is also architecturally unsuitable for real-time retrieval at scale.
The problem is user-item coupling at inference time. In any model where the user and item representations are processed jointly - concatenated, attended over, or otherwise mixed before the final prediction - you cannot precompute anything. You must run the entire forward pass for every user-item pair you want to score.
| Model type | Inference cost per user | Scalable to 100M items? |
|---|---|---|
| NeuMF (joint interaction) | 100M forward passes | No |
| Two-tower (independent towers) | 1 forward pass + ANN search | Yes |
Two-tower imposes the strongest possible constraint: the final similarity score must be decomposable as a function of independent user and item representations:
The user tower maps user features to a user embedding. The item tower maps item features to an item embedding. The score is their dot product (or cosine similarity). Because the two towers share no parameters and have no information flow between them, item embeddings can be computed once and stored. User embeddings are computed at query time. The nearest stored item embeddings are retrieved with FAISS.
This architectural constraint is the reason two-tower models are slightly less expressive than joint models - the model cannot leverage fine-grained user-item feature interactions during the forward pass. But at retrieval scale, this is an acceptable trade-off. The retrieval stage does not need to be perfect; it needs to produce a set of 1,000 plausible candidates that a more expressive ranking model can then re-score.
The two-stage pipeline - two-tower retrieval followed by NeuMF-style ranking - gives you both scalability and accuracy. Each stage does what it is best at.
Historical Context
The Foundational Papers
DSSM (2013) - Huang et al., "Learning Deep Structured Semantic Models for Web Search using Clickthrough Data," Microsoft Research. The first large-scale dual-encoder trained with click signals. Used for document retrieval in Bing. Established the in-batch negative training approach.
YouTube DNN (2016) - Covington et al., "Deep Neural Networks for YouTube Recommendations." Google's paper describing their two-stage system: a deep retrieval model followed by a deep ranking model. The retrieval model is a precursor to the modern two-tower: it outputs a single user embedding, finds nearest neighbors in the item embedding space, and passes candidates to the ranking model.
Sampling-Bias-Corrected Neural Modeling (2019) - Yi et al., Google Brain. The paper that formalized the modern two-tower training with in-batch negatives and introduced the crucial sampling bias correction technique. This is the paper that most practitioners reference when they say "two-tower." The correction is now standard in any serious two-tower implementation.
FAISS (2021) - Johnson et al., Meta AI. The library that made billion-scale approximate nearest neighbor search practical. Without efficient ANN, the two-tower architecture would not work - you would be back to brute-force linear scan.
The "Aha Moment"
The key insight that Yi et al. formalized was that the two-tower architecture creates a natural tension: the model is trained to distinguish the correct item for a user from negative items, but at training time the negatives are sampled from the same batch. Popular items appear as negatives more often (because they are more likely to be someone else's positive in the batch). This biases the model against popular items - exactly the opposite of the popularity bias you usually worry about. The solution (sampling bias correction) is both mathematically elegant and practically critical.
Core Concepts
Concept 1: The Two-Tower Architecture
The defining equation:
where:
- is the user feature vector (user ID, demographics, watch history, search history, device, time-of-day, etc.)
- is the item feature vector (item ID, category, tags, duration, creator features, content embeddings, etc.)
- is the user tower - a deep neural network mapping user features to a -dimensional embedding
- is the item tower - a separate deep neural network mapping item features to the same -dimensional space
- is the dot product (or cosine similarity after normalization)
Both towers can be arbitrarily complex - they can include attention layers, recurrent layers, cross-feature interactions within each tower. The only architectural constraint is that no information crosses from one tower to the other. The user tower computes its output solely from user features; the item tower computes its output solely from item features.
The output dimension is typically 64–256. Smaller means faster ANN search but less expressive representations. Larger stores more information but increases memory and search latency.
Concept 2: Training with In-Batch Negatives
Two-tower models are almost always trained with in-batch negatives (also called batch softmax or sampled softmax).
Given a batch of (user, item) positive pairs , treat every other item in the batch as a negative for each user:
where is a temperature parameter (typically 0.05–0.2).
This is the same loss as InfoNCE / NT-Xent used in contrastive learning (SimCLR, CLIP). The intuition: treat recommendation as a classification problem over the batch. For user , the correct answer is item . The model must predict item as more relevant than all other items in the batch.
Why in-batch negatives? Because they are free. Every forward pass already computes all item embeddings in the batch. Using them as negatives adds zero additional computation. With batch size 2048, you get 2047 negatives per positive - vastly more signal than the 4 negatives per positive used in standard NCF training.
Temperature controls the sharpness of the distribution:
- Small (approx 0.05): sharp distribution, model is penalized heavily for any misranking. Pushes representations apart aggressively. Risk of training instability.
- Large (approx 1.0): flat distribution, model receives weak signal. Embeddings do not separate well.
The optimal is typically found by grid search on validation recall. Values in work well for most recommendation problems.
Concept 3: Sampling Bias Correction (Yi et al.'s Key Contribution)
In-batch negatives introduce a subtle but serious bias. Items are sampled into batches proportionally to their popularity (more popular items have more interactions, so they appear in more training batches). This means popular items appear as negatives much more often than rare items.
The model interprets "item appears as negative often" as "item is generally bad." Over training, it systematically learns to down-rank popular items - even when they are genuinely good recommendations for a particular user. This is the opposite of the usual popularity bias: instead of over-recommending popular items, the model under-recommends them.
Yi et al.'s fix is mathematically clean: subtract the log of the item's sampling probability from its logit before computing the softmax.
Let be the sampling probability of item (proportional to its frequency in the training data). The corrected score is:
The corrected loss:
Why does this work? The correction penalizes the model for scoring an item high beyond what its popularity already predicts. A popular item with close to 0 (high frequency) gets its logit heavily penalized, so the model only recommends it if it genuinely fits the user. An unpopular item with large gets a logit bonus, making it easier for the model to recommend niche items to the users who would love them.
The sampling probability is estimated from the training data: count how many times item appears in the batch stream divided by the total number of batch positions.
Concept 4: Serving - FAISS and Approximate Nearest Neighbor Search
At serving time:
- Offline (precomputed nightly or hourly): run all items through the item tower. Store the resulting embedding matrix in a FAISS index.
- Online (per request): (a) compute user embedding via one forward pass through the user tower, (b) query FAISS for the top-1000 nearest item embeddings, (c) return candidates to the ranking model.
What is FAISS? A library from Meta AI for efficient similarity search. At its core, FAISS does approximate nearest neighbor (ANN) search: finding the vectors most similar to a query vector in a large collection, without computing the exact distance to every vector.
IVF (Inverted File Index) - the workhorse for billion-scale search:
- At index build time, cluster the embedding space into Voronoi cells using k-means.
- For each query, search only the nearest cluster centers, then do exact search within those clusters.
- Trade-off: higher = higher recall, slower search.
PQ (Product Quantization) - for memory compression:
- Split each 128-dim embedding into 16 sub-vectors of 8 dimensions each.
- Quantize each sub-vector to its nearest centroid from a codebook of 256 entries.
- Store the embedding as 16 bytes (one byte per sub-vector) instead of bytes.
- Reduces memory footprint by 32x at the cost of some recall degradation.
FAISS index types (roughly from fastest/lowest-recall to slowest/highest-recall):
IndexFlatL2: exact search, no approximation. Only viable up to ~1M items.IndexIVFFlat: IVF with exact inner-cluster search. Good up to ~100M items.IndexIVFPQ: IVF + product quantization. Billions of items in RAM.IndexHNSW: graph-based ANN, excellent recall-latency trade-off, used when memory is not the bottleneck.
Recall vs latency trade-off: for a production two-tower system, you typically target Recall@1000 > 0.95 (the ground truth item is in the top-1000 candidates 95%+ of the time). The ranking model compensates for the few misses.
Concept 5: The Full Retrieval to Ranking Pipeline
User opens app
|
User features assembled (history, demographics, context)
|
User Tower: 1 forward pass → user embedding (1–5 ms)
|
FAISS ANN search: top-1000 candidates (5–15 ms)
|
Ranking Model (NeuMF / Deep Interest Network / etc.)
- Scores each of the 1,000 candidates
- Can use rich user-item interaction features
- 50–100 ms for 1,000 candidates
|
Business Logic Layer
- Diversity enforcement
- Freshness boost
- Safety filters
|
Final 50 items shown to user
Each stage has different objectives and constraints:
- Retrieval: maximize recall at low latency. Must find the needle in the haystack.
- Ranking: maximize precision at moderate latency. Given the candidate set is small, can afford to be expensive.
- Business logic: maximize engagement while respecting diversity, freshness, and safety constraints.
The retrieval model optimizes for: "does the ground truth item appear in my top-1000?" The ranking model optimizes for: "given these 1000 items, which 50 should I show, in what order?" These are fundamentally different objectives, which is why they use different architectures.
Architecture Diagram
Implementation: Two-Tower from Scratch
Model Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from typing import Optional
# ─── Feature Configuration ────────────────────────────────────────────────────
class UserFeatureCfg:
"""
User feature specification. In a real system this would include
watch history embeddings, search query embeddings, etc.
For this example: user ID + age bucket + gender + device type.
"""
NUM_USERS: int = 100_000
NUM_AGE_BUCKETS: int = 10
NUM_GENDERS: int = 3
NUM_DEVICES: int = 5
EMB_DIM: int = 32 # embedding dim for each categorical feature
class ItemFeatureCfg:
"""
Item feature specification.
item ID + category + duration bucket + content embedding (pre-trained).
"""
NUM_ITEMS: int = 500_000
NUM_CATEGORIES: int = 200
NUM_DURATION_BUCKETS: int = 10
CONTENT_EMB_DIM: int = 128 # pre-trained content embedding
EMB_DIM: int = 32
# ─── User Tower ───────────────────────────────────────────────────────────────
class UserTower(nn.Module):
"""
Maps user features to a k-dimensional embedding.
Input features:
- user_id: categorical (embedding)
- age_bucket: categorical (embedding)
- gender: categorical (embedding)
- device: categorical (embedding)
All embeddings are concatenated and fed into a MLP.
Output is L2-normalized for cosine similarity scoring.
"""
def __init__(
self,
output_dim: int = 128,
hidden_dims: list[int] = [256, 128],
dropout: float = 0.1,
):
super().__init__()
cfg = UserFeatureCfg
# Categorical embeddings
self.user_emb = nn.Embedding(cfg.NUM_USERS, cfg.EMB_DIM)
self.age_emb = nn.Embedding(cfg.NUM_AGE_BUCKETS, 8)
self.gender_emb = nn.Embedding(cfg.NUM_GENDERS, 4)
self.device_emb = nn.Embedding(cfg.NUM_DEVICES, 8)
# Total input: 32 + 8 + 4 + 8 = 52
input_dim = cfg.EMB_DIM + 8 + 4 + 8
layers: list[nn.Module] = []
in_dim = input_dim
for h_dim in hidden_dims:
layers.extend([
nn.Linear(in_dim, h_dim),
nn.BatchNorm1d(h_dim),
nn.ReLU(),
nn.Dropout(dropout),
])
in_dim = h_dim
layers.append(nn.Linear(in_dim, output_dim))
self.mlp = nn.Sequential(*layers)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=0.01)
elif isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
user_ids: torch.Tensor,
age_buckets: torch.Tensor,
genders: torch.Tensor,
devices: torch.Tensor,
) -> torch.Tensor:
user_e = self.user_emb(user_ids)
age_e = self.age_emb(age_buckets)
gender_e = self.gender_emb(genders)
device_e = self.device_emb(devices)
x = torch.cat([user_e, age_e, gender_e, device_e], dim=-1) # (B, 52)
out = self.mlp(x) # (B, output_dim)
return F.normalize(out, p=2, dim=-1) # L2-normalize
# ─── Item Tower ───────────────────────────────────────────────────────────────
class ItemTower(nn.Module):
"""
Maps item features to a k-dimensional embedding.
Input features:
- item_id: categorical (embedding, randomly dropped with p=0.1
during training to enable cold-start content-only embedding)
- category: categorical (embedding)
- duration_bucket: categorical (embedding)
- content_emb: pre-trained dense embedding (e.g., from video encoder)
Output is L2-normalized.
"""
def __init__(
self,
output_dim: int = 128,
hidden_dims: list[int] = [256, 128],
dropout: float = 0.1,
id_dropout_prob: float = 0.1, # cold-start training trick
):
super().__init__()
cfg = ItemFeatureCfg
self.id_dropout_prob = id_dropout_prob
# Categorical embeddings
self.item_emb = nn.Embedding(cfg.NUM_ITEMS, cfg.EMB_DIM)
self.category_emb = nn.Embedding(cfg.NUM_CATEGORIES, 16)
self.duration_emb = nn.Embedding(cfg.NUM_DURATION_BUCKETS, 8)
# Project pre-trained content embedding (128-dim) to 64-dim
self.content_proj = nn.Sequential(
nn.Linear(cfg.CONTENT_EMB_DIM, 64),
nn.ReLU(),
)
# Total input: 32 + 16 + 8 + 64 = 120
input_dim = cfg.EMB_DIM + 16 + 8 + 64
layers: list[nn.Module] = []
in_dim = input_dim
for h_dim in hidden_dims:
layers.extend([
nn.Linear(in_dim, h_dim),
nn.BatchNorm1d(h_dim),
nn.ReLU(),
nn.Dropout(dropout),
])
in_dim = h_dim
layers.append(nn.Linear(in_dim, output_dim))
self.mlp = nn.Sequential(*layers)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=0.01)
elif isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
item_ids: torch.Tensor,
categories: torch.Tensor,
duration_buckets: torch.Tensor,
content_embeddings: torch.Tensor,
) -> torch.Tensor:
item_e = self.item_emb(item_ids)
# Cold-start trick: randomly zero out item ID embedding during training
# Forces the model to learn to embed items from content features alone
if self.training and self.id_dropout_prob > 0:
mask = (torch.rand(item_e.shape[0], device=item_e.device) > self.id_dropout_prob)
item_e = item_e * mask.float().unsqueeze(-1)
cat_e = self.category_emb(categories)
dur_e = self.duration_emb(duration_buckets)
content_e = self.content_proj(content_embeddings)
x = torch.cat([item_e, cat_e, dur_e, content_e], dim=-1) # (B, 120)
out = self.mlp(x) # (B, output_dim)
return F.normalize(out, p=2, dim=-1) # L2-normalize
# ─── Full Two-Tower Model ─────────────────────────────────────────────────────
class TwoTowerModel(nn.Module):
"""
Full two-tower model.
Score = cosine similarity(user_emb, item_emb).
Trained with in-batch negatives + sampling bias correction.
"""
def __init__(
self,
output_dim: int = 128,
hidden_dims: list[int] = [256, 128],
temperature: float = 0.1,
dropout: float = 0.1,
):
super().__init__()
self.user_tower = UserTower(output_dim, hidden_dims, dropout)
self.item_tower = ItemTower(output_dim, hidden_dims, dropout)
# Learnable temperature - usually works better than fixed
self.log_temperature = nn.Parameter(torch.tensor(temperature).log())
@property
def temperature(self) -> torch.Tensor:
# Keep temperature positive via exp
return self.log_temperature.exp()
def forward(
self,
user_ids: torch.Tensor,
age_buckets: torch.Tensor,
genders: torch.Tensor,
devices: torch.Tensor,
item_ids: torch.Tensor,
categories: torch.Tensor,
duration_buckets: torch.Tensor,
content_embeddings: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Returns (user_emb, item_emb), both L2-normalized.
Loss is computed outside the model to allow flexibility.
"""
user_emb = self.user_tower(user_ids, age_buckets, genders, devices)
item_emb = self.item_tower(item_ids, categories, duration_buckets, content_embeddings)
return user_emb, item_emb
Training with In-Batch Negatives and Bias Correction
def in_batch_softmax_loss(
user_emb: torch.Tensor,
item_emb: torch.Tensor,
temperature: torch.Tensor,
item_sampling_probs: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, dict]:
"""
In-batch softmax loss with optional sampling bias correction.
The similarity matrix sim[i, j] = dot(user_i, item_j).
The diagonal is the positive pair. All off-diagonal entries
in each row are treated as negatives for that user.
Args:
user_emb: (B, k) L2-normalized user embeddings
item_emb: (B, k) L2-normalized item embeddings
temperature: scalar - controls softmax sharpness
item_sampling_probs: (B,) - q(i) for each item. If provided,
subtract log(q(i)) from logits (Yi et al. bias correction).
Returns:
loss: scalar
metrics: dict with 'top1_acc' and 'avg_sim_pos' for monitoring
"""
B = user_emb.shape[0]
# Similarity matrix: (B, B)
# sim[i, j] = cosine_sim(user_i, item_j)
sim_matrix = torch.matmul(user_emb, item_emb.T)
# Apply sampling bias correction (Yi et al. 2019)
if item_sampling_probs is not None:
log_q = torch.log(item_sampling_probs.clamp(min=1e-9)) # (B,)
# Subtract log(q(i)) from each column (item i's column gets its correction)
sim_matrix = sim_matrix - log_q.unsqueeze(0) # broadcast over users
# Scale by temperature
logits = sim_matrix / temperature # (B, B)
# Ground truth: diagonal matches (user_i with item_i)
targets = torch.arange(B, device=user_emb.device)
loss = F.cross_entropy(logits, targets)
# Monitoring metrics
with torch.no_grad():
top1_acc = (logits.argmax(dim=1) == targets).float().mean().item()
# Average cosine similarity of positive pairs (pre-temperature)
pos_sim = sim_matrix.diagonal().mean().item()
return loss, {"top1_acc": top1_acc, "avg_sim_pos": pos_sim}
def estimate_sampling_probs(
item_ids: torch.Tensor,
global_item_counts: torch.Tensor,
smoothing: float = 0.75,
) -> torch.Tensor:
"""
Estimate q(i) for each item in the batch.
Uses the same power-law smoothing as word2vec: count^alpha / sum.
"""
counts = global_item_counts[item_ids].float()
smoothed = counts ** smoothing
return smoothed / smoothed.sum()
# ─── Training Loop ────────────────────────────────────────────────────────────
def train_two_tower(
model: TwoTowerModel,
train_loader: DataLoader,
num_epochs: int = 10,
lr: float = 1e-3,
weight_decay: float = 1e-5,
use_bias_correction: bool = True,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> dict:
model = model.to(device)
# Separate LR for temperature parameter
main_params = [p for n, p in model.named_parameters() if "log_temperature" not in n]
temp_params = [model.log_temperature]
optimizer = optim.AdamW([
{"params": main_params, "lr": lr, "weight_decay": weight_decay},
{"params": temp_params, "lr": lr * 0.1, "weight_decay": 0.0},
])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
# Item frequency tracking for bias correction
global_item_counts = torch.ones(ItemFeatureCfg.NUM_ITEMS, device=device)
history = {"loss": [], "top1_acc": [], "temperature": []}
for epoch in range(num_epochs):
model.train()
epoch_loss = 0.0
epoch_acc = 0.0
num_batches = 0
for batch in train_loader:
# Unpack batch (structure depends on your Dataset)
user_ids = batch["user_id"].to(device)
age_buckets = batch["age_bucket"].to(device)
genders = batch["gender"].to(device)
devices_feat = batch["device"].to(device)
item_ids = batch["item_id"].to(device)
categories = batch["category"].to(device)
duration_buckets = batch["duration_bucket"].to(device)
content_embs = batch["content_emb"].to(device)
# Update item frequency counts
for item_id in item_ids.cpu().numpy():
global_item_counts[item_id] += 1
optimizer.zero_grad()
user_emb, item_emb = model(
user_ids, age_buckets, genders, devices_feat,
item_ids, categories, duration_buckets, content_embs,
)
# Compute sampling probabilities for bias correction
sampling_probs = None
if use_bias_correction:
sampling_probs = estimate_sampling_probs(
item_ids, global_item_counts
)
loss, metrics = in_batch_softmax_loss(
user_emb, item_emb,
model.temperature,
item_sampling_probs=sampling_probs,
)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
epoch_loss += loss.item()
epoch_acc += metrics["top1_acc"]
num_batches += 1
scheduler.step()
avg_loss = epoch_loss / num_batches
avg_acc = epoch_acc / num_batches
temp_val = model.temperature.item()
history["loss"].append(avg_loss)
history["top1_acc"].append(avg_acc)
history["temperature"].append(temp_val)
print(
f"Epoch {epoch+1:>2}/{num_epochs} | "
f"Loss: {avg_loss:.4f} | "
f"In-batch Top-1: {avg_acc:.4f} | "
f"Temp: {temp_val:.4f}"
)
return history
FAISS Index: Building and Querying
import faiss
import numpy as np
def precompute_item_embeddings(
model: TwoTowerModel,
item_catalog: dict,
batch_size: int = 2048,
device: str = "cpu",
) -> np.ndarray:
"""
Offline step: run all items through the item tower.
Returns (N, k) float32 embedding matrix.
Run this nightly after model retraining.
"""
model.eval()
model.to(device)
all_embeddings = []
num_items = len(item_catalog["item_id"])
print(f"Precomputing embeddings for {num_items:,} items...")
with torch.no_grad():
for start in range(0, num_items, batch_size):
end = min(start + batch_size, num_items)
item_ids = torch.tensor(
item_catalog["item_id"][start:end], dtype=torch.long, device=device
)
categories = torch.tensor(
item_catalog["category"][start:end], dtype=torch.long, device=device
)
duration_buckets = torch.tensor(
item_catalog["duration_bucket"][start:end], dtype=torch.long, device=device
)
content_embs = torch.tensor(
item_catalog["content_emb"][start:end], dtype=torch.float32, device=device
)
embs = model.item_tower(item_ids, categories, duration_buckets, content_embs)
all_embeddings.append(embs.cpu().numpy())
result = np.vstack(all_embeddings).astype(np.float32)
print(f"Done. Shape: {result.shape}")
return result
def build_faiss_index(
item_embeddings: np.ndarray,
n_list: int = 4096,
n_subvectors: int = 16,
use_gpu: bool = False,
) -> faiss.Index:
"""
Build an IVF-PQ FAISS index for billion-scale ANN search.
n_list: number of Voronoi cells. Rule of thumb: sqrt(N) to 4*sqrt(N).
n_subvectors: number of PQ sub-vectors. embedding_dim must be divisible.
For d=128, n_subvectors=16 gives 8-dim sub-vectors (256 centroids each).
Memory: N * n_subvectors bytes (vs N * d * 4 bytes for float32).
For N=100M, d=128: float32 = 51.2 GB; IVF-PQ = 1.6 GB (32x compression).
"""
N, d = item_embeddings.shape
assert d % n_subvectors == 0, f"d={d} must be divisible by n_subvectors={n_subvectors}"
print(f"Building FAISS IVF-PQ index: {N:,} items, d={d}, n_list={n_list}")
quantizer = faiss.IndexFlatIP(d) # inner product for L2-normalized vectors = cosine
index = faiss.IndexIVFPQ(
quantizer,
d,
n_list,
n_subvectors,
8, # bits per sub-vector code (8 bits = 256 centroids)
)
index.metric_type = faiss.METRIC_INNER_PRODUCT
if use_gpu:
res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index)
print("Training index (k-means clustering)...")
# FAISS needs at least 39 * n_list training points
train_size = min(N, max(39 * n_list, 256_000))
train_indices = np.random.choice(N, train_size, replace=False)
index.train(item_embeddings[train_indices])
print("Adding vectors to index...")
index.add(item_embeddings)
print(f"Index built. Total vectors: {index.ntotal:,}")
return index
def retrieve_candidates(
user_embedding: np.ndarray,
index: faiss.Index,
k: int = 1000,
n_probe: int = 64,
) -> tuple[np.ndarray, np.ndarray]:
"""
Online serving: retrieve top-k candidates for a user.
Args:
user_embedding: (d,) float32 L2-normalized user embedding
index: trained FAISS index
k: number of candidates to return
n_probe: number of IVF clusters to search.
Higher = better recall, slower. Start at n_list // 64.
Returns:
scores: (k,) cosine similarity scores
item_indices: (k,) integer indices into the item catalog
"""
if hasattr(index, "nprobe"):
index.nprobe = n_probe
query = user_embedding.reshape(1, -1).astype(np.float32)
scores, indices = index.search(query, k)
return scores[0], indices[0]
def evaluate_recall_at_k(
model: TwoTowerModel,
test_data: list[dict],
faiss_index: faiss.Index,
k_values: list[int] = [10, 50, 100, 500, 1000],
device: str = "cpu",
) -> dict[int, float]:
"""
Recall@K: for each test (user, item) pair, check whether the
ground truth item appears in the top-K retrieved candidates.
This is the primary metric for retrieval model quality.
Target: Recall@1000 > 0.95 before deploying to production.
"""
model.eval()
hit_counts = {k: 0 for k in k_values}
total = len(test_data)
with torch.no_grad():
for example in test_data:
user_ids = torch.tensor([example["user_id"]], dtype=torch.long, device=device)
age_buckets = torch.tensor([example["age_bucket"]], dtype=torch.long, device=device)
genders = torch.tensor([example["gender"]], dtype=torch.long, device=device)
devices_feat = torch.tensor([example["device"]], dtype=torch.long, device=device)
user_emb = model.user_tower(
user_ids, age_buckets, genders, devices_feat
).squeeze(0).cpu().numpy()
max_k = max(k_values)
_, retrieved = retrieve_candidates(user_emb, faiss_index, k=max_k)
gt_idx = example["item_idx"]
for k in k_values:
if gt_idx in retrieved[:k]:
hit_counts[k] += 1
recall = {k: hit_counts[k] / total for k in k_values}
for k, r in recall.items():
print(f"Recall@{k}: {r:.4f}")
return recall
# ─── End-to-End Serving Example ──────────────────────────────────────────────
def serve_recommendations(
user_context: dict,
model: TwoTowerModel,
faiss_index: faiss.Index,
item_metadata: dict,
k_retrieve: int = 1000,
k_return: int = 50,
device: str = "cpu",
) -> list[dict]:
"""
Full online serving path.
Step 1: compute user embedding (1 forward pass).
Step 2: ANN search in FAISS.
Step 3: return top-k (ranking model would re-score here).
"""
model.eval()
with torch.no_grad():
user_emb = model.user_tower(
torch.tensor([user_context["user_id"]], dtype=torch.long, device=device),
torch.tensor([user_context["age_bucket"]], dtype=torch.long, device=device),
torch.tensor([user_context["gender"]], dtype=torch.long, device=device),
torch.tensor([user_context["device"]], dtype=torch.long, device=device),
).squeeze(0).cpu().numpy()
scores, item_indices = retrieve_candidates(user_emb, faiss_index, k=k_retrieve)
results = []
for score, idx in zip(scores[:k_return], item_indices[:k_return]):
meta = item_metadata.get(int(idx), {})
results.append({"item_id": int(idx), "retrieval_score": float(score), **meta})
return results
Production Engineering Notes
The Retrieval to Ranking Pipeline in Detail
In production, the two-tower retrieval model is only stage one. The full pipeline at a company like YouTube or TikTok typically has 3–4 stages:
Stage 1 - Retrieval (two-tower): reduce 100M+ items to ~1,000 candidates. Recall@1000 is the metric. Latency budget: 20ms.
Stage 2 - Pre-ranking (light model): reduce 1,000 to 200. A lightweight model (shallow MLP or linear) that runs quickly. Used to trim the candidate set before the expensive ranking model.
Stage 3 - Ranking (heavy model): score 200 candidates with a full NeuMF / DIN / DLRM. Can use features that cross user and item (which would break the two-tower constraint). Latency budget: 50–80ms.
Stage 4 - Post-processing: diversity enforcement (do not show 10 videos from the same creator), freshness boost, safety filters, business rules.
Each stage uses a more expensive model that can afford less scale because the candidate set is smaller.
Freshness: Handling Item Embedding Staleness
Item embeddings are precomputed offline - typically nightly. This means new items (posted today) have no embedding and cannot be retrieved. Solutions:
Content-based warm start: build the item tower to operate on content features alone when the item ID is unknown or untrained. During training, randomly zero out the item ID embedding with probability (the id_dropout_prob in the implementation above). This forces the model to learn useful content-only embeddings that generalize to new items at serving time.
Streaming incremental updates: after a new item accumulates 50–100 interactions, re-run it through the item tower and push the new embedding to FAISS. FAISS supports incremental add() calls without a full index rebuild.
Exploration injection: explicitly inject new items into candidate sets via a rule-based policy, track their performance, and graduate them to the main retrieval index once they have reliable embeddings.
Multi-Task Learning in the Ranking Stage
The ranking model typically optimizes multiple objectives simultaneously - click probability, watch time, share probability, user satisfaction. YouTube uses a Multi-gate Mixture-of-Experts (MMoE) architecture for this. The key insight: different objectives have different user signals, and a single-task model will overfit to whatever is easiest to predict (usually clicks, which are noisy). Multi-task learning with task-specific output heads learns shared representations that generalize across all objectives.
Embedding Storage at Scale
For 100M items at 128 dimensions × 4 bytes per float = 51.2 GB of RAM. Solutions:
- Float16: cut storage in half (25.6 GB) with minimal precision loss for cosine similarity.
- IVF-PQ: 32x compression (1.6 GB) at the cost of ~1–3% recall degradation. The standard for billion-scale deployment.
- Binary embeddings: 128 dimensions = 16 bytes. 100M items = 1.6 GB. Hamming distance lookup is extremely fast. Used for first-stage rough filtering before a more precise re-rank.
Common Mistakes
:::danger Not Correcting for Sampling Bias If you train with in-batch negatives without sampling bias correction, your model will systematically under-recommend popular items. The mechanism is clear: popular items appear as negatives more often in batches, so the model learns to assign them low scores.
This seems counterintuitive (you usually worry about models over-recommending popular items), but the in-batch negative mechanism reverses the bias.
The fix is one line: subtract log(q(i)) from each item's logit before the softmax. Skipping this is a common mistake that silently degrades long-tail recommendation quality without producing obvious error messages.
:::
:::danger Using the Same Model for Retrieval and Ranking Retrieval and ranking have different objectives and completely different latency budgets. Retrieval must scan a billion items; ranking can afford to be slow because it sees only 1,000 candidates.
Using a two-tower model for ranking wastes its potential - the two-tower constraint prevents it from learning cross-feature interactions that ranking critically depends on.
Using a joint model (NeuMF, DIN) for retrieval is physically impossible at billion-item scale.
Always use two separate models: two-tower for retrieval, joint model for ranking. They are optimizing for different things. :::
:::warning Forgetting That Item Embeddings Go Stale A common production bug: train the two-tower model, precompute item embeddings, deploy - and never refresh the embeddings. Over time, the model gets retrained on new data, but the FAISS index still contains embeddings from the old model weights. This silently degrades retrieval quality.
Always build a pipeline to:
- Trigger FAISS index rebuild whenever the model is retrained.
- Push incremental embedding updates for new and high-velocity items.
- Monitor Recall@1000 continuously - sudden drops often indicate embedding staleness. :::
:::warning Small Batch Size Kills Two-Tower Training With a batch size of 32, each user sees only 31 negatives. The model barely learns to discriminate between relevant and irrelevant items. Two-tower models need large batches to work effectively.
Target batch size: 1024–4096. If GPU memory is the bottleneck, use gradient accumulation:
accumulation_steps = 4 # effective batch = batch_size * 4
for step, batch in enumerate(dataloader):
loss, _ = compute_loss(batch)
(loss / accumulation_steps).backward()
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
:::
:::tip Temperature Is a Critical Hyperparameter The temperature dramatically affects what the model learns:
- Too low (below 0.02): embeddings collapse to a few clusters (mode collapse). Loss goes to zero quickly but representations are not useful.
- Too high (above 0.5): embeddings are not discriminative. Model does not learn to separate users and items.
Start with , monitor the in-batch top-1 accuracy during training, and tune. A learnable temperature parameter (as in the implementation above) often converges to a better value than manual tuning. :::
YouTube Resources
| Video | Channel | Why Watch |
|---|---|---|
| Two-Tower Networks for Recommendations | Google Research | Overview of dual encoders, directly from the team that built it |
| YouTube Recommendation Deep Dive | Yannic Kilcher | Full walkthrough of the YouTube DNN paper - the production system |
| FAISS Explained | James Briggs | ANN search intuition and hands-on FAISS tutorial |
| Building a Recommendation Engine at Scale | MLOps Community | Real production two-tower system walk-through |
Interview Q&A
Q1: Why do you use a two-tower model for retrieval instead of a model like NeuMF?
Answer: The fundamental constraint is inference time. NeuMF - and any model where user and item features are jointly processed - requires a full forward pass for every user-item pair you want to score. At YouTube scale (2 billion users, 800 million videos), a single user's request would require 800 million forward passes. Even with aggressive batching on modern GPU hardware, this takes hours. The latency requirement is 100 milliseconds.
The two-tower model solves this with one architectural constraint: the user and item towers are completely independent. No user feature ever touches the item tower, and no item feature ever touches the user tower. This means item embeddings can be precomputed offline and stored in a FAISS index. At query time, the model only needs to: (1) compute the user embedding (1 forward pass, ~5ms), and (2) run ANN search (~10ms). Well under 100ms for a billion-item catalog.
The cost is expressiveness - the two-tower model cannot learn fine-grained user-item feature interactions the way NeuMF can. This is why two-tower is used for retrieval (recall@1000 from billions) and a more expressive model handles ranking (precision@50 from 1000 candidates).
Q2: What are in-batch negatives and why are they preferred for two-tower training?
Answer: In-batch negatives use the other positive items in the current training batch as negatives for each training example. If the batch contains pairs , then for user , items are all treated as negatives.
Why preferred:
First, they are free. Every forward pass already computes all item embeddings in the batch. Using them as negatives adds zero additional computation. With batch size 2048, you get 2047 negatives per positive - far more signal than the 4 negatives per positive typical in NCF training.
Second, they are high-quality. Items that appear in other users' interactions are plausible candidates - not random catalog items. These harder negatives force the model to learn genuine user preferences.
Third, the similarity matrix computation is a simple matrix multiply of two (B, k) tensors - it maps perfectly to GPU tensor cores and is extremely efficient.
The one catch: popular items appear as negatives more often, introducing sampling bias. This is corrected by subtracting from each item's logit, making in-batch negatives both efficient and statistically unbiased.
Q3: Explain the sampling bias correction from Yi et al. Why does it matter and how does it work?
Answer: Without correction, training with in-batch negatives causes the model to systematically under-rank popular items.
The mechanism: items are sampled into batches proportionally to their interaction frequency. Popular items have more interactions, appear in more batches, and therefore appear as negatives more often. The model repeatedly sees "do not recommend this popular item to this user" and learns to assign popular items low scores - even though popularity often signals genuine quality.
The correction comes from importance sampling theory. The model should be trained to minimize:
In-batch sampling approximates this sum but overrepresents popular items. To correct: divide each item's contribution by its sampling probability, which is equivalent to subtracting from its logit:
A popular item with high gets its logit heavily penalized - the model only recommends it if it genuinely fits this particular user above and beyond its baseline popularity. An unpopular item with low gets a logit bonus, making it easier for the model to surface niche items to the users who would love them.
In practice: estimate from empirical training frequencies, clamp to avoid log(0), subtract from logits before temperature scaling. One-line change with significant impact on long-tail recommendation quality.
Q4: Walk me through the FAISS IVF-PQ index. What are the tradeoffs?
Answer: IVF-PQ combines two techniques: Inverted File Index for fast approximate search and Product Quantization for memory compression.
IVF mechanics:
- At build time: cluster the embedding space into Voronoi cells via k-means. Each item is assigned to its nearest centroid.
- At query time: find the nearest centroids to the query, then do exact search within only those clusters. Returns approximate top-k in time instead of .
tradeoff: too few clusters → each cluster is large, search is slow. Too many clusters → each cluster is small, recall drops (true neighbors may be in clusters you did not probe). Rule of thumb: to .
tradeoff: higher = better recall, slower. is fastest, lowest recall. is exact search. In production, find the minimum that achieves Recall@1000 > 0.95, typically to .
Product Quantization: Split each 128-dim embedding into 16 sub-vectors of 8 dimensions. Quantize each sub-vector to one of 256 centroids (trained offline). Store the embedding as 16 bytes instead of 512 bytes - 32x compression. Approximate distances are computed from the quantized sub-vectors using pre-computed lookup tables, which is fast. Recall cost: roughly 1–3% for well-tuned configurations.
For 100M items: float32 = 51 GB, IVF-PQ = 1.6 GB. The compression is what makes billion-scale deployment feasible in RAM.
Q5: How do you handle new items that have no embeddings yet?
Answer: New items have no interaction history, so the item ID embedding starts as random noise and is meaningless. Without intervention, they cannot be retrieved until the model is retrained.
Content-based warm start (best solution): build the item tower to operate on content features alone, even when the item ID is untrained. During training, randomly zero out the item ID embedding with probability (ID dropout). This forces the model to learn to embed items from content features when the ID is unavailable - which generalizes directly to new items at serving time.
At deployment, when a new item arrives, pass it through the item tower with a zeroed-out or randomly initialized ID embedding and rely on the content features (title, category, content encoder output). The embedding will not be as accurate as a fully trained one, but it is far better than nothing and improves rapidly once interaction data accumulates.
Streaming updates: after a new item accumulates enough interactions (50–100 is a reasonable threshold), re-run it through the item tower with its updated ID embedding and push the new embedding to FAISS incrementally. FAISS supports index.add() without full rebuild.
Exploration injection: add a rule-based policy that forces new items into candidate sets regardless of their embedding quality. Track performance and graduate items to main retrieval once they have reliable embeddings.
Q6: Design a recommendation system for a company with 50 million users and 2 million items.
Answer: This is the full ML system design. Here is every architectural decision:
Data infrastructure: collect implicit signals (clicks, watch time, shares, skips) via Kafka. Process interactions into feature tables in BigQuery or Spark. Use a feature store (Feast or Tecton) for low-latency feature serving at query time.
Retrieval stage - two-tower model:
- User tower: user ID embedding (256-dim) + behavioral sequence (last 50 interactions via a 2-layer transformer) + demographic features → 128-dim L2-normalized output.
- Item tower: item ID embedding + content features (category, tags, content encoder output from a pre-trained video/text model) + creator features → 128-dim L2-normalized output. ID dropout at 10% for cold-start.
- Training: in-batch negatives with sampling bias correction, batch size 2048, learnable temperature initialized at 0.1, AdamW with cosine schedule.
- Offline precompute: run all 2M items through item tower nightly. Build IVF-PQ FAISS index (n_list=2000, n_subvectors=16). Total index size: ~32 MB.
- Online: compute user embedding in 3ms, FAISS top-1000 retrieval in 5ms. Total retrieval latency: 10ms.
- Retrain weekly on 90 days of interaction data. Incremental embedding updates for high-velocity new items.
Ranking stage - DIN or DLRM:
- Input: 1,000 candidates from retrieval.
- Features: now can use cross-features that cross the user-item boundary.
- Architecture: Deep Interest Network with attention over the user's relevant history for each candidate. Or DLRM-style if latency is tighter.
- Multi-task heads: click probability, watch completion quartile, share probability. MMoE architecture to share representations across tasks.
- Target latency: 40ms for 1,000 candidates on GPU. Batch all 1,000 into a single forward pass.
Post-processing: maximal marginal relevance for diversity enforcement. Freshness boost for items uploaded in the last 24h. Safety classifier hard filter before ranking results are returned. Business rules for sponsored content insertion.
Monitoring: online metrics (CTR, watch time per session, next-day retention). Retrieval health (Recall@1000 on daily held-out test). Embedding freshness alert if FAISS index is older than 48h. A/B testing infrastructure: shadow deploy model changes, full randomized experiment before full rollout.
Key Takeaways
-
The two-tower constraint is everything: by keeping user and item computations independent, item embeddings can be precomputed offline. This is what makes billion-scale retrieval physically possible.
-
In-batch negatives give you 2000 negatives for free: use them with large batch sizes (1024+). The similarity matrix computation is efficient on GPU.
-
Sampling bias correction is not optional: one line of code that prevents the model from systematically under-ranking popular items. Always include it.
-
Two-tower is for retrieval; joint models are for ranking: they optimize different objectives and have different latency budgets. Build both.
-
FAISS IVF-PQ makes this practical at scale: 32x memory compression with ~1–3% recall cost. Without it, precomputed embeddings are useless for billion-item catalogs.
-
Build cold-start into the item tower from day one: ID dropout during training gives you content-only embeddings for new items at no extra cost.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Embedding Space Explorer demo on the EngineersOfAI Playground - no code required.
:::
