:::tip 🎮 Interactive Playground Visualize this concept: Try the Knowledge Distillation demo on the EngineersOfAI Playground - no code required. :::
Knowledge Distillation: Training Smaller Models to Think Like Larger Ones
The Impossible Requirements Document
The product brief lands on a Monday morning. Your company has been running GPT-4-class inference at $0.03 per thousand tokens for an enterprise customer who uses the system to triage incoming support tickets - about 50,000 per day. The customer's renewal is coming up. The account executive comes to you: "They're asking for a 90% cost reduction and a commitment to 10x lower latency. They'll walk otherwise."
You do the math. 1,500 per day → 150 per day. At $0.001 per thousand tokens, you need a model that costs 30x less than GPT-4. That means a model roughly 10-30x smaller. But it needs to work well enough on your specific task - support ticket triage - that the customer doesn't notice the quality difference.
This is exactly the problem knowledge distillation was designed to solve. You have a large, expensive, capable teacher. You have a target task with a specific distribution. You have a quality bar to hit. Distillation trains a small, cheap student to match the teacher's quality on your specific task - often achieving 90-97% of the teacher's performance with 10-30x fewer parameters, at a fraction of the inference cost.
Three weeks later, a 2B-parameter fine-tuned and distilled model is in production. Task-specific accuracy: 94% of GPT-4-class performance on ticket triage. Inference cost: $0.0008 per thousand tokens. Latency: 85ms versus 4.2 seconds. The customer renews. The distillation project delivers more ROI than any other engineering initiative that quarter. This is the real-world case for distillation - not academic benchmark chasing, but a concrete business survival tool.
Why Training from Scratch Leaves Knowledge on the Table
Training a small model from scratch on labeled data gives it only hard labels: "ticket category is 'payment_issue.'" No information about the teacher's uncertainty, no information about which other categories are plausible, no information about how confident the teacher is.
A teacher model's output distribution contains far more information:
Hard label (one-hot encoding):
billing: 1.0
shipping: 0.0
technical: 0.0
account: 0.0
refund: 0.0
Teacher soft distribution at T=4:
billing: 0.62 ← clearly the right answer
refund: 0.21 ← related - ticket mentions a charge dispute
account: 0.09 ← plausible - might be an account billing issue
technical: 0.05 ← less likely but possible
shipping: 0.03 ← probably not
What the soft distribution teaches the student:
1. billing and refund are semantically similar
2. account tickets overlap with billing in some dimensions
3. The billing/refund decision boundary is nearby (important for generalization)
4. shipping is far from billing in the semantic space
5. This training example is 0.62 confident - not ambiguous, but not certain
This additional information - called "dark knowledge" by Hinton et al. - is encoded in the teacher's non-maximum probabilities. A student trained on soft labels learns the similarity structure of the problem, not just the decision boundary. This is why distillation consistently outperforms training from scratch at the same parameter count: the student receives richer supervision signal per training example.
The effect is substantial in landmark systems: DistilBERT achieved 97% of BERT's performance with 40% fewer parameters. TinyBERT achieved 96.8% of BERT-large performance with 7.5x fewer parameters. For LLMs, Mistral 7B was partially trained with distillation from Llama 2 70B. These are not marginal improvements - they represent the most efficient known path to high-quality small models when you have the training budget.
Temperature Scaling: Controlling the Richness of Soft Labels
The softmax temperature controls how peaked or flat the teacher's output distribution is:
At (standard inference), the distribution is peaked around the correct class. At higher temperatures, the distribution softens, spreading probability to similar classes. This softer distribution carries more information about similarity structure:
The sweet spot for distillation is typically . Too low: nearly equivalent to hard labels. Too high: distribution becomes nearly uniform, washing out the similarity signal. Always tune as a hyperparameter - it is the most sensitive configuration choice in distillation.
The Distillation Loss: KL Divergence + Task Loss
The complete distillation loss combines two terms:
1. KL divergence loss - minimize divergence between student and teacher soft distributions at temperature :
\mathcal{L}_{\text{KL}} = T^2 \cdot \text{KL}\!\left(p_S^{(T)} \, \big\| \, p_T^{(T)}\right) = T^2 \sum_k p_T^{(T)}_k \log \frac{p_T^{(T)}_k}{p_S^{(T)}_k}
2. Task loss - standard cross-entropy against ground truth labels:
\mathcal{L}_{\text{task}} = -\sum_k y_k \log p_S^{(1)}_k = \text{CrossEntropy}(z_S, y)
3. Combined loss - weighted average:
The multiplier on the KL loss is critical. At high temperature, both the student and teacher logit differences scale down by . The KL divergence of the temperature-scaled softmax outputs scales as compared to . Without the correction, the distillation loss produces gradients that are times smaller than the task loss, effectively making the distillation signal negligible at high temperature. The factor restores gradient magnitude so that remains the controlling tradeoff between the two objectives.
The Distillation Loss Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import Optional, List, Tuple, Dict
@dataclass
class DistillationConfig:
"""Configuration for knowledge distillation training."""
temperature: float = 4.0 # Softmax temperature for soft labels
alpha: float = 0.7 # Weight on KL loss vs task CE loss (0.7 = 70% KL)
intermediate_layers: bool = True # Match intermediate hidden states (TinyBERT-style)
attention_matching: bool = True # Match attention patterns
layer_mapping: str = "uniform" # "uniform" = evenly spaced, "last" = last N layers
hidden_loss_weight: float = 0.1 # Weight for intermediate representation loss
attention_loss_weight: float = 0.1 # Weight for attention matrix loss
class DistillationLoss(nn.Module):
"""
Knowledge distillation loss combining:
1. KL divergence between student and teacher soft distributions (primary signal)
2. Cross-entropy against hard labels (prevents drift from teacher errors)
3. Optional: MSE between intermediate hidden states (TinyBERT-style)
4. Optional: MSE between attention matrices (attention transfer)
Configuration:
alpha=0.7: 70% distillation + 30% task loss - balanced default
alpha=1.0: pure distillation - only if teacher labels are very high quality
alpha=0.0: pure task loss - equivalent to no distillation
"""
def __init__(self, config: DistillationConfig):
super().__init__()
self.config = config
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
self.mse_loss = nn.MSELoss()
def forward(
self,
student_logits: torch.Tensor, # (batch, n_classes) or (batch, seq, vocab)
teacher_logits: torch.Tensor, # Same shape as student_logits
hard_labels: torch.Tensor, # (batch,) or (batch, seq)
student_hiddens: Optional[List[torch.Tensor]] = None,
teacher_hiddens: Optional[List[torch.Tensor]] = None,
student_attentions: Optional[List[torch.Tensor]] = None,
teacher_attentions: Optional[List[torch.Tensor]] = None,
hidden_projectors: Optional[List[nn.Linear]] = None,
) -> dict:
T = self.config.temperature
alpha = self.config.alpha
# --- 1. Task loss: standard cross-entropy on hard labels ---
task_loss = self.ce_loss(student_logits, hard_labels)
# --- 2. Distillation loss: KL divergence at temperature T ---
# Flatten sequence dimension if needed (for LM pretraining)
if student_logits.dim() == 3:
s_logits_flat = student_logits.view(-1, student_logits.size(-1))
t_logits_flat = teacher_logits.view(-1, teacher_logits.size(-1))
else:
s_logits_flat = student_logits
t_logits_flat = teacher_logits
soft_student = F.log_softmax(s_logits_flat / T, dim=-1)
soft_teacher = F.softmax(t_logits_flat / T, dim=-1)
# KL(teacher || student): we want student to match teacher
# F.kl_div(log_input, target): computes sum(target * (log_target - log_input))
kl_loss = F.kl_div(
soft_student,
soft_teacher,
reduction="batchmean",
) * (T ** 2) # T^2 restores gradient magnitude across temperatures
# --- 3. Combined output loss ---
total_loss = alpha * kl_loss + (1.0 - alpha) * task_loss
losses = {
"task_loss": task_loss.item(),
"kl_loss": kl_loss.item(),
"total_loss": total_loss,
}
# --- 4. Intermediate representation matching (TinyBERT-style) ---
if (self.config.intermediate_layers
and student_hiddens is not None
and teacher_hiddens is not None):
hidden_loss = self._compute_hidden_loss(
student_hiddens, teacher_hiddens, hidden_projectors
)
total_loss = total_loss + self.config.hidden_loss_weight * hidden_loss
losses["hidden_loss"] = hidden_loss.item()
# --- 5. Attention matrix matching ---
if (self.config.attention_matching
and student_attentions is not None
and teacher_attentions is not None):
attn_loss = self._compute_attention_loss(
student_attentions, teacher_attentions
)
total_loss = total_loss + self.config.attention_loss_weight * attn_loss
losses["attention_loss"] = attn_loss.item()
losses["total_loss"] = total_loss
return losses
def _compute_hidden_loss(
self,
student_hiddens: List[torch.Tensor],
teacher_hiddens: List[torch.Tensor],
projectors: Optional[List[nn.Linear]],
) -> torch.Tensor:
"""
Compute MSE loss between student and teacher hidden states.
Student and teacher typically have different hidden dimensions
(e.g., student: 768, teacher: 4096). A linear projector maps
student hidden states to the teacher's dimensionality.
Layer mapping (uniform): Map student layers evenly to teacher layers.
Student layer 0 → Teacher layer 0
Student layer k → Teacher layer k * (n_teacher / n_student)
...
"""
n_student = len(student_hiddens)
n_teacher = len(teacher_hiddens)
if n_student == 0 or n_teacher == 0:
return torch.tensor(0.0)
total_hidden_loss = torch.tensor(0.0, device=student_hiddens[0].device)
for i, s_hidden in enumerate(student_hiddens):
# Map student layer i to corresponding teacher layer
if self.config.layer_mapping == "uniform":
teacher_idx = int(i * n_teacher / n_student)
else: # "last" - map to the last n_student teacher layers
teacher_idx = n_teacher - n_student + i
teacher_idx = min(teacher_idx, n_teacher - 1)
t_hidden = teacher_hiddens[teacher_idx].detach() # No gradient through teacher
# Project student hidden if dimensions differ
if projectors is not None and i < len(projectors):
s_projected = projectors[i](s_hidden)
elif s_hidden.shape[-1] == t_hidden.shape[-1]:
s_projected = s_hidden
else:
# Skip if no projector and dimensions mismatch
continue
# MSE over all token positions and batch
layer_loss = self.mse_loss(s_projected, t_hidden)
total_hidden_loss = total_hidden_loss + layer_loss
return total_hidden_loss / max(n_student, 1)
def _compute_attention_loss(
self,
student_attentions: List[torch.Tensor],
teacher_attentions: List[torch.Tensor],
) -> torch.Tensor:
"""
Compute MSE loss between student and teacher attention patterns.
Attention patterns capture structural information about which
tokens attend to which - syntax, coreference, semantic relations.
Matching them helps the student learn the same inductive biases.
Note: Student and teacher may have different numbers of attention heads.
If so, skip layers where head counts don't match (or average heads first).
"""
n_student = len(student_attentions)
n_teacher = len(teacher_attentions)
if n_student == 0 or n_teacher == 0:
return torch.tensor(0.0)
total_attn_loss = torch.tensor(0.0, device=student_attentions[0].device)
matched_layers = 0
for i, s_attn in enumerate(student_attentions):
teacher_idx = int(i * n_teacher / n_student)
teacher_idx = min(teacher_idx, n_teacher - 1)
t_attn = teacher_attentions[teacher_idx].detach()
# Attention shape: (batch, n_heads, seq, seq)
# If head counts differ, cannot directly compare - skip or average
if s_attn.shape[1] != t_attn.shape[1]:
# Average over heads to get (batch, seq, seq) and compare
s_attn = s_attn.mean(dim=1) # (batch, seq, seq)
t_attn = t_attn.mean(dim=1)
total_attn_loss = total_attn_loss + self.mse_loss(s_attn, t_attn)
matched_layers += 1
return total_attn_loss / max(matched_layers, 1)
Distillation Variants: Response, Feature, and API-Based
API-Based Distillation: Using GPT-4 as a Teacher
When the teacher is a proprietary API (GPT-4, Claude, Gemini), you cannot access internal logits. You must use hard labels from generated outputs. The process:
- Collect task inputs: gather real or synthetic inputs from your deployment domain
- Generate teacher outputs: run each input through the teacher API, collecting both the response and (if available) confidence or ranking information
- Fine-tune a small student: train a 1-3B model on the teacher-generated (input, output) pairs
This is sometimes called "data distillation" or "data augmentation via teacher" - you are using the teacher's knowledge to create a high-quality fine-tuning dataset, not to provide soft labels during training. The resulting technique is simpler than true distillation but often effective for task-specific deployment.
import anthropic
import openai
from typing import List, Dict, Tuple
import json
import asyncio
import aiohttp
from dataclasses import dataclass
@dataclass
class TaskExample:
"""A single input-output pair for distillation training."""
input_text: str
teacher_output: str
metadata: Dict = None
class APIDistillationDataCollector:
"""
Collect teacher outputs from proprietary APIs for distillation.
This implements the "API distillation" pattern:
1. Sample diverse inputs from your task distribution
2. Query the teacher API for high-quality outputs
3. Use these (input, output) pairs to fine-tune a small student
The quality of the final distilled model depends heavily on:
- Quality of input diversity (cover the full task distribution)
- Quality of teacher outputs (teacher must be good at your task)
- Quantity of examples (more is better, min ~1000 for simple tasks)
"""
def __init__(
self,
teacher_provider: str = "openai", # "openai", "anthropic"
teacher_model: str = "gpt-4o",
system_prompt: str = "",
temperature: float = 0.0,
max_tokens: int = 1024,
):
self.teacher_provider = teacher_provider
self.teacher_model = teacher_model
self.system_prompt = system_prompt
self.temperature = temperature
self.max_tokens = max_tokens
if teacher_provider == "openai":
self.client = openai.AsyncOpenAI()
elif teacher_provider == "anthropic":
self.client = anthropic.AsyncAnthropic()
async def query_teacher_single(self, input_text: str) -> str:
"""Query the teacher API for a single input."""
try:
if self.teacher_provider == "openai":
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.append({"role": "user", "content": input_text})
response = await self.client.chat.completions.create(
model=self.teacher_model,
messages=messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
return response.choices[0].message.content
elif self.teacher_provider == "anthropic":
response = await self.client.messages.create(
model=self.teacher_model,
system=self.system_prompt,
messages=[{"role": "user", "content": input_text}],
max_tokens=self.max_tokens,
)
return response.content[0].text
except Exception as e:
print(f"API error for input '{input_text[:50]}...': {e}")
return ""
async def collect_distillation_data(
self,
input_texts: List[str],
output_path: str,
batch_size: int = 10,
max_concurrent: int = 5,
) -> List[TaskExample]:
"""
Collect teacher outputs for all inputs, with rate limiting.
Args:
input_texts: List of task-representative inputs
output_path: JSONL file path to save collected examples
batch_size: Examples to process per batch (for checkpointing)
max_concurrent: Max concurrent API calls (respect rate limits)
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def bounded_query(text: str) -> Tuple[str, str]:
async with semaphore:
output = await self.query_teacher_single(text)
return text, output
all_examples = []
total = len(input_texts)
print(f"Collecting teacher outputs for {total} inputs...")
print(f" Teacher: {self.teacher_provider}/{self.teacher_model}")
print(f" Max concurrent: {max_concurrent}")
for batch_start in range(0, total, batch_size):
batch_end = min(batch_start + batch_size, total)
batch = input_texts[batch_start:batch_end]
tasks = [bounded_query(text) for text in batch]
results = await asyncio.gather(*tasks)
batch_examples = []
for inp, out in results:
if out: # Skip failed queries
example = TaskExample(input_text=inp, teacher_output=out)
batch_examples.append(example)
all_examples.extend(batch_examples)
# Checkpoint: save after each batch in case of interruption
with open(output_path, "a") as f:
for ex in batch_examples:
record = {"input": ex.input_text, "output": ex.teacher_output}
f.write(json.dumps(record) + "\n")
success_rate = len(batch_examples) / len(batch) * 100
print(
f" Batch {batch_start//batch_size + 1}: "
f"{len(batch_examples)}/{len(batch)} succeeded ({success_rate:.0f}%)"
)
print(f"\nCollection complete: {len(all_examples)}/{total} examples")
print(f" Saved to: {output_path}")
return all_examples
def estimate_api_distillation_cost(
n_examples: int,
avg_input_tokens: int = 200,
avg_output_tokens: int = 300,
teacher_model: str = "gpt-4o",
) -> dict:
"""
Estimate cost of collecting teacher data from API.
Costs are approximate (verify at platform pricing page):
- GPT-4o: $0.005/1K input, $0.015/1K output (as of late 2024)
- GPT-4o-mini: $0.00015/1K input, $0.0006/1K output
- Claude 3.5 Sonnet: $0.003/1K input, $0.015/1K output
"""
pricing = {
"gpt-4o": (0.005, 0.015), # (input, output) per 1K tokens
"gpt-4o-mini": (0.00015, 0.0006),
"claude-3-5-sonnet": (0.003, 0.015),
"claude-3-haiku": (0.00025, 0.00125),
}
inp_price, out_price = pricing.get(teacher_model, (0.01, 0.03))
total_input_cost = n_examples * avg_input_tokens * inp_price / 1000
total_output_cost = n_examples * avg_output_tokens * out_price / 1000
total_cost = total_input_cost + total_output_cost
return {
"n_examples": n_examples,
"teacher_model": teacher_model,
"input_tokens": n_examples * avg_input_tokens,
"output_tokens": n_examples * avg_output_tokens,
"input_cost_usd": round(total_input_cost, 2),
"output_cost_usd": round(total_output_cost, 2),
"total_cost_usd": round(total_cost, 2),
"cost_per_example_cents": round(total_cost / n_examples * 100, 3),
}
Full Training Pipeline: Student Fine-Tuning with Distillation
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from transformers import (
AutoModelForSequenceClassification,
AutoModelForCausalLM,
AutoTokenizer,
get_cosine_schedule_with_warmup,
)
from torch.utils.data import DataLoader, Dataset
from typing import Optional, Dict, List
import wandb
class DistillationDataset(Dataset):
"""
Dataset for response distillation training.
Returns (inputs, teacher_logits, hard_labels) for each example.
Teacher logits are pre-computed and cached - don't recompute them
every epoch. Computing teacher logits during training doubles memory usage.
"""
def __init__(
self,
texts: List[str],
labels: List[int],
teacher_logits: torch.Tensor, # Pre-computed teacher outputs
tokenizer,
max_length: int = 512,
):
self.texts = texts
self.labels = labels
self.teacher_logits = teacher_logits
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, idx: int) -> dict:
encoding = self.tokenizer(
self.texts[idx],
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": torch.tensor(self.labels[idx], dtype=torch.long),
"teacher_logits": self.teacher_logits[idx],
}
def precompute_teacher_logits(
teacher_model_name: str,
texts: List[str],
tokenizer,
batch_size: int = 16,
device: str = "cuda",
max_length: int = 512,
) -> torch.Tensor:
"""
Pre-compute teacher logits for all training examples.
Pre-computing is much more efficient than computing teacher logits
during each training step:
- Computes once vs. once per epoch (3-10x reduction in teacher compute)
- Teacher runs in inference mode - no gradient computation overhead
- Results can be cached to disk for reuse across experiments
Returns:
teacher_logits: Tensor of shape (n_examples, n_classes)
"""
print(f"Pre-computing teacher logits with {teacher_model_name}...")
teacher = AutoModelForSequenceClassification.from_pretrained(
teacher_model_name,
torch_dtype=torch.float16,
).to(device)
teacher.eval()
all_logits = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
encoding = tokenizer(
batch_texts,
max_length=max_length,
truncation=True,
padding=True,
return_tensors="pt",
)
encoding = {k: v.to(device) for k, v in encoding.items()}
with torch.no_grad():
outputs = teacher(**encoding)
logits = outputs.logits.float().cpu()
all_logits.append(logits)
if (i // batch_size) % 20 == 0:
print(f" {i}/{len(texts)} examples processed")
del teacher
torch.cuda.empty_cache()
all_logits_tensor = torch.cat(all_logits, dim=0)
print(f" Teacher logits computed: shape {all_logits_tensor.shape}")
return all_logits_tensor
def train_student_with_distillation(
student_model_name: str,
train_dataset: DistillationDataset,
val_dataset: DistillationDataset,
distillation_config: DistillationConfig,
n_epochs: int = 5,
learning_rate: float = 2e-5,
batch_size: int = 32,
warmup_ratio: float = 0.1,
device: str = "cuda",
output_dir: str = "./distilled_model",
use_wandb: bool = False,
) -> nn.Module:
"""
Full distillation training loop with validation and checkpointing.
Training strategy:
1. Load a small student model (initialize from pretrained weights)
2. Train with combined distillation + task loss
3. Validate every epoch and save best checkpoint
4. Return the best student model
Note: Initialize student from a pretrained model of the same architecture
(e.g., DistilBERT initialized from BERT-base), not from scratch.
Pretrained initialization dramatically improves final accuracy.
"""
# Load student model
n_classes = train_dataset.teacher_logits.shape[1]
student = AutoModelForSequenceClassification.from_pretrained(
student_model_name,
num_labels=n_classes,
).to(device)
# Load projectors if doing intermediate distillation with different hidden dims
# (Only needed for feature distillation with mismatched architectures)
distillation_loss_fn = DistillationLoss(distillation_config)
# Data loaders
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=4
)
# Optimizer: AdamW with weight decay on non-bias parameters
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in student.named_parameters()
if not any(nd in n for nd in no_decay)],
"weight_decay": 0.01,
},
{
"params": [p for n, p in student.named_parameters()
if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
total_steps = len(train_loader) * n_epochs
warmup_steps = int(total_steps * warmup_ratio)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
if use_wandb:
wandb.init(project="distillation", config={
"student_model": student_model_name,
"temperature": distillation_config.temperature,
"alpha": distillation_config.alpha,
"n_epochs": n_epochs,
"learning_rate": learning_rate,
})
best_val_acc = 0.0
best_model_path = f"{output_dir}/best_checkpoint"
for epoch in range(n_epochs):
# --- Training ---
student.train()
train_losses = {"task": [], "kl": [], "total": []}
for batch in train_loader:
optimizer.zero_grad()
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
teacher_logits = batch["teacher_logits"].to(device)
# Forward pass through student
student_outputs = student(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=distillation_config.intermediate_layers,
output_attentions=distillation_config.attention_matching,
)
student_logits = student_outputs.logits
# Compute distillation loss
loss_dict = distillation_loss_fn(
student_logits=student_logits,
teacher_logits=teacher_logits,
hard_labels=labels,
)
loss = loss_dict["total_loss"]
loss.backward()
# Gradient clipping: prevents exploding gradients, critical in distillation
torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
train_losses["task"].append(loss_dict["task_loss"])
train_losses["kl"].append(loss_dict["kl_loss"])
train_losses["total"].append(loss.item())
avg_task_loss = sum(train_losses["task"]) / len(train_losses["task"])
avg_kl_loss = sum(train_losses["kl"]) / len(train_losses["kl"])
avg_total_loss = sum(train_losses["total"]) / len(train_losses["total"])
# --- Validation ---
student.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = student(input_ids=input_ids, attention_mask=attention_mask)
predictions = outputs.logits.argmax(dim=-1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
val_acc = correct / total
print(
f"Epoch {epoch+1}/{n_epochs}: "
f"task={avg_task_loss:.4f}, "
f"kl={avg_kl_loss:.4f}, "
f"total={avg_total_loss:.4f}, "
f"val_acc={val_acc:.4f}"
)
if use_wandb:
wandb.log({
"train/task_loss": avg_task_loss,
"train/kl_loss": avg_kl_loss,
"train/total_loss": avg_total_loss,
"val/accuracy": val_acc,
"epoch": epoch + 1,
})
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
student.save_pretrained(best_model_path)
print(f" New best val_acc: {best_val_acc:.4f} - checkpoint saved")
print(f"\nTraining complete. Best val_acc: {best_val_acc:.4f}")
print(f"Best model: {best_model_path}")
if use_wandb:
wandb.finish()
return student
LLM Distillation: Sequence-Level and Token-Level
For generative LLMs (not classifiers), distillation operates differently. The teacher generates sequences, and the student trains to predict each token matching the teacher's probability distribution:
def compute_token_level_distillation_loss(
student_logits: torch.Tensor, # (batch, seq_len, vocab_size)
teacher_logits: torch.Tensor, # (batch, seq_len, vocab_size)
attention_mask: torch.Tensor, # (batch, seq_len) - which tokens to include
temperature: float = 4.0,
alpha: float = 0.7,
label_smoothing: float = 0.0,
) -> Dict[str, torch.Tensor]:
"""
Token-level distillation loss for autoregressive language models.
At each position, the student learns to match the teacher's full
next-token probability distribution, not just the argmax token.
This is richer than next-token prediction with hard targets alone.
The "forward KL" is standard here: KL(teacher || student).
The student is updated to cover the teacher's distribution.
For open-ended generation, some papers use "reverse KL": KL(student || teacher).
Reverse KL causes the student to focus on modes of the teacher distribution
rather than covering all modes - useful when you want concise outputs.
"""
batch_size, seq_len, vocab_size = student_logits.shape
# Shift: for LM training, predict position i+1 from position i
# student predicts [1..seq_len-1] from [0..seq_len-2]
student_shift = student_logits[:, :-1, :].contiguous().view(-1, vocab_size)
teacher_shift = teacher_logits[:, :-1, :].contiguous().view(-1, vocab_size)
mask_shift = attention_mask[:, 1:].contiguous().view(-1).bool()
# Apply mask: only compute loss on non-padding tokens
student_masked = student_shift[mask_shift] # (n_valid_tokens, vocab)
teacher_masked = teacher_shift[mask_shift]
# --- KL divergence loss ---
T = temperature
soft_student = F.log_softmax(student_masked / T, dim=-1)
soft_teacher = F.softmax(teacher_masked / T, dim=-1)
kl_loss = F.kl_div(
soft_student,
soft_teacher,
reduction="batchmean",
) * (T ** 2)
# --- Task loss (next-token prediction) ---
# Hard labels: argmax of teacher logits (teacher-generated tokens)
teacher_tokens = teacher_shift[mask_shift].argmax(dim=-1)
task_loss = F.cross_entropy(
student_masked,
teacher_tokens,
label_smoothing=label_smoothing,
)
total_loss = alpha * kl_loss + (1 - alpha) * task_loss
return {
"kl_loss": kl_loss,
"task_loss": task_loss,
"total_loss": total_loss,
"n_valid_tokens": mask_shift.sum().item(),
}
class LLMDistillationTrainer:
"""
Production-grade LLM distillation trainer.
Supports:
- White-box distillation: access to teacher logits (best accuracy)
- Black-box (API) distillation: only generated text from teacher
- QLoRA-based student training for memory efficiency
"""
def __init__(
self,
teacher_model_name: str,
student_model_name: str,
distillation_config: DistillationConfig,
device: str = "cuda",
):
self.config = distillation_config
self.device = device
print("Loading teacher model...")
self.teacher = AutoModelForCausalLM.from_pretrained(
teacher_model_name,
torch_dtype=torch.float16,
device_map="auto",
)
self.teacher.eval()
for param in self.teacher.parameters():
param.requires_grad = False # Teacher is frozen always
print("Loading student model...")
self.student = AutoModelForCausalLM.from_pretrained(
student_model_name,
torch_dtype=torch.bfloat16,
).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(student_model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def get_teacher_logits(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Get teacher's next-token distribution for all sequence positions."""
with torch.no_grad():
teacher_outputs = self.teacher(
input_ids=input_ids.to(self.teacher.device),
attention_mask=attention_mask.to(self.teacher.device),
)
return teacher_outputs.logits.to(self.device).detach()
def training_step(
self,
batch: Dict[str, torch.Tensor],
optimizer: torch.optim.Optimizer,
) -> Dict[str, float]:
"""Single training step with distillation."""
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
# Get teacher logits (no gradient)
teacher_logits = self.get_teacher_logits(input_ids, attention_mask)
# Forward through student
student_outputs = self.student(
input_ids=input_ids,
attention_mask=attention_mask,
)
student_logits = student_outputs.logits
# Compute distillation loss
losses = compute_token_level_distillation_loss(
student_logits=student_logits,
teacher_logits=teacher_logits,
attention_mask=attention_mask,
temperature=self.config.temperature,
alpha=self.config.alpha,
)
losses["total_loss"].backward()
torch.nn.utils.clip_grad_norm_(self.student.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
return {k: v.item() if torch.is_tensor(v) else v
for k, v in losses.items()}
Landmark Distillation Systems: DistilBERT and TinyBERT
Understanding how these systems were built gives you patterns applicable to modern LLMs:
| System | Teacher | Student | Method | Size Reduction | Accuracy Retention |
|---|---|---|---|---|---|
| DistilBERT | BERT-base | DistilBERT (6L) | Response + intermediate | 40% fewer params | 97% of BERT-base |
| TinyBERT | BERT-large | TinyBERT (4L) | Full feature (attn + hidden) | 7.5x fewer params | 96.8% on GLUE |
| MiniLM | Various | MiniLM | Attention transfer | 50% fewer params | 99%+ on many tasks |
| Phi-2 (partial) | Multiple | Phi-2 (2.7B) | Data-driven + distillation | 10-40x smaller | Strong on benchmarks |
| Mistral 7B | Llama 2 70B | Mistral 7B | Partial distillation | 10x smaller | Competitive with 70B |
DistilBERT's key insight: Initialize the student from alternating layers of the teacher. BERT-base has 12 layers; DistilBERT uses layers 1, 3, 5, 7, 9, 11 (every other layer). This warm initialization is crucial - it gives the student 6 layers that already capture diverse linguistic phenomena, not random weights.
TinyBERT's key insight: Match not just the outputs, but attention patterns and hidden states at corresponding layers. The attention matrices encode syntactic structure (which tokens attend to which). By matching them, TinyBERT learns the same structural priors as BERT, making it much better than response-only distillation on syntax-sensitive tasks.
Hyperparameter Tuning for Distillation
The most impactful hyperparameters in distillation, in order of sensitivity:
from itertools import product
def distillation_hyperparameter_grid() -> List[DistillationConfig]:
"""
Generate configurations for a systematic distillation hyperparameter search.
Most sensitive to tune (in order):
1. temperature: 3-6 is the typical effective range
2. alpha: 0.5-0.9 is the effective range
3. learning_rate: distillation typically needs lower LR than fine-tuning
4. intermediate_layers: usually True is better for architectural similarity
"""
temperatures = [2.0, 4.0, 6.0]
alphas = [0.5, 0.7, 0.9]
intermediate_options = [False, True]
configs = []
for T, alpha, intermediate in product(temperatures, alphas, intermediate_options):
configs.append(DistillationConfig(
temperature=T,
alpha=alpha,
intermediate_layers=intermediate,
))
return configs # 18 configurations - feasible with a small dev set
def run_distillation_hyperparam_search(
configs: List[DistillationConfig],
train_dataset: DistillationDataset,
val_dataset: DistillationDataset,
student_model_name: str,
device: str = "cuda",
quick_epochs: int = 2, # Fast evaluation - fewer epochs for search
) -> DistillationConfig:
"""
Run hyperparameter search on a small subset of training data.
Pattern: train for quick_epochs on 20% of train data,
evaluate on full val set. Pick the best config,
then retrain on full data with the winner.
"""
best_config = None
best_val_acc = 0.0
# Use 20% of training data for fast search
n_quick = max(100, len(train_dataset) // 5)
quick_indices = list(range(n_quick))
quick_dataset = torch.utils.data.Subset(train_dataset, quick_indices)
for i, config in enumerate(configs):
print(f"\nConfig {i+1}/{len(configs)}: T={config.temperature}, alpha={config.alpha}")
student = AutoModelForSequenceClassification.from_pretrained(
student_model_name,
num_labels=train_dataset.teacher_logits.shape[1],
).to(device)
# Quick training
loss_fn = DistillationLoss(config)
optimizer = AdamW(student.parameters(), lr=2e-5)
loader = DataLoader(quick_dataset, batch_size=32, shuffle=True)
for epoch in range(quick_epochs):
student.train()
for batch in loader:
optimizer.zero_grad()
outputs = student(
input_ids=batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
)
losses = loss_fn(
student_logits=outputs.logits,
teacher_logits=batch["teacher_logits"].to(device),
hard_labels=batch["labels"].to(device),
)
losses["total_loss"].backward()
optimizer.step()
# Evaluate
student.eval()
correct = total = 0
val_loader = DataLoader(val_dataset, batch_size=64)
with torch.no_grad():
for batch in val_loader:
outputs = student(
input_ids=batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
)
preds = outputs.logits.argmax(dim=-1)
correct += (preds == batch["labels"].to(device)).sum().item()
total += batch["labels"].size(0)
val_acc = correct / total
print(f" Quick val accuracy: {val_acc:.4f}")
if val_acc > best_val_acc:
best_val_acc = val_acc
best_config = config
del student
torch.cuda.empty_cache()
print(f"\nBest config: T={best_config.temperature}, alpha={best_config.alpha}")
print(f"Best quick val acc: {best_val_acc:.4f}")
return best_config
When to Use Distillation vs. Quantization
Distillation vs. Quantization Decision Matrix:
| Scenario | Best Approach | Reasoning |
|---|---|---|
| No retraining budget | Quantization (AWQ/GPTQ) | Zero training cost, 4x memory reduction |
| General-purpose serving | Quantization first, then evaluate | Fast, good-enough accuracy |
| High-volume specific task | Task-specific distillation | Permanently cheaper per token |
| Edge deployment | Distill to small arch + quantize | Both techniques needed |
| Fine-tuned model serving | QLoRA + AWQ | Fine-tune flexibility + inference speed |
| Regulatory accuracy requirements | INT8 quantization only | Minimal accuracy impact |
| 10x+ compression needed | Distillation required | Quantization alone can't reach 10x |
Common Mistakes and Production Pitfalls
:::danger Never Use Hard Labels Alone When You Have Teacher Access If you have access to the teacher model's logits, always use soft labels. Training only on hard labels (teacher argmax) wastes the teacher's probability information. The difference between soft-label distillation and hard-label fine-tuning on teacher outputs is typically 3-8% accuracy on most tasks. The soft label signal is free if the teacher model is available - always use it. :::
:::danger Do Not Freeze the Student Completely in Early Training Some practitioners freeze the student's lower layers during early distillation to preserve pretrained representations. This is counterproductive when the student architecture differs significantly from the teacher. Allow all student parameters to be trainable throughout distillation. If training stability is an issue, use a lower learning rate (1e-5 instead of 2e-5) rather than freezing layers. :::
:::warning Calibration-Quality Mismatch in API Distillation When using API teachers (GPT-4, Claude), the quality of teacher outputs for your task is not guaranteed. If the teacher makes systematic errors on a subset of your task distribution, the student learns those errors too. Always: (1) sample 100-200 teacher outputs and manually review quality, (2) filter out low-quality teacher outputs before training, (3) include a hard-label task loss component (alpha less than 1.0) to anchor the student to verified labels. :::
:::warning Do Not Use the Same Data for Hyperparameter Search and Final Evaluation Hyperparameter search on the validation set causes test set leakage - the best hyperparameters are optimized for the validation set, not the test set. Use a three-way split: train / validation (for hyperparameter selection) / test (for final evaluation). Report only the test set numbers in your evaluation. :::
:::tip Start with Response Distillation Before Feature Distillation Feature distillation (matching intermediate hidden states) adds complexity - architecture-specific projection layers, careful layer mapping, additional loss terms to tune. Start with response distillation only (KL loss + task loss). If the accuracy is insufficient and you have additional engineering budget, add feature distillation. In many practical cases, response distillation alone achieves 90%+ of the accuracy benefit at much lower implementation complexity. :::
Interview Questions
Q1: Explain knowledge distillation from first principles. What problem does it solve, and why does it work better than training from scratch?
Knowledge distillation solves the problem of training a small model that performs as well as a large model, without training the large model yourself. The fundamental insight is that hard labels (0/1 one-hot vectors) are information-sparse. When a large teacher model predicts class distributions, its non-maximum probabilities encode rich information about similarity structure: if a teacher gives 62% probability to "billing" and 21% to "refund", this tells the student that billing and refund are semantically related. Training from scratch on hard labels loses this information entirely. The soft probability distributions Hinton called "dark knowledge" encode the teacher's knowledge about which classes are similar, which boundaries are close, and how confident each decision should be. The student trained on soft labels learns the same similarity structure, not just the decision boundary. This is why distillation consistently outperforms training from scratch at the same parameter count: the student receives a richer training signal per example, leading to better generalization on the same task.
Q2: Explain the mathematical role of the temperature parameter in distillation. Why is the correction factor needed?
Temperature controls the entropy of the teacher's soft distribution. At , softmax gives a peaked distribution where the correct class dominates. At , the distribution spreads more probability to related classes, revealing similarity structure. The correction addresses a gradient magnitude problem. The softmax temperature scaling changes the effective magnitude of the probability differences. At temperature , the logit differences in the softmax are divided by , so the resulting probabilities are more uniform. The KL divergence between two temperature- distributions is approximately times the KL divergence at temperature 1 (for the same underlying logits). Without the correction, the distillation loss produces gradients times smaller than at , effectively making the distillation signal negligible at high temperature. The multiplication restores gradient magnitude to be roughly constant across temperature choices, allowing to remain the controlling hyperparameter for the task-distillation tradeoff.
Q3: Compare response distillation with feature distillation (TinyBERT-style). When would you choose each?
Response distillation matches only the final output distributions - the teacher's and student's next-token or class probabilities. Feature distillation also matches intermediate hidden states and attention matrices between corresponding layers. Response distillation is simpler: one loss term, no projectors needed, works even when student and teacher have very different architectures. It is the right default for most practitioners. Feature distillation transfers structural knowledge: attention patterns encode which tokens attend to which (syntax, coreference, semantics). By matching them, TinyBERT learned the same inductive biases as BERT, producing 7.5x compression with only 3.2% accuracy loss - better than response distillation alone at the same compression ratio. Choose feature distillation when: the student and teacher have the same or similar architecture families, the task requires syntactic understanding (parsing, NER, structured prediction), you have engineering budget for the additional complexity, and response distillation alone did not meet your accuracy requirements. For most classification and generation tasks, response distillation is sufficient.
Q4: How does API-based distillation work, and what are its limitations compared to white-box distillation?
API-based distillation uses a proprietary teacher model (GPT-4, Claude, Gemini) to generate high-quality training data for a small open-source student. The process: sample diverse inputs from your task distribution, query the teacher API for each, collect the (input, output) pairs, and fine-tune a small student on these pairs. The key limitation: you cannot access the teacher's internal logits, only its generated text. This means you can only use hard labels (the teacher's argmax output), not soft probability distributions. You lose the "dark knowledge" in the teacher's non-maximum probabilities. As a result, API distillation typically achieves 85-92% of teacher quality at the target size, versus 90-97% with white-box distillation that has access to full logit distributions. The other practical limitations are cost (running thousands of API calls is expensive - 15 per thousand examples depending on the model) and teacher error propagation (if the teacher makes systematic errors on your task, the student learns those errors).
Q5: What is the key engineering insight that makes distillation work better than quantization for a 10x compression target?
Quantization attacks numeric precision while preserving the model architecture. Moving from FP16 to INT4 provides a maximum 4x memory reduction. To reach 10x compression, you would need INT1.6 - which is not a real format and not feasible without catastrophic accuracy loss. Distillation changes the architecture: a 70B model distilled into a 7B model achieves 10x compression while keeping the student in a reasonable numeric precision (FP16 or INT4). The student can then be further quantized to INT4, reaching 40x compression (10x architecture × 4x precision). This compounding is the key: distillation and quantization attack orthogonal dimensions of the model size. Quantization attacks the bits-per-weight dimension. Distillation attacks the number-of-weights dimension. The Pareto frontier for extreme compression (10x+) requires both techniques - distill to a smaller architecture, then quantize the result. Neither technique alone reaches 10x+ with acceptable accuracy.
Q6: How would you set up a production distillation pipeline for a task-specific LLM deployment?
The production pipeline has five stages. Stage 1: collect and curate calibration data. Gather 10,000-50,000 representative inputs from your deployment domain, including edge cases and the hardest examples you expect in production. Stage 2: generate teacher outputs. Run the teacher model (GPT-4, Claude, or your large internal model) on all inputs, collecting both the generated text and, if accessible, the full logit distributions. Budget $500-5000 for API costs depending on scale. Stage 3: select student architecture. Choose a model that fits your compute budget - for classification, 100M-500M parameters; for generation, 1B-7B. Initialize from the best available pretrained checkpoint of that size. Stage 4: run hyperparameter search. Use a 20% subset of training data to tune temperature (3-6), alpha (0.5-0.9), and learning rate (1e-5 to 5e-5). Choose the config with the best validation accuracy, then retrain on the full dataset. Stage 5: evaluate and compress further. Benchmark the distilled model on your task metrics, comparing to the teacher. If meeting the accuracy bar, quantize to INT4 with AWQ using domain-calibrated data for an additional 4x memory reduction. Deploy and monitor closely for the first two weeks, tracking not just average accuracy but accuracy on the edge cases from stage 1.
