Matryoshka Representation Learning (MRL)
Reading time: 22 min | Relevance: AI Engineer, ML Engineer, Research Engineer
The Storage Problem Nobody Expected
You've indexed 10 million documents with 1536-dimensional float32 embeddings. Storage: 10M × 1536 × 4 bytes = 60 GB. That's manageable. Then usage grows and you add more documents. At 100M documents: 600 GB. At 1B documents: 6 TB. Your vector database cost is now larger than your compute cost.
You could switch to a lower-dimensional model. But lower-dimensional models are generally worse at retrieval. You'd be trading quality for storage. Unless... what if you could use a 64-dimensional embedding for the fast first-pass search across all 1B documents, then rerank the top 1,000 results using the full 1536 dimensions? You'd get the storage efficiency of a tiny embedding and the quality of a large one.
That's exactly what Matryoshka Representation Learning enables. MRL trains a single model that produces nested embeddings: the first 64 dimensions are informative by themselves, and so are the first 128, 256, 512, or 1536 - with each larger prefix being more accurate. You can truncate to any size and get a retrieval quality that scales smoothly with the number of dimensions used.
The name comes from the Russian nesting dolls: each smaller doll fits inside a larger one, and each is complete in itself. The embedding analogy is precise: the first 64 dimensions are a complete (if coarse) representation, contained within the first 128, which are contained within the full 1536.
Historical Context
November 2022 - Kusupati, Bhatt, Rege, et al. at University of Washington and Google publish "Matryoshka Representation Learning" at NeurIPS 2022. The paper introduces MRL for image classification and retrieval using CLIP-style models.
January 2024 - OpenAI releases text-embedding-3-small and text-embedding-3-large, both trained with MRL internally. The API exposes this through the dimensions parameter. OpenAI's technical report notes MRL training as a key design decision.
2024 - MRL becomes standard practice for new embedding models. The technique applies to any contrastive pre-training setup and adds minimal training cost.
The Core Idea: Nested Representations
Standard embedding training: train a model to produce a fixed-size vector , where the entire vector is jointly optimized for a task (retrieval, classification, etc.).
MRL embedding training: train a model to produce a vector , but additionally optimize every prefix for all in a set of target dimensions (e.g., ).
Standard embedding:
Full vector: [d₁, d₂, ..., d₁₅₃₆] ← Optimized jointly
Truncation to 64: [d₁, ..., d₆₄] ← NOT optimized - may be garbage
MRL embedding:
Full vector: [d₁, d₂, ..., d₁₅₃₆] ← Optimized
Prefix 512: [d₁, ..., d₅₁₂] ← ALSO optimized
Prefix 256: [d₁, ..., d₂₅₆] ← ALSO optimized
Prefix 128: [d₁, ..., d₁₂₈] ← ALSO optimized
Prefix 64: [d₁, ..., d₆₄] ← ALSO optimized
Each prefix is a complete, independently useful embedding.
The MRL Training Objective
The MRL loss is a weighted sum of losses at each target dimensionality:
where:
- is the first dimensions of the embedding
- is the task loss (InfoNCE for contrastive learning, cross-entropy for classification)
- are weights for each dimension level (typically for all in the paper)
- is the set of target dimensions
In practice, each loss term uses the same task - the same contrastive retrieval objective - just applied to different-length embedding prefixes.
Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
class MRLEmbeddingModel(nn.Module):
"""
Matryoshka Representation Learning wrapper for any embedding model.
Adds MRL training capability to a standard embedding model.
"""
def __init__(
self,
base_model: str,
mrl_dimensions: list[int] = [64, 128, 256, 512, 1024, 1536],
mrl_weights: list[float] = None,
):
super().__init__()
self.encoder = SentenceTransformer(base_model)
self.mrl_dimensions = sorted(mrl_dimensions)
self.max_dim = max(mrl_dimensions)
self.mrl_weights = mrl_weights or [1.0] * len(mrl_dimensions)
# Verify base model outputs at least max_dim dimensions
test_embedding = self.encoder.encode(["test"])
assert test_embedding.shape[1] >= self.max_dim, \
f"Base model outputs {test_embedding.shape[1]} dims, need {self.max_dim}"
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Produce full-dimension embedding."""
outputs = self.encoder[0](
input_ids=input_ids,
attention_mask=attention_mask
)
token_embeddings = outputs["token_embeddings"]
# Mean pooling
mask = attention_mask.unsqueeze(-1).float()
embeddings = (token_embeddings * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
return embeddings # (batch, full_dim)
def mrl_contrastive_loss(
self,
query_embeddings: torch.Tensor, # (batch, full_dim)
doc_embeddings: torch.Tensor, # (batch, full_dim)
temperature: float = 0.05,
) -> tuple[torch.Tensor, dict]:
"""
Compute MRL loss: weighted sum of InfoNCE losses at each dimensionality.
"""
total_loss = 0.0
loss_per_dim = {}
for dim, weight in zip(self.mrl_dimensions, self.mrl_weights):
# Truncate to this dimensionality level
q = F.normalize(query_embeddings[:, :dim], dim=-1)
d = F.normalize(doc_embeddings[:, :dim], dim=-1)
# InfoNCE loss at this dimensionality
sim_matrix = q @ d.T / temperature # (batch, batch)
labels = torch.arange(q.shape[0], device=q.device)
loss_dim = F.cross_entropy(sim_matrix, labels)
loss_dim += F.cross_entropy(sim_matrix.T, labels)
loss_dim = loss_dim / 2.0
total_loss += weight * loss_dim
loss_per_dim[f"loss_dim_{dim}"] = loss_dim.item()
return total_loss, loss_per_dim
def train_mrl_model(
base_model_name: str,
train_pairs: list[tuple[str, str]], # (query, positive_doc) pairs
mrl_dimensions: list[int] = [64, 128, 256, 512, 1536],
epochs: int = 1,
batch_size: int = 32,
learning_rate: float = 2e-5,
temperature: float = 0.05,
):
"""Train an embedding model with MRL objective."""
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
model = MRLEmbeddingModel(base_model_name, mrl_dimensions)
optimizer = AdamW(model.parameters(), lr=learning_rate)
# (Simplified training loop - production would use gradient accumulation,
# mixed precision, and proper evaluation)
for epoch in range(epochs):
for batch_start in range(0, len(train_pairs), batch_size):
batch = train_pairs[batch_start:batch_start + batch_size]
queries = [p[0] for p in batch]
docs = [p[1] for p in batch]
# Tokenize (simplified)
q_inputs = model.encoder.tokenize(queries)
d_inputs = model.encoder.tokenize(docs)
q_embs = model(q_inputs["input_ids"], q_inputs["attention_mask"])
d_embs = model(d_inputs["input_ids"], d_inputs["attention_mask"])
loss, loss_per_dim = model.mrl_contrastive_loss(q_embs, d_embs, temperature)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model
The Nesting Property: Why It Works
Why does training on the loss sum produce useful truncated embeddings? Two key mechanisms:
1. Information concentration: MRL training forces the model to pack the most important information into the earlier dimensions. Dimension 1 must be informative by itself (because it's the first dimension in every truncated version). Dimension 2 must add the most additional information given dimension 1. This creates a "sorted by importance" structure.
2. Shared optimization: All dimensionality levels use the same parameters. When you optimize for 64-dimensional retrieval, you're also modifying the parameters that affect 1536-dimensional retrieval. The gradients from multiple dimension levels reinforce each other for the shared early dimensions and allow the later dimensions to specialize.
What the model learns:
Dimensions 1-64: Coarse semantic category
(medical vs legal vs technical)
Dimensions 65-256: Topic within category
(cardiology vs oncology vs neurology)
Dimensions 257-512: Specific subtopic
(treatments vs diagnostics vs mechanisms)
Dimensions 513-1536: Fine-grained distinctions
(specific drugs, specific conditions)
Each prefix captures progressively more nuance.
Adaptive Retrieval: The Key Application
MRL enables a two-stage retrieval strategy that achieves near-full-dimension quality at a fraction of the computational cost:
Stage 1 - Coarse search: Embed all documents with truncated dimensions (e.g., 64 dims). Build an index. At query time, embed the query with 64 dims and retrieve top-1000 candidates. Fast: 64-dim similarity is 24× faster than 1536-dim.
Stage 2 - Fine reranking: Load full-dimension embeddings for the top-1000 candidates. Rerank using full 1536-dim similarity. The final top-10 results are close to what you'd get from full-dimension search over the entire corpus.
import numpy as np
import faiss
class AdaptiveMRLRetriever:
"""
Two-stage retrieval using MRL embeddings:
Stage 1: Coarse search with truncated dimensions
Stage 2: Fine reranking with full dimensions
"""
def __init__(
self,
model, # MRL-trained embedding model
coarse_dims: int = 64,
full_dims: int = 1536,
first_stage_k: int = 1000,
):
self.model = model
self.coarse_dims = coarse_dims
self.full_dims = full_dims
self.first_stage_k = first_stage_k
# Two FAISS indices:
# - Coarse: 64-dim, all documents (for fast first-pass)
# - Fine: 1536-dim, for reranking (optionally stored separately)
self.coarse_index = None
self.full_embeddings = None # Store full embeddings for reranking
self.doc_ids = []
def index_documents(self, documents: list[str], doc_ids: list[str] = None):
"""
Index all documents with both coarse and full embeddings.
"""
self.doc_ids = doc_ids or [str(i) for i in range(len(documents))]
print("Encoding documents (full dimensions)...")
full_embs = self.model.encode(
documents,
normalize_embeddings=True,
batch_size=256,
show_progress_bar=True,
) # (n_docs, full_dims)
# Extract coarse embeddings by truncation
coarse_embs = full_embs[:, :self.coarse_dims] # (n_docs, coarse_dims)
coarse_embs = coarse_embs / np.linalg.norm(coarse_embs, axis=1, keepdims=True)
# Store full embeddings for reranking
self.full_embeddings = full_embs # (n_docs, full_dims)
# Build coarse FAISS index
self.coarse_index = faiss.IndexFlatIP(self.coarse_dims)
self.coarse_index.add(coarse_embs.astype(np.float32))
print(f"Indexed {len(documents)} documents")
print(f"Coarse index: {self.coarse_dims} dims, "
f"size = {coarse_embs.nbytes / 1e9:.2f} GB")
print(f"Full embeddings: {self.full_dims} dims, "
f"size = {self.full_embeddings.nbytes / 1e9:.2f} GB")
def search(self, query: str, final_k: int = 10) -> list[dict]:
"""
Two-stage retrieval:
Stage 1: Fast coarse search → top-1000 candidates
Stage 2: Precise reranking with full embeddings → top-k results
"""
# Embed query at full dimensions
query_emb = self.model.encode([query], normalize_embeddings=True)[0]
# Stage 1: Coarse search using truncated query embedding
query_coarse = query_emb[:self.coarse_dims]
query_coarse = query_coarse / np.linalg.norm(query_coarse)
coarse_sims, coarse_indices = self.coarse_index.search(
query_coarse.reshape(1, -1).astype(np.float32),
self.first_stage_k
)
coarse_indices = coarse_indices[0] # (first_stage_k,)
# Stage 2: Fine reranking with full embeddings
candidate_full_embs = self.full_embeddings[coarse_indices] # (k, full_dims)
full_sims = candidate_full_embs @ query_emb # (k,) dot products
# Sort candidates by full-dimension similarity
rerank_order = np.argsort(-full_sims)
top_indices = coarse_indices[rerank_order[:final_k]]
top_sims = full_sims[rerank_order[:final_k]]
return [
{
"doc_id": self.doc_ids[idx],
"score": float(sim),
"rank": rank + 1
}
for rank, (idx, sim) in enumerate(zip(top_indices, top_sims))
]
Quality vs Cost Trade-Off
The key empirical result from the original MRL paper (Kusupati et al. 2022) on image retrieval, and replicated for text by subsequent work:
Retrieval quality (nDCG@10) vs embedding dimensions:
Full 1536-dim search: nDCG@10 = 0.85 (baseline)
MRL 512-dim search: nDCG@10 = 0.83 (-2.4%)
MRL 256-dim search: nDCG@10 = 0.80 (-5.9%)
MRL 128-dim search: nDCG@10 = 0.76 (-10.6%)
MRL 64-dim search: nDCG@10 = 0.70 (-17.6%)
MRL adaptive retrieval (64-dim coarse → 1536-dim rerank):
nDCG@10 = 0.84 (-1.2%) ← Near full quality!
Standard model 64-dim (no MRL):
nDCG@10 = 0.45 (-47%) ← Terrible
FLOP comparison (1B documents, top-10 results):
Full 1536-dim: 1B × 1536 ops = 1.536T ops
MRL adaptive: 1B × 64 ops + 1000 × 1536 ops ≈ 65.5B ops
= 23× reduction in FLOPs, <1% quality loss
This is the key result: MRL adaptive retrieval achieves near-full-dimension quality at ~23× lower computational cost for large corpora.
Using MRL with OpenAI's text-embedding-3
OpenAI's text-embedding-3 models use MRL internally. You access this via the dimensions parameter:
from openai import OpenAI
import numpy as np
client = OpenAI()
def embed_with_mrl(
texts: list[str],
model: str = "text-embedding-3-small",
dimensions: int = 512, # Request truncated embedding
) -> np.ndarray:
"""
Get truncated MRL embeddings from OpenAI.
The API truncates internally and renormalizes.
"""
response = client.embeddings.create(
model=model,
input=texts,
dimensions=dimensions,
)
return np.array([item.embedding for item in response.data])
# Quality comparison at different dimensions
# (approximate based on OpenAI's benchmarks)
dimension_quality = {
3072: 1.000, # text-embedding-3-large, full
1536: 0.989, # text-embedding-3-large, half dims
1024: 0.975,
512: 0.950,
256: 0.915,
64: 0.810, # Still usable for coarse retrieval!
}
# Storage savings
full_dim = 3072
bytes_per_float32 = 4
n_docs = 10_000_000 # 10M documents
print("Storage for 10M documents:")
for dims, relative_quality in dimension_quality.items():
storage_gb = n_docs * dims * bytes_per_float32 / 1e9
print(f" {dims:>5} dims ({relative_quality:.1%} quality): {storage_gb:.1f} GB")
# Storage for 10M documents:
# 3072 dims (100.0% quality): 122.9 GB
# 1536 dims (98.9% quality): 61.4 GB
# 1024 dims (97.5% quality): 41.0 GB
# 512 dims (95.0% quality): 20.5 GB
# 256 dims (91.5% quality): 10.2 GB
# 64 dims (81.0% quality): 2.6 GB ← 47× storage reduction vs full
# Implementing adaptive retrieval with OpenAI embeddings
class OpenAIMRLRetriever:
"""Adaptive retrieval using OpenAI's MRL embeddings."""
def __init__(
self,
model: str = "text-embedding-3-large",
coarse_dims: int = 64,
full_dims: int = 3072,
first_stage_k: int = 500,
):
self.model = model
self.coarse_dims = coarse_dims
self.full_dims = full_dims
self.first_stage_k = first_stage_k
def embed_coarse(self, texts: list[str]) -> np.ndarray:
return embed_with_mrl(texts, self.model, self.coarse_dims)
def embed_full(self, texts: list[str]) -> np.ndarray:
return embed_with_mrl(texts, self.model, self.full_dims)
def build_index(self, documents: list[str]):
"""Index with coarse embeddings for fast search."""
coarse_embs = self.embed_coarse(documents)
# Build FAISS index on coarse embeddings
# ... (same as above)
self.coarse_index = faiss.IndexFlatIP(self.coarse_dims)
self.coarse_index.add(coarse_embs.astype(np.float32))
# Pre-compute full embeddings for reranking
self.full_embs = self.embed_full(documents)
def search(self, query: str, final_k: int = 10) -> list[dict]:
# Stage 1: Coarse search
query_coarse = self.embed_coarse([query])[0]
_, candidate_indices = self.coarse_index.search(
query_coarse.reshape(1, -1).astype(np.float32),
self.first_stage_k
)
# Stage 2: Full reranking
query_full = self.embed_full([query])[0]
candidate_embs = self.full_embs[candidate_indices[0]]
full_sims = candidate_embs @ query_full
rerank_order = np.argsort(-full_sims)
return [candidate_indices[0][i] for i in rerank_order[:final_k]]
Practical Truncation Without MRL
If you have a standard (non-MRL) embedding model, can you truncate embeddings?
Yes, but with significant quality loss. The first dimensions of a standard model are not specially optimized - they're an arbitrary subset of the learned representation. Truncating a standard 1536-dim model to 256 dims may lose 40-50% of retrieval quality.
PCA reduction: A better approach for standard models is PCA - project embeddings to a lower-dimensional space that captures maximum variance. This works better than truncation but is still significantly worse than MRL:
from sklearn.decomposition import PCA
import numpy as np
def pca_reduce_embeddings(
train_embeddings: np.ndarray, # Sample of embeddings to fit PCA
target_dims: int = 256,
) -> tuple:
"""
Fit PCA on a sample of embeddings and return the fitted PCA.
Apply to new embeddings using pca.transform().
"""
pca = PCA(n_components=target_dims)
pca.fit(train_embeddings)
explained_variance = pca.explained_variance_ratio_.sum()
print(f"PCA {target_dims} dims captures {explained_variance:.1%} of variance")
return pca
# PCA vs MRL truncation quality (approximate):
# PCA 256-dim from 1536: ~80% quality
# MRL truncation 256-dim: ~91% quality
# MRL is significantly better
The bottom line: if you need compact embeddings, use a model trained with MRL (text-embedding-3, or a fine-tuned model with MRL loss). PCA reduction is a fallback when you have a deployed non-MRL model you can't retrain.
Common Mistakes
:::danger Truncating embeddings from non-MRL models Truncating a standard embedding model's output to fewer dimensions produces poor-quality representations - the early dimensions are not more informative than later ones. Quality loss is typically 40-60% for aggressive truncation. Only truncate embeddings from models explicitly trained with MRL (OpenAI text-embedding-3, or your own MRL-trained model). :::
:::danger Using the same dimensions for first-pass and reranking In adaptive MRL retrieval, the power comes from using very small dimensions (64-128) for the first-pass search over the full corpus, then reranking with full dimensions. Using 512 dimensions for both stages gets you the storage savings of the first pass without the quality of the full-dimension reranking. Set coarse_dims to the smallest dimension that maintains reasonable recall (try 64, 128 - measure recall on your benchmark). :::
:::warning Ignoring the first_stage_k parameter The number of candidates you retrieve in the first stage determines your recall ceiling for the final results. If you retrieve 100 candidates and the true answer is at position 150 in coarse-dim space (due to information loss), you'll miss it in the final ranking. Set first_stage_k conservatively (1000 is typical). The compute cost of reranking 1000 candidates at full dimensions is trivial compared to the first-stage search cost. :::
:::tip Use MRL in production for hybrid storage strategies Store documents at multiple granularities: 64-dim embeddings on fast (expensive) storage for real-time search, and 1536-dim embeddings on cold (cheap) storage for reranking. The 64-dim index for 1B documents is ~2.5 GB (fits in GPU RAM for ultra-fast search). The full-dim embeddings for the top-1000 candidates can be fetched from cold storage in 10-100ms. :::
Interview Q&A
Q1: What is Matryoshka Representation Learning and why is it useful?
MRL trains embedding models such that every prefix of the embedding vector is informative - the first 64 dimensions form a useful (if coarse) embedding, the first 128 are a better one, and so on up to the full dimensionality. This is achieved by computing the task loss (InfoNCE for retrieval) at multiple dimension levels simultaneously and summing them.
The practical value: you can truncate MRL embeddings to smaller dimensions and still get useful retrieval quality. This enables adaptive retrieval - use 64-dim embeddings for fast search over 1B documents, then rerank the top-1000 results with 1536-dim embeddings. The result is 23× fewer FLOPs with less than 1% quality loss compared to full-dimension search. It also enables storage optimization: index small dimensions on fast/expensive storage, keep full dimensions on cold storage.
Q2: How does MRL training differ from standard embedding training?
Standard training: compute one loss for the full-dimension embedding. MRL training: compute a separate loss for each target dimension in a set (e.g., {64, 128, 256, 512, 1536}) and sum them with weights. The total loss is . This forces every prefix to be independently informative, not just the full vector. The training overhead is proportional to the number of dimension levels (typically 20-50% more compute than standard training).
Q3: What is adaptive retrieval and what performance gains does it achieve?
Adaptive retrieval uses MRL embeddings in two stages. First, embed all documents with truncated dimensions (e.g., 64) and retrieve top-1000 candidates using fast approximate nearest neighbor search. Second, rerank those 1000 candidates using full-dimension embeddings.
From Kusupati et al. (2022) and subsequent work: adaptive retrieval with 64-dim first pass → 1536-dim reranking achieves 98-99% of the retrieval quality of full-dimension search over the entire corpus, while requiring only ~4% of the FLOPs (24× fewer operations). For 1B documents, this is the difference between impractical full-dimension search and a system that runs in real-time.
Q4: How does OpenAI's text-embedding-3 use MRL, and how do you access it?
text-embedding-3-small and text-embedding-3-large are trained with MRL internally, so their output embeddings can be safely truncated to any smaller dimension. The API exposes this via the dimensions parameter: client.embeddings.create(model="text-embedding-3-large", input=texts, dimensions=512) returns 512-dimensional embeddings that have been trained to be maximally informative at that size. You pay the same per-token price regardless of dimensions requested - dimension reduction is free.
Q5: Can you truncate embeddings from models that weren't trained with MRL?
Yes, but with much worse results. Standard model embeddings don't have the property that early dimensions are more informative than later ones - all dimensions were optimized jointly. Truncating a 1536-dim standard model to 256 dims loses ~40-60% of retrieval quality versus only ~9% for an MRL-trained model truncated to the same size.
Alternatives for standard models: PCA reduction preserves maximum variance and typically outperforms truncation (achieves ~80% quality at 256-dims from 1536), but still falls well short of MRL. If you need compact embeddings, use an MRL-trained model - the training cost of MRL is low (20-50% overhead), but the benefits are large.
Summary
Matryoshka Representation Learning trains embedding models with nested representations: any prefix of dimensions is independently informative.
Key facts:
- MRL training objective: Weighted sum of contrastive losses at each target dimensionality -
- Adaptive retrieval: 64-dim first pass over all documents, 1536-dim reranking of top candidates. Achieves ~99% quality at ~4% the compute cost.
- OpenAI text-embedding-3: Uses MRL internally; access via
dimensionsAPI parameter at no additional cost - Storage impact: 64-dim MRL embeddings use 24× less storage than 1536-dim, with only ~17% quality loss on coarse search
- Non-MRL truncation: Produces 40-60% quality loss - only truncate MRL-trained models
MRL is the right approach whenever you need to trade off dimensionality for cost, speed, or storage - which in production systems is almost always.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Matryoshka Representation Learning demo on the EngineersOfAI Playground - no code required.
:::
