Instruction Tuning at Scale
The Dataset That Changed Everything
It's January 2022. The GPT-3 paper has been out for 18 months. Everyone knows large language models are capable of remarkable things when prompted carefully. Zero-shot performance on classification, translation, summarization - all possible if you write the right prompt. The problem is inconsistency. GPT-3 can answer a question perfectly with one phrasing and completely fail with a slightly different phrasing. The model's behavior is powerful but brittle.
A team at Google Brain - Jason Wei, Maarten Bosma, Vincent Zhao, Kelvin Guu, Adams Yu, Brian Lester, Nan Du, Andrew Dai, Quoc Le - has an idea. What if instead of hoping the model generalizes from in-context examples, you explicitly train it on many different tasks framed as natural language instructions? Not fine-tuning for a single task. Fine-tuning on a curriculum that spans 60+ NLP task types: sentiment analysis, machine translation, reading comprehension, summarization, question answering, commonsense reasoning, coreference resolution - each formatted as an instruction plus an expected response.
The paper is called "Finetuned Language Models Are Zero-Shot Learners" (FLAN). The results are striking: a model trained on 60 task types generalizes to unseen task types in a way that standard prompting cannot match. Not because it saw those tasks - it didn't. Because training on diverse instructions teaches the model what "following an instruction" means as a general skill.
This single insight - that instruction diversity during training transfers to instruction-following ability on new tasks - is the foundation of every chat model and instruction-following LLM that came after it. ChatGPT, Claude, Mistral Instruct, Llama Instruct - all of them trace back to the FLAN observation.
But there is a gap between "we know this works" and "we can make this work at production scale." Building a robust instruction-tuning pipeline requires decisions about dataset construction, data quality filtering, training infrastructure, and evaluation that the original FLAN paper does not address. This lesson bridges that gap.
Why This Exists - The Problem With Base Models
A base language model is trained on a simple objective: predict the next token. Feed it the internet, books, code, and papers. Minimize cross-entropy loss. After trillions of tokens, you have a model that can complete sequences in a statistically coherent way.
The problem is that "completing sequences" is not what users want. A user who asks "What is the capital of France?" does not want the model to continue the pattern with "What is the capital of Germany? What is the capital of Italy?" - which is what a base model might do if the pattern looks like a list of geography questions. They want a direct answer: "Paris."
This mismatch between pretraining objective (next-token prediction on internet text) and deployment objective (following user instructions to be helpful) is the fundamental problem instruction tuning solves.
Before instruction tuning, the solutions were inadequate. Few-shot prompting worked but consumed context window tokens and required careful prompt engineering. Task-specific fine-tuning worked but produced single-task models that could not generalize. Fine-tuning on a mix of tasks in their raw academic format worked but produced models that responded in academic-paper style rather than conversational style.
Instruction tuning solves this by reformatting every task as a human-readable instruction-response pair and training on a diverse enough set of these pairs that the model learns the meta-skill of instruction following itself.
Historical Context - From FLAN to OpenHermes
2021 - FLAN (Wei et al., Google Brain): Fine-tunes a 137B PaLM model on 62 tasks from the NLP academic literature, formatted as instructions. Demonstrates zero-shot generalization to held-out tasks. The key finding: instruction tuning scales - bigger models get more benefit from it.
2022 - FLAN-T5 and SuperNaturalInstructions: Wang et al. at Allen AI release SuperNaturalInstructions: 1,616 tasks with 64 demonstration examples each, covering 76 distinct task types across 55 languages. Scaled to a much broader task coverage than original FLAN.
2022 - InstructGPT (Ouyang et al., OpenAI): Adds RLHF on top of supervised instruction tuning. Crucial finding: 1.3B parameters instruction-tuned with RLHF outperforms 175B parameters on raw helpfulness metrics. RLHF paper is important, but the SFT stage they describe is essentially instruction tuning - 13K human-written instruction-response pairs covering diverse tasks.
2023 - Alpaca (Stanford CRFM): Generates 52K instruction-following examples from GPT-3.5 using the self-instruct method (Wang et al. 2022). Fine-tunes LLaMA 7B. Demonstrates that open-source instruction tuning is achievable with small datasets and limited compute. Quality is mediocre by modern standards but proved the concept.
2023 - Dolly (Databricks): 15K human-written instruction-response pairs, released with an open license (CC BY SA). First commercially usable instruction dataset.
2023 - OpenHermes (Teknium): 900K+ high-quality instruction-response pairs, heavily filtered from multiple sources including GPT-4 generated data. Became the standard reference dataset for open-source instruction tuning quality. Mistral 7B fine-tuned on OpenHermes (Hermes 2.5) outperformed many much larger models.
2023-2024 - The data quality insight: The community converges on a counterintuitive finding. 1,000 high-quality examples can outperform 100,000 low-quality examples. LIMA (Zhou et al. 2023) demonstrates this formally: 1,000 carefully curated examples fine-tune LLaMA 65B to strong performance across diverse tasks. Quality beats quantity for instruction tuning above a threshold of roughly 1,000-10,000 examples.
Core Concepts - The Instruction Tuning Framework
What "Instructions" Actually Are
An instruction-tuned dataset is a collection of (instruction, optional input, response) triples. The instruction specifies what to do. The optional input provides context (a document to summarize, code to debug, a question to answer). The response is what the model should output.
Instruction: Summarize the following text in one sentence.
Input: The Amazon rainforest, often called the "lungs of the Earth," covers approximately 5.5 million square kilometers across nine countries in South America. It produces 20% of the world's oxygen and houses 10% of all known species.
Response: The Amazon rainforest is a vast, biologically diverse ecosystem spanning 5.5 million sq km across South America that generates 20% of Earth's oxygen.
The format varies across datasets and models. Modern chat models use a structured chat template with role markers:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Summarize this in one sentence: [text]<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
[response]<|eot_id|>
The structural difference matters: models trained on chat templates learn the turn structure and can handle multi-turn conversations. Models trained on raw instruction-response pairs often struggle with multi-turn.
The FLAN Insight - Task Diversity as Transfer Learning
The key finding from Wei et al. is that instruction tuning performance on held-out tasks scales with the number of tasks seen during training, not just the number of examples per task.
The mechanism: training on diverse instructions teaches the model a task representation that generalizes. The model learns that text starting with "Translate this to French:" has a certain expected completion pattern, and this learning transfers to "Translate this to Spanish:" even without seeing Spanish translation during training.
More precisely, the model learns:
- How to parse an instruction into a goal
- How to identify the relevant context in the input
- What format the response should take given the instruction type
- How to terminate the response appropriately
These four sub-skills are general. They transfer across task types because they are about the structure of following instructions, not about any specific task.
Scaling Laws for Instruction Tuning Data
Unlike pretraining, where loss decreases predictably with data and compute (Chinchilla scaling laws), instruction tuning follows a different curve:
Phase 1 (0 to ~1K examples): Rapid improvement. The model learns basic instruction-following behavior.
Phase 2 (1K to ~10K examples): Continued strong improvement in breadth. Adding new task types continues to help.
Phase 3 (10K to ~100K examples): Diminishing returns. More examples of the same task type provide minimal benefit. New task types still help.
Phase 4 (100K+ examples): Quality becomes dominant over quantity. Additional low-quality examples can hurt. Data cleaning and filtering matter more than adding raw volume.
The mathematical intuition: instruction tuning is teaching a behavior, not teaching new knowledge. The model already has the knowledge from pretraining. You are teaching it how to deploy that knowledge in response to instructions. Once the behavior pattern is learned, more examples of the same pattern do not add new information.
This is not a rigorous formula but captures the qualitative relationship: gains are logarithmic in example count (diminishing returns), linear in task diversity (new tasks always help), and quality acts as a multiplier (high-quality examples count as more).
Chat Fine-Tuning vs Instruction Fine-Tuning
These terms are often used interchangeably but they differ in important ways:
Instruction fine-tuning: Single-turn. One instruction, one response. The model learns to respond to commands. Alpaca, FLAN-T5, Dolly are instruction-tuned models.
Chat fine-tuning: Multi-turn. A conversation history with alternating user and assistant messages. The model learns conversational dynamics: following up, asking clarifying questions, maintaining context across turns, handling corrections.
In practice, modern "instruct" models are trained on both: a base of instruction pairs plus multi-turn conversation data. The distinction matters for your dataset construction: if you need a conversational assistant, you need multi-turn examples. If you need a command-following tool, single-turn examples are sufficient.
Building an Instruction Dataset
The Three Pillars: Coverage, Quality, Format Consistency
Coverage means task diversity. A dataset that covers only one type of task (e.g., only summarization) will produce a model that instruction-follows well only for that task. You want representation across:
- Information extraction (NER, relation extraction, table parsing)
- Generation tasks (summarization, translation, paraphrase)
- Reasoning tasks (math, logic, causal reasoning)
- Coding tasks (code generation, debugging, explanation)
- Conversation tasks (Q&A, dialogue, clarification)
- Classification tasks (sentiment, topic, intent)
At minimum, aim for 10+ distinct task categories with at least 500 examples each.
Quality means response correctness and appropriateness. The single biggest quality issue in community instruction datasets is hallucinated responses - the model generating authoritative-sounding but incorrect responses to factual questions. Filter for:
- Response length appropriateness (too short = lazy answer, too long = padding)
- Response relevance to the instruction (does it actually follow the instruction?)
- Factual plausibility (use automated metrics or sample-based human review)
Format consistency means all examples follow the same structural template. Inconsistent formatting is one of the most common sources of degraded model behavior. If 60% of your examples use one chat template and 40% use a different one, the model will produce inconsistent outputs at inference.
Data Sources and Their Trade-offs
Source Quality License Cost Notes
---
Human-written High Varies Very High Gold standard; expensive to scale
GPT-4 generated High Restricted Moderate Best quality/cost ratio; OpenAI ToS restriction
GPT-3.5 generated Medium Restricted Low Good for scale; quality below GPT-4
Academic NLP tasks Variable Open Low Covers diverse tasks; often stiff/academic tone
Self-instruct Medium Open Low Bootstrap from seed tasks; quality variance high
Community datasets Variable Varies None OpenHermes, Dolly, Orca, etc.
The OpenAI Terms of Service prohibit using outputs to train competing models. This is a legal risk that many community datasets ignore. For commercial production systems, use human-written data, openly licensed datasets, or self-hosted model outputs (e.g., from Llama 3.1 405B, which Meta permits use for training other Llama-family models).
Data Cleaning Pipeline
import re
from typing import Optional
def is_valid_instruction_example(
instruction: str,
response: str,
input_context: Optional[str] = None,
min_instruction_words: int = 5,
min_response_words: int = 10,
max_response_words: int = 1000,
) -> tuple[bool, str]:
"""
Returns (is_valid, reason_if_invalid).
Filters out examples that will degrade training quality.
"""
# Basic length checks
instruction_words = len(instruction.split())
response_words = len(response.split())
if instruction_words < min_instruction_words:
return False, f"Instruction too short: {instruction_words} words"
if response_words < min_response_words:
return False, f"Response too short: {response_words} words"
if response_words > max_response_words:
return False, f"Response too long: {response_words} words"
# Check for refusal responses (these can be good but need careful handling)
refusal_patterns = [
r"i (cannot|can't|am unable to)",
r"i (don't|do not) have (the ability|access|information)",
r"as an ai (language model|assistant)",
]
response_lower = response.lower()
for pattern in refusal_patterns:
if re.search(pattern, response_lower):
# Not invalid, but flag for review - refusals should be intentional
pass # return False, f"Contains refusal pattern: {pattern}"
# Check for repetitive response (common in low-quality generated data)
sentences = response.split(".")
if len(sentences) > 5:
unique_sentences = set(s.strip() for s in sentences if s.strip())
repetition_ratio = 1 - (len(unique_sentences) / len(sentences))
if repetition_ratio > 0.3:
return False, f"High repetition ratio: {repetition_ratio:.2f}"
# Check instruction-response relevance (simple heuristic)
# A real pipeline uses embedding similarity or an LLM-as-judge
instruction_keywords = set(instruction.lower().split()[:20])
response_first_words = set(response.lower().split()[:50])
overlap = instruction_keywords & response_first_words
# Very low keyword overlap often indicates the response ignores the instruction
if len(instruction_keywords) > 10 and len(overlap) < 2:
# Flag but don't automatically reject - some instructions are abstract
pass
return True, ""
def clean_response_text(response: str) -> str:
"""Remove common artifacts from generated instruction responses."""
# Remove "As an AI language model..." preambles
preamble_patterns = [
r"^As an AI(?: language model)?(?:,| -|:)?\s*",
r"^I(?: am|'m) an AI(?: assistant)?(?:,| -|:)?\s*",
r"^Sure!?\s+",
r"^Of course!?\s+",
r"^Certainly!?\s+",
r"^Absolutely!?\s+",
]
for pattern in preamble_patterns:
response = re.sub(pattern, "", response, flags=re.IGNORECASE)
# Remove trailing "Is there anything else" closings
closing_patterns = [
r"\s*Is there anything else I can (?:help|assist) you with\??$",
r"\s*Feel free to ask if you have (?:more|any|other) questions\.?$",
r"\s*Let me know if you need (?:anything else|further (?:help|clarification))\.?$",
]
for pattern in closing_patterns:
response = re.sub(pattern, "", response, flags=re.IGNORECASE | re.DOTALL)
return response.strip()
Code Examples - Training Pipeline
Dataset Formatting for Llama 3
# Format instruction dataset for Llama 3 chat template
from transformers import AutoTokenizer
from datasets import Dataset
import json
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
def format_example_llama3(
instruction: str,
response: str,
system_prompt: str = "You are a helpful, accurate, and thoughtful assistant.",
input_context: str = "",
) -> str:
"""
Format a single instruction-response pair using Llama 3 chat template.
Returns the full formatted string with special tokens.
"""
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": instruction if not input_context else f"{instruction}\n\n{input_context}"
},
{"role": "assistant", "content": response},
]
# apply_chat_template adds <|begin_of_text|>, <|start_header_id|>, etc.
formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False, # False for training, True for inference
)
return formatted
def tokenize_and_mask(
example: dict,
tokenizer,
max_length: int = 2048,
) -> dict:
"""
Tokenize a formatted example and create loss mask.
Only compute loss on assistant tokens, not on system/user tokens.
This is critical for instruction tuning - you want the model to learn
to generate the response, not to predict the instruction.
"""
text = example["text"]
tokens = tokenizer(
text,
max_length=max_length,
truncation=True,
padding="max_length",
return_tensors=None,
)
input_ids = tokens["input_ids"]
labels = input_ids.copy()
# Find where the assistant response starts
# In Llama 3: assistant turn starts after <|start_header_id|>assistant<|end_header_id|>\n\n
assistant_token_ids = tokenizer.encode(
"<|start_header_id|>assistant<|end_header_id|>\n\n",
add_special_tokens=False,
)
# Mask everything before the assistant response with -100
# -100 is ignored by PyTorch's CrossEntropyLoss
assistant_start = None
for i in range(len(input_ids) - len(assistant_token_ids)):
if input_ids[i:i + len(assistant_token_ids)] == assistant_token_ids:
assistant_start = i + len(assistant_token_ids)
break
if assistant_start is not None:
labels[:assistant_start] = [-100] * assistant_start
# Mask padding tokens
for i, token_id in enumerate(input_ids):
if token_id == tokenizer.pad_token_id:
labels[i] = -100
return {
"input_ids": input_ids,
"attention_mask": tokens["attention_mask"],
"labels": labels,
}
# Load and format a dataset
def prepare_instruction_dataset(
raw_examples: list[dict],
tokenizer,
max_length: int = 2048,
num_proc: int = 8,
) -> Dataset:
"""
Full pipeline: raw examples -> tokenized dataset ready for Trainer.
"""
# Step 1: Format as chat template strings
formatted = [
{
"text": format_example_llama3(
instruction=ex["instruction"],
response=ex["output"],
input_context=ex.get("input", ""),
)
}
for ex in raw_examples
]
dataset = Dataset.from_list(formatted)
# Step 2: Tokenize with loss masking
tokenized = dataset.map(
lambda x: tokenize_and_mask(x, tokenizer, max_length),
num_proc=num_proc,
remove_columns=["text"],
desc="Tokenizing",
)
return tokenized
Multi-Node Training on a 2-Node A100 Cluster
# launch_training.py - Run with:
# torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0
# --master_addr=NODE0_IP --master_port=29500 launch_training.py
# (run same command on both nodes with node_rank=0 and node_rank=1)
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
from peft import LoraConfig, get_peft_model
import torch
import os
MODEL_ID = "meta-llama/Meta-Llama-3-8B"
OUTPUT_DIR = "/shared_storage/llama3-8b-instruct-v1" # must be shared NFS/EFS
def main():
# Model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
use_cache=False,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# LoRA config - using LoRA for this 2-node example
# For full fine-tuning, remove LoRA and add DeepSpeed ZeRO-3 config
lora_config = LoraConfig(
r=32,
lora_alpha=64,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Training arguments optimized for 2-node, 16x A100 80GB
training_args = TrainingArguments(
output_dir=OUTPUT_DIR,
# --- Batch size configuration ---
# Total effective batch = per_device * grad_accum * num_gpus
# = 4 * 4 * 16 = 256
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
# --- Optimization ---
num_train_epochs=3,
learning_rate=2e-4, # Higher for LoRA
lr_scheduler_type="cosine",
warmup_ratio=0.03,
weight_decay=0.001,
max_grad_norm=1.0,
# --- Precision ---
bf16=True,
tf32=True, # A100-specific: faster bf16 matmul
# --- Memory ---
gradient_checkpointing=True,
# --- Multi-node distributed ---
# torchrun handles WORLD_SIZE, RANK, LOCAL_RANK automatically
# No need to set ddp_* flags when using torchrun
dataloader_num_workers=4,
dataloader_pin_memory=True,
# --- Logging and saving ---
logging_steps=10,
logging_dir=f"{OUTPUT_DIR}/logs",
save_strategy="steps",
save_steps=200,
save_total_limit=3,
evaluation_strategy="steps",
eval_steps=200,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
# --- Reporting ---
report_to="wandb",
run_name="llama3-8b-instruct-v1",
)
# Load datasets (prepared with prepare_instruction_dataset above)
from datasets import load_from_disk
train_dataset = load_from_disk("/shared_storage/data/train")
eval_dataset = load_from_disk("/shared_storage/data/eval")
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
padding=True,
pad_to_multiple_of=8,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
trainer.train()
trainer.save_model(OUTPUT_DIR)
if __name__ == "__main__":
main()
# multi_node_launch.sh - Run this on NODE 0 and NODE 1 simultaneously
# NODE_RANK is 0 on first node, 1 on second node
NODE_RANK=${1:-0}
MASTER_ADDR="192.168.1.100" # IP of node 0
MASTER_PORT=29500
NPROC_PER_NODE=8
NNODES=2
torchrun \
--nproc_per_node=$NPROC_PER_NODE \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
launch_training.py
# Usage:
# On node 0: bash multi_node_launch.sh 0
# On node 1: bash multi_node_launch.sh 1
Gradient Accumulation for Large Effective Batch Sizes
# Why gradient accumulation matters for instruction tuning
# The research consensus (from InstructGPT, FLAN, and community experiments):
# Effective batch size of 128-512 tokens x examples consistently outperforms
# smaller batch sizes for instruction tuning. Large batches provide:
# 1. More stable gradient estimates
# 2. Better handling of diverse task types in each batch
# 3. Less noise in loss signal when tasks vary widely in difficulty
# With limited GPU memory, gradient accumulation simulates large batches:
# Instead of batch_size=128, use batch_size=4 with grad_accum=32
# The gradient is accumulated over 32 forward passes before a weight update
# Memory peak: determined by batch_size=4 (small)
# Effective optimization: same as batch_size=128 (large)
# Calculating the right accumulation steps:
def calculate_grad_accumulation(
target_effective_batch: int,
per_device_batch_size: int,
num_gpus: int,
) -> int:
"""
target_effective_batch: the batch size you want for stable training
per_device_batch_size: what fits in GPU memory
num_gpus: total GPU count across all nodes
"""
total_per_step = per_device_batch_size * num_gpus
if target_effective_batch % total_per_step != 0:
raise ValueError(
f"target_effective_batch ({target_effective_batch}) must be "
f"divisible by per_device_batch * num_gpus ({total_per_step})"
)
return target_effective_batch // total_per_step
# Example: 2-node, 16 GPUs, 4 per device, target batch 256
grad_accum = calculate_grad_accumulation(
target_effective_batch=256,
per_device_batch_size=4,
num_gpus=16,
)
print(f"Gradient accumulation steps: {grad_accum}") # 4
Dataset Mixing for Multi-Task Instruction Tuning
# Mixing multiple instruction datasets with controlled proportions
# Critical for preventing task-specific overfitting
from datasets import Dataset, concatenate_datasets
import numpy as np
def create_mixed_dataset(
task_datasets: dict[str, Dataset],
mixing_weights: dict[str, float],
total_examples: int,
seed: int = 42,
) -> Dataset:
"""
Create a mixed dataset from multiple task-specific datasets.
Args:
task_datasets: dict mapping task name to its Dataset
mixing_weights: dict mapping task name to sampling weight (should sum to 1.0)
total_examples: total number of examples in the mixed dataset
seed: random seed for reproducibility
Returns:
Shuffled mixed dataset with specified proportions
"""
assert abs(sum(mixing_weights.values()) - 1.0) < 1e-6, "Weights must sum to 1.0"
rng = np.random.default_rng(seed)
sampled_datasets = []
for task_name, weight in mixing_weights.items():
n_samples = int(total_examples * weight)
source_dataset = task_datasets[task_name]
if n_samples <= len(source_dataset):
# Sample without replacement
indices = rng.choice(len(source_dataset), size=n_samples, replace=False)
sampled = source_dataset.select(indices)
else:
# Oversample with replacement (for small datasets)
indices = rng.choice(len(source_dataset), size=n_samples, replace=True)
sampled = source_dataset.select(indices)
# Add task label for analysis
sampled = sampled.add_column(
"task_type",
[task_name] * len(sampled)
)
sampled_datasets.append(sampled)
print(f" {task_name}: {len(sampled):,} examples ({weight:.1%})")
# Concatenate and shuffle
mixed = concatenate_datasets(sampled_datasets)
mixed = mixed.shuffle(seed=seed)
return mixed
# Example mixing strategy for a general instruction-following model
task_weights = {
"instruction_following": 0.30, # diverse single-turn instructions
"code_generation": 0.20, # code + explanation tasks
"math_reasoning": 0.15, # chain-of-thought math problems
"summarization": 0.10, # document summarization
"question_answering": 0.10, # factual and reading comprehension QA
"creative_writing": 0.05, # creative and open-ended tasks
"classification": 0.05, # text classification examples
"multi_turn_chat": 0.05, # multi-turn conversation examples
}
# Load your task-specific datasets
# task_datasets = {name: load_dataset(...) for name in task_weights}
# Create mixed training set of 500K examples
# mixed_dataset = create_mixed_dataset(
# task_datasets,
# task_weights,
# total_examples=500_000,
# )
LLM-as-Judge for Response Quality Scoring
# Use a capable model to score instruction-response quality
# This is how most production data curation pipelines work
from openai import OpenAI # or use a local model via vLLM
client = OpenAI() # or point to local vLLM endpoint
JUDGE_PROMPT = """You are evaluating the quality of an AI assistant's response to a user instruction.
Rate the response on a scale of 1-5:
1 - Completely wrong, off-topic, or harmful
2 - Partially relevant but missing key content or contains errors
3 - Adequate but could be significantly improved
4 - Good response that correctly addresses the instruction
5 - Excellent response that is accurate, complete, and well-formatted
Output ONLY a JSON object with:
- "score": integer 1-5
- "reason": one sentence explaining the score
Instruction: {instruction}
Response: {response}"""
def score_example(instruction: str, response: str) -> dict:
"""Score a single instruction-response pair using an LLM judge."""
prompt = JUDGE_PROMPT.format(instruction=instruction, response=response)
completion = client.chat.completions.create(
model="gpt-4o-mini", # Use cheapest capable model for bulk scoring
messages=[{"role": "user", "content": prompt}],
response_format={"type": "json_object"},
temperature=0.0,
max_tokens=100,
)
import json
result = json.loads(completion.choices[0].message.content)
return result
def filter_by_quality(
examples: list[dict],
min_score: float = 3.5,
batch_size: int = 100,
) -> list[dict]:
"""
Filter examples by LLM-judged quality score.
Returns only examples with average score >= min_score.
For large datasets (>100K), sample 5-10% for scoring and
use the scores to train a lightweight quality classifier.
"""
scored_examples = []
for i in range(0, len(examples), batch_size):
batch = examples[i:i + batch_size]
for ex in batch:
try:
result = score_example(ex["instruction"], ex["output"])
scored_examples.append({
**ex,
"quality_score": result["score"],
"quality_reason": result.get("reason", ""),
})
except Exception as e:
print(f"Scoring failed for example {i}: {e}")
# Keep example with default score if scoring fails
scored_examples.append({**ex, "quality_score": 3, "quality_reason": "scoring_failed"})
if i % 1000 == 0:
print(f"Scored {i}/{len(examples)} examples")
filtered = [ex for ex in scored_examples if ex["quality_score"] >= min_score]
print(f"Kept {len(filtered)}/{len(examples)} examples ({len(filtered)/len(examples):.1%})")
return filtered
Mermaid Diagrams
Instruction Tuning Data Pipeline
Multi-Node Training Architecture
Data Scaling Behavior for Instruction Tuning
Production Engineering Notes
Shared Storage Requirements for Multi-Node Training
The biggest operational challenge with multi-node training is not the training code - it is the storage setup. All nodes must be able to read the dataset and write checkpoints to a shared location.
Options in order of performance:
- AWS EFS or equivalent NFS: Works reliably, moderate throughput. Mount at
/sharedon all nodes. - S3 with fsspec streaming: Read dataset directly from S3 via Hugging Face datasets
load_from_disk. Slower but no shared filesystem needed. - Pre-copy to each node: Copy dataset to each node's local NVMe before training. Fastest reads but requires coordination.
For checkpoints, shared NFS is the simplest. Rank 0 writes, all ranks can read if resuming. Use save_total_limit=3 to avoid filling the filesystem.
Evaluation During Training
Monitoring eval loss is necessary but insufficient for instruction tuning. Eval loss can decrease while model quality (as measured by human preference or task benchmarks) plateaus or even regresses.
Complement eval loss with:
- MT-Bench: 80 multi-turn questions across 8 categories, scored by GPT-4. Run every 2-3 checkpoints.
- Task-specific holdout: Hold out 5% of each task category, track accuracy or F1 per task.
- Generation samples: Log 10-20 generation samples per checkpoint for qualitative review.
from transformers import TrainerCallback
class InstructEvalCallback(TrainerCallback):
"""Run MT-Bench style evaluation on a subset of prompts every N steps."""
def __init__(self, eval_prompts: list[str], tokenizer, every_n_steps: int = 500):
self.eval_prompts = eval_prompts
self.tokenizer = tokenizer
self.every_n_steps = every_n_steps
def on_step_end(self, args, state, control, model=None, **kwargs):
if state.global_step % self.every_n_steps != 0:
return
model.eval()
samples = []
for prompt in self.eval_prompts[:5]: # Run on 5 prompts for speed
inputs = self.tokenizer(
prompt,
return_tensors="pt",
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.0,
do_sample=False,
)
response = self.tokenizer.decode(
output_ids[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
samples.append({"prompt": prompt, "response": response})
# Log to wandb
import wandb
if wandb.run:
wandb.log({
"generation_samples": wandb.Table(
columns=["prompt", "response"],
data=[[s["prompt"], s["response"]] for s in samples],
),
"step": state.global_step,
})
model.train()
Handling Long Sequences Without OOM
Instruction tuning datasets vary wildly in sequence length. Code generation examples might be 3,000 tokens. Simple classification examples might be 50 tokens. Padding a batch of 50-token examples to 3,000 tokens wastes 98% of compute.
Two strategies:
Packing (sequence packing): Concatenate multiple short examples into a single sequence up to max_length, separated by EOS tokens. SFTTrainer in TRL supports this with packing=True. Best throughput but can cause cross-contamination if EOS handling is wrong.
Dynamic padding per batch: Sort examples by length, batch similar-length examples together. Use DataCollatorForSeq2Seq with pad_to_multiple_of=8 and padding=True (not max_length). Reduces wasted compute significantly.
# In SFTTrainer, enable packing for high-throughput training
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
model=model,
args=SFTConfig(
# ... other args ...
packing=True, # pack short examples together
max_seq_length=2048, # pack up to this length
),
train_dataset=dataset,
tokenizer=tokenizer,
)
Learning Rate and Warmup for Instruction Tuning
Instruction tuning is brittle to learning rate. Too high: catastrophic forgetting of base model capabilities. Too low: slow convergence, model might not fully learn the instruction format.
General guidelines:
- LoRA: 1e-4 to 3e-4, cosine decay, 3-5% warmup
- Full fine-tuning: 1e-5 to 5e-5, cosine decay, 3% warmup
- Warmup is more important for instruction tuning than for pretraining - the instruction following behavior seems to need a stable initialization period
Use linear warmup for the first 3% of steps, then cosine decay to 10% of the peak LR (not 0). A non-zero final LR prevents the optimizer from collapsing at the end of training.
Common Mistakes
:::danger Not Masking Instruction Tokens in Loss Computation The most common technical mistake in instruction tuning: computing loss on both the instruction and the response tokens. When you train on the full sequence (instruction + response), the model learns to predict the instruction tokens too. This wastes model capacity on tokens you never generate at inference time and can confuse the model about what it is supposed to generate. Always set labels to -100 for all tokens up to and including the instruction/user turn. Only compute loss on the assistant's response tokens. :::
:::danger Using a Single Task Type Dataset Training on only one task type (e.g., only summarization) will produce a model that follows instructions well for that task and poorly for everything else. The FLAN insight is not that instruction tuning works - it is that task diversity is what makes instruction tuning generalize. If your use case requires only one task, task-specific fine-tuning is more efficient than instruction tuning anyway. If you want general instruction following, you need coverage across at least 10 distinct task categories. :::
:::warning Applying the Chat Template Inconsistently
Every model has a specific chat template (the arrangement of special tokens, role markers, and turn delimiters). If you apply the template inconsistently across your dataset - some examples with system prompts, some without, some with different delimiters - the model will produce inconsistent outputs at inference. Use tokenizer.apply_chat_template() for every single example in your dataset. Never manually construct the template string.
:::
:::warning Evaluating Only on Training Distribution It is easy to achieve low eval loss while the model fails on tasks outside the training distribution. This is the instruction tuning version of overfitting: the model learns to follow the specific instructions it saw during training but does not generalize to new instruction phrasings or new task types. Always evaluate on at least one benchmark that includes task types not in your training data. MT-Bench and MMLU are good choices. If your eval metrics only measure the training distribution, you will not detect generalization failures until they appear in production. :::
:::warning Forgetting That Data Quality Beats Quantity Above 10K Examples Many teams spend engineering effort scraping and cleaning large instruction datasets when the return on that effort is low. The LIMA paper (Zhou et al. 2023) showed 1,000 carefully curated examples outperform 50,000 noisy examples. Above 10K examples, your time is better spent on quality filtering (LLM-as-judge scoring, deduplication, task diversity analysis) than on raw data collection. A dataset of 20,000 high-quality examples will outperform 200,000 examples with no quality filtering. :::
Interview Q&A
Q1: What was the key insight from the FLAN paper and why did it change how people thought about fine-tuning?
The FLAN paper (Wei et al. 2021) showed that training a model on many different NLP tasks formatted as natural language instructions produced zero-shot generalization to new, unseen tasks. The key insight was that it is not the specific tasks in the training set that matter - it is the meta-skill of instruction following that transfers.
Before FLAN, the assumption was that you fine-tuned a model for a specific task. FLAN showed that fine-tuning on a curriculum of diverse tasks teaches the model what "following instructions" means as a general capability. This generalized instruction-following then transfers to new task types without any additional examples.
The implication: you do not need examples of every possible task your model will encounter. You need enough task diversity during fine-tuning to teach the general skill. This insight is the foundation of every chat model and instruction-following LLM that followed.
Q2: How does loss masking work in instruction tuning, and why is it necessary?
Loss masking sets the target labels to -100 for all instruction tokens, telling PyTorch's CrossEntropyLoss to ignore those positions when computing the gradient. Only the assistant response tokens contribute to the loss.
This is necessary because at inference time, you provide the instruction and generate only the response. If you train on both instruction and response tokens, the model optimizes for predicting instruction tokens - which it never does at inference. Worse, it can create a feedback loop where the model learns that the presence of certain instruction patterns predicts specific responses, rather than learning to follow the semantic content of the instruction.
Practically: models trained without loss masking often have degraded instruction-following behavior and tend to "echo" parts of the instruction in the response. The fix is simple: apply loss masking correctly, always.
Q3: Describe the scaling behavior of instruction tuning. At what data scale does quality start to matter more than quantity?
Instruction tuning follows a four-phase curve. Phase 1 (0-1K examples): rapid improvement as the model learns basic instruction-following behavior - it learns what instructions look like and how to format responses. Phase 2 (1K-10K examples): continued strong gains, primarily driven by adding new task types. Phase 3 (10K-100K examples): diminishing returns per example, but new task categories still help. Phase 4 (100K+ examples): quality becomes the dominant factor. Adding low-quality examples actively hurts at this scale.
The transition point where quality beats quantity is approximately 10,000 examples. LIMA (Zhou et al. 2023) demonstrated this empirically: 1,000 carefully selected, human-written examples fine-tuned a 65B model to strong multi-task performance. The intuition: instruction tuning teaches behavior, not knowledge. Once the behavior pattern is learned (which requires relatively few high-quality examples), additional low-quality examples introduce noise without adding signal.
Practically: for datasets up to 10K, focus on coverage (task diversity). For datasets above 10K, focus on quality filtering - LLM-as-judge scoring, deduplication, and task balance matter more than raw data collection.
Q4: How do you set up multi-node training for instruction tuning? What are the key infrastructure requirements?
Multi-node training with PyTorch uses torchrun (formerly torch.distributed.launch) to launch processes across nodes. Each node runs torchrun --nproc_per_node=N --nnodes=M --node_rank=K --master_addr=NODE0_IP with a different node_rank per node.
Key infrastructure requirements:
- Shared storage (NFS/EFS): all nodes must read the same dataset and write checkpoints to a shared path. Without this, nodes will diverge or crash when the master writes a checkpoint.
- Network bandwidth: NCCL AllReduce for gradient synchronization uses the inter-node network. For 16x A100s, you need at least 100 Gbps InfiniBand or equivalent. AWS p4d instances have 400 Gbps EFA for this reason.
- Port accessibility:
--master_port(default 29500) must be reachable from all nodes. Security groups or firewall rules must allow this. - Clock synchronization: all nodes should have synchronized clocks (NTP). Clock skew can cause mysterious NCCL timeouts.
Common failure mode: checkpoints fail because the checkpoint directory does not exist on the shared storage, or the NFS mount is read-only from worker nodes.
Q5: What is the difference between instruction fine-tuning and chat fine-tuning, and when does the distinction matter in practice?
Instruction fine-tuning trains on single-turn (instruction, response) pairs. The model learns to map one instruction to one response. It does not explicitly learn conversational dynamics - turn-taking, context maintenance across turns, graceful handling of corrections.
Chat fine-tuning trains on multi-turn conversation sequences. The model learns the full conversation structure: system prompt behavior, how to maintain context from previous turns, how to handle ambiguity by asking clarifying questions, how to incorporate corrections.
The distinction matters when your production use case is genuinely conversational. A customer support chatbot that needs to remember context across a 10-turn conversation requires multi-turn training data. A document processing pipeline that issues single commands to the model does not.
Common mistake: training on only instruction pairs and deploying in a chat interface. The model will often fail to maintain context across turns, ignore the conversation history, or produce responses that do not acknowledge previous turns. Fix: include at least 10-15% multi-turn examples in your training data even if single-turn is the primary use case, as a guard against context confusion.
Q6: Your instruction-tuned model performs well on your eval set but poorly in production. What are the first three things you check?
First, check for distribution shift between your eval set and production traffic. If your eval set was sampled from the training data distribution, it does not measure generalization. Pull 100 production examples and hand-evaluate them against your quality criteria. Compare the instruction phrasings in production to what was in training - if users phrase instructions differently than your training data, the model may fail even on tasks it handles well in eval.
Second, check for the top-k failure modes in your production logs. Cluster failed examples by similarity. Are they all from one task type? One instruction style? One topic area? If failures cluster, you have a coverage gap in your training data for that cluster.
Third, check whether the model's failures are errors in knowledge (it says the wrong thing) or errors in instruction following (it does the wrong thing). Knowledge errors require adding domain data. Instruction following errors require more diverse instruction examples or better loss masking. Conflating these leads to adding wrong data for the wrong problem.
Q7: Walk through the gradient accumulation calculation for a 2-node, 16 GPU setup targeting an effective batch size of 256.
Effective batch size = per_device_batch_size x gradient_accumulation_steps x num_gpus.
We have 16 GPUs total. If we can fit per_device_batch_size=4 in memory (4 x 2048 tokens x bf16 activations), then without gradient accumulation, our effective batch is 4 x 1 x 16 = 64. To reach effective batch size 256, we need gradient_accumulation_steps = 256 / (4 x 16) = 4.
So: per_device_batch_size=4, gradient_accumulation_steps=4, 16 GPUs, effective batch = 256.
What this means operationally: each GPU runs 4 forward-backward passes without synchronizing gradients. After 4 passes (gradient accumulation), NCCL AllReduce runs across all 16 GPUs to average the accumulated gradients, then the optimizer step updates the weights. This reduces the number of AllReduce calls by 4x, improving efficiency when inter-node bandwidth is a bottleneck.
