Supervised Fine-Tuning
The Production Fine-Tuning Problem
A fintech company has deployed a general-purpose LLM to answer customer questions. The model is decent - it answers questions coherently, does not hallucinate too badly. But it keeps doing two things wrong. First, it gives investment advice with confident numbers that are subtly incorrect. Second, it apologizes too much - every answer starts with "I'm sorry but" as if the model is afraid of being wrong.
The team has 2,000 carefully human-annotated examples of ideal answers: factually correct, confident in tone, specific to their domain. They do not want to retrain the model from scratch ($5 million+). They do not want to use RLHF (complex, unstable, requires a reward model). They want to do something targeted: take this specific model and move it toward their specific behavior using these specific examples.
This is supervised fine-tuning. Continue training the pretrained model on your labeled (instruction, response) pairs, optimizing the same language modeling objective but on your data. It sounds simple, and in many ways it is - but the details matter enormously.
Three weeks later, the team has a model that gives confident, accurate financial answers with near-zero apologetic preamble. They used 2,000 examples, two epochs, a learning rate of 2e-5, and a single A100. The whole run cost $47. The improvement in customer satisfaction scores was 23%.
Why This Exists: The Gap Between Pretraining and Task Performance
A pretrained LLM knows an enormous amount. It knows grammar, facts, reasoning patterns, writing styles. But it is trained to continue text, not to solve specific problems on demand. Ask a base model "What is the capital of France?" and it might respond with a continuation of the text style it thinks this belongs to - perhaps "What is the capital of France? This is a common geography question asked in schools across..." rather than just "Paris."
Supervised fine-tuning bridges this gap. By showing the model thousands of (question, answer) pairs, you teach it what the expected format and style of a response looks like for your use case. The model already knows the answer - it knows that Paris is the capital of France. Fine-tuning teaches it how to express that knowledge in the form you want.
The key insight: fine-tuning is much more about format and style than about knowledge. The base model's knowledge is baked in from pretraining. Fine-tuning shapes how that knowledge is expressed.
Historical Context: From Task-Specific to General Fine-Tuning
2018 - BERT (Devlin et al.) demonstrated that fine-tuning a pretrained model on task-specific labeled data outperforms training task-specific models from scratch. The field shifted to "pretrain then fine-tune" as the default paradigm.
2019-2021 - Fine-tuning BERT-family models on NLU benchmarks (GLUE, SuperGLUE, SQuAD) became standard practice. But this was task-specific - a different fine-tuned model for each task.
2021 - FLAN (Wei et al.) showed that fine-tuning on a diverse mixture of tasks (instruction tuning) produced a model that generalized to new tasks zero-shot. This opened the door to general-purpose fine-tuning.
2022 - InstructGPT (Ouyang et al.) demonstrated that fine-tuning GPT-3 on high-quality human demonstrations (followed by RLHF) produced a model dramatically better at following instructions - despite using far fewer parameters than the raw GPT-3.
2023 - Alpaca (Stanford), Vicuna (Berkeley), and dozens of open-source SFT models showed that a 7B or 13B model fine-tuned on high-quality instruction data could approach GPT-3.5 quality on many tasks.
The Data is Everything
The most important lesson from years of SFT research: data quality dominates data quantity.
The Alpaca paper (Taori et al., 2023) fine-tuned LLaMA-7B on just 52,000 instruction-following examples generated by GPT-3.5, producing a model that performed surprisingly well on instruction following. The LIMA paper (Zhou et al., 2023) went further: fine-tuning on just 1,000 carefully selected examples produced a model competitive with RLHF-trained models. Their conclusion: "Almost all knowledge in large language models is learned during pretraining, and only limited instruction tuning data is necessary to teach models to produce high quality output."
What makes data high-quality?
-
Diversity of tasks and domains: The model should see classification, extraction, generation, reasoning, summarization, coding, and domain-specific tasks. A fine-tuning dataset with only customer service Q&A will produce a model that is narrow.
-
Correct and complete responses: Incorrect responses are worse than no data - they teach the model wrong answers. Each example should be verified.
-
Appropriate response length: Responses should match what a real expert would say - neither too terse (loses helpfulness) nor verbose (wastes tokens and teaches padding).
-
Consistent format: If your deployment uses a specific prompt template, your training data must use that exact template.
-
No instruction-following failures: Filter out examples where the response does not actually follow the instruction. These confuse the model about what the task is.
The SFT Training Objective
SFT uses the same cross-entropy language modeling loss as pretraining, but with one important modification: you typically only compute loss on the response tokens, not on the instruction/prompt tokens.
Naive approach - compute loss on everything:
Input: "Answer the question: What is the capital of France? The capital of France is"
Label: "Answer the question: What is the capital of France? The capital of France is Paris."
Loss is computed on "Answer the question: What is the capital of France? The capital of France is Paris." You are training the model to predict the prompt too - this is wasteful and sometimes harmful.
Better approach - compute loss only on the response:
Input: "Answer the question: What is the capital of France?"
Labels: [-100, -100, ..., "The", "capital", "of", "France", "is", "Paris", "."]
Set all prompt token labels to -100 (the ignore index). Only the response tokens contribute to the loss. This is cleaner - the model is trained to produce good responses, not to predict the prompt.
def create_training_labels(
input_ids, # Full sequence (prompt + response)
prompt_length, # Number of prompt tokens
ignore_index=-100
):
"""Create labels that only train on response tokens."""
labels = input_ids.clone()
labels[:prompt_length] = ignore_index # Ignore prompt tokens
return labels
Fine-Tuning Hyperparameters
Fine-tuning hyperparameters are very different from pretraining hyperparameters. The key difference: you are making small adjustments to an already-well-trained model, not training from scratch.
Learning rate: 1e-5 to 5e-5 for full fine-tuning. This is 5-50x smaller than typical pretraining learning rates (3e-4 to 1e-3). Too high and you destroy the pretrained knowledge. The LIMA paper used 1e-5 with cosine decay.
Epochs: 1 to 3 epochs is almost always optimal. More than 3 epochs usually hurts - the model memorizes the training examples rather than generalizing. Unlike pretraining (1 epoch over massive data), fine-tuning datasets are small enough to train on multiple times, but not many times.
Batch size: 8 to 128. Larger batch sizes with correspondingly higher learning rate (linear scaling rule) are more stable. The LIMA paper used batch size 32.
Warmup: 3-6% of total steps. Start with a very small learning rate and ramp up to the target LR. This prevents early destructive updates to pretrained weights.
Sequence length: Use the maximum sequence length your use case requires. Do not pad to the model's maximum (often 4096 or 8192) if your actual examples are 512 tokens - you waste compute.
Weight decay: 0.1 is a standard value. Helps prevent the model from concentrating all its "fine-tuning signal" in a few directions.
Catastrophic Forgetting
When you fine-tune a pretrained model on your specific task, there is a risk that the model forgets things it knew from pretraining. This is catastrophic forgetting - training on new data overwrites previously learned patterns.
For SFT with small, high-quality datasets, this risk is real. If you fine-tune only on financial Q&A, the model may "forget" how to write code or do general reasoning. Signs of catastrophic forgetting:
- Performance on general benchmarks drops significantly
- The model's vocabulary becomes repetitive (overfit to your domain's style)
- Out-of-domain queries receive strange responses
Mitigations:
- Small learning rate: keeps updates small, limits drift from pretrained weights
- Fewer epochs: 1-2 epochs, not 10
- Data mixing: include some general instruction-following data (e.g., FLAN) alongside your domain-specific data, even at a small fraction (5-10%)
- LoRA: instead of updating all weights, only update small adapter matrices. The pretrained weights are preserved exactly. (Covered in Lesson 07)
- Elastic Weight Consolidation (EWC): penalize updates to parameters that were important for previous tasks. Rarely used in practice due to complexity.
When SFT is Enough vs When You Need RLHF
SFT is sufficient when:
- You have high-quality labeled examples of desired behavior
- The target behavior can be expressed in supervised demonstrations
- Safety and alignment are not the primary concern
- The task has a relatively clear ground truth (code generation, factual QA, classification)
You likely need RLHF (or DPO) when:
- Human preferences are hard to express as demonstrations (what makes one response "better" than another?)
- Safety is critical - SFT on safe examples does not necessarily produce a safe model
- You want to optimize for subtle qualities like helpfulness, harmlessness, and honesty simultaneously
- The task requires the model to balance multiple objectives (be helpful but not harmful)
The practical heuristic: start with SFT. If evaluation shows the model is "following instructions" correctly but still producing subtly wrong outputs, add RLHF or DPO. For most production use cases, good SFT on quality data is 80% of the way there.
Code: Full Fine-Tuning with HuggingFace Trainer
"""
Full supervised fine-tuning of an instruction-following LLM.
- Custom data collator that masks prompt tokens
- Training on instruction-response pairs
- Evaluation with held-out set
- Gradient checkpointing for memory efficiency
"""
import torch
from dataclasses import dataclass
from typing import Dict, List, Optional
from torch import Tensor
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
PreTrainedTokenizer,
)
from datasets import Dataset
# ---- Data preparation ----
PROMPT_TEMPLATE = """Below is an instruction that describes a task.
Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
"""
def format_example(example: dict, tokenizer: PreTrainedTokenizer, max_length: int = 2048):
"""
Format a single (instruction, response) example.
Returns input_ids and labels where prompt tokens are masked to -100.
"""
prompt = PROMPT_TEMPLATE.format(instruction=example["instruction"])
full_text = prompt + example["response"] + tokenizer.eos_token
# Tokenize full text
tokenized = tokenizer(
full_text,
truncation=True,
max_length=max_length,
padding=False,
return_tensors=None,
)
# Find where the response starts
prompt_tokenized = tokenizer(prompt, add_special_tokens=False)
prompt_length = len(prompt_tokenized["input_ids"])
# Create labels: -100 for prompt tokens, actual ids for response tokens
labels = tokenized["input_ids"].copy()
labels[:prompt_length] = [-100] * prompt_length
tokenized["labels"] = labels
return tokenized
@dataclass
class SFTDataCollator:
"""
Pad sequences in a batch and handle variable-length inputs.
"""
tokenizer: PreTrainedTokenizer
padding_side: str = "right"
def __call__(self, features: List[Dict]) -> Dict[str, Tensor]:
max_len = max(len(f["input_ids"]) for f in features)
batch_input_ids = []
batch_attention_mask = []
batch_labels = []
for feature in features:
seq_len = len(feature["input_ids"])
padding_length = max_len - seq_len
if self.padding_side == "right":
input_ids = feature["input_ids"] + [self.tokenizer.pad_token_id] * padding_length
attention_mask = feature["attention_mask"] + [0] * padding_length
labels = feature["labels"] + [-100] * padding_length
else: # left padding
input_ids = [self.tokenizer.pad_token_id] * padding_length + feature["input_ids"]
attention_mask = [0] * padding_length + feature["attention_mask"]
labels = [-100] * padding_length + feature["labels"]
batch_input_ids.append(input_ids)
batch_attention_mask.append(attention_mask)
batch_labels.append(labels)
return {
"input_ids": torch.tensor(batch_input_ids),
"attention_mask": torch.tensor(batch_attention_mask),
"labels": torch.tensor(batch_labels),
}
# ---- Training setup ----
def run_sft(
model_name: str,
train_data: List[dict],
eval_data: List[dict],
output_dir: str,
):
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
use_cache=False, # Disable KV cache during training
)
model.gradient_checkpointing_enable()
# Prepare datasets
train_dataset = Dataset.from_list(train_data)
eval_dataset = Dataset.from_list(eval_data)
train_dataset = train_dataset.map(
lambda x: format_example(x, tokenizer),
remove_columns=["instruction", "response"]
)
eval_dataset = eval_dataset.map(
lambda x: format_example(x, tokenizer),
remove_columns=["instruction", "response"]
)
data_collator = SFTDataCollator(tokenizer=tokenizer)
# Hyperparameters - calibrated for 7B model, 1K-10K examples
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=8, # Effective batch size: 4 * 8 = 32
learning_rate=2e-5,
lr_scheduler_type="cosine",
warmup_ratio=0.05, # 5% warmup
weight_decay=0.1,
max_grad_norm=1.0,
bf16=True,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
logging_steps=10,
report_to="wandb",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
return trainer
# ---- Using TRL SFTTrainer (simpler API) ----
def run_sft_with_trl(model_name: str, train_data: list, output_dir: str):
"""
TRL's SFTTrainer handles prompt masking automatically
if your data has a 'text' field with the formatted prompt+response.
"""
from trl import SFTTrainer, SFTConfig
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
)
# TRL expects data with a 'text' field
def format_for_trl(example):
text = PROMPT_TEMPLATE.format(instruction=example["instruction"])
text += example["response"] + tokenizer.eos_token
return {"text": text}
dataset = Dataset.from_list(train_data)
dataset = dataset.map(format_for_trl)
sft_config = SFTConfig(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=2e-5,
bf16=True,
max_seq_length=2048,
)
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=dataset,
processing_class=tokenizer,
)
trainer.train()
return trainer
Evaluation
Evaluating fine-tuned models is harder than training them. Loss curves tell you the model is learning the training distribution, but they do not tell you if the model is actually good.
Loss curves: Watch for:
- Training loss decreasing smoothly: normal
- Validation loss diverging upward after epoch 1: overfitting - reduce epochs or use more data
- Both losses plateauing high: model not learning - check data format, learning rate, tokenizer
Qualitative evaluation: run your model on 20-50 held-out examples and read the outputs manually. This is irreplaceable. No benchmark tells you "the model is adding unnecessary apologies" or "the model confuses product SKUs."
Benchmark evaluation: run on standard benchmarks (MMLU, HellaSwag, TruthfulQA) to check for catastrophic forgetting. If fine-tuning dropped MMLU by more than 5 points, your learning rate is too high or you are training too many epochs.
MT-Bench and Arena: for instruction-following quality, LLM-as-judge evaluation (have GPT-4 rate the outputs) and arena comparisons (blind preference between base model and fine-tuned model) are the current best practices.
Start small, evaluate fast Before running your full fine-tuning job, run a 100-step test with a small data sample to verify: (1) no OOM errors, (2) loss is decreasing (not NaN), (3) the output format looks correct. A 100-step smoke test takes 5 minutes and catches 80% of bugs before you commit to a multi-hour run.
Common Mistakes
Leaking test data into training Fine-tuning datasets are often curated by selecting "good examples" from a larger pool. If you evaluated your model on the test set before curation, and then selected training examples that happen to cover the test distribution, you have data leakage. Your evaluation metrics will look better than reality. Always set aside your evaluation data before starting any data curation.
Training on the full prompt+response without masking Computing loss on the prompt tokens teaches the model to predict the prompt, which is wasteful and sometimes harmful. If your prompt template is "Answer the following: [question]", the model will start outputting "Answer the following:" when asked questions, because that pattern was heavily reinforced. Always set prompt token labels to -100.
Too many epochs on too little data With a 1,000-example dataset and 10 epochs, your model will memorize the training examples and perform poorly on anything outside that distribution. Watch validation loss - if it starts climbing after epoch 2, stop training. The LIMA paper achieved strong results in 15 epochs because their 1,000 examples were carefully curated for diversity. Without diversity, fewer epochs are safer.
Forgetting to set pad_token for GPT-style models
GPT-2 and LLaMA do not have a pad token by default - they are autoregressive models with no padding in their original design. If you try to batch-train without setting tokenizer.pad_token = tokenizer.eos_token, you will get an error or subtly wrong behavior. Always check and set this before training.
Interview Q&A
Q1: What is the difference between pretraining and supervised fine-tuning in terms of objective, data, and outcome?
Pretraining uses the same cross-entropy language modeling loss on raw text (no labeled pairs), trained on terabytes of data, learning general language representations. SFT uses the same loss but on (instruction, response) pairs - curated labeled data, typically thousands to hundreds of thousands of examples. The outcome: pretraining teaches the model language, facts, and reasoning; SFT teaches the model how to express that knowledge in a specific format, style, and domain. SFT learns format and style more than new knowledge. Crucially, SFT typically computes loss only on response tokens (not the prompt), while pretraining computes loss on all tokens.
Q2: The LIMA paper showed that 1,000 examples could compete with much larger datasets. How is this possible?
LIMA (Zhou et al., 2023) found that a carefully curated 1,000-example dataset fine-tuned on LLaMA-65B produced responses competitive with RLHF-trained models on human preference evaluations. The explanation is that "almost all knowledge is learned during pretraining." SFT is teaching the model response format and style, not new knowledge. If your 1,000 examples are genuinely diverse (covering many task types) and high quality (demonstrating the exact format and quality you want), the model learns the style quickly. The key is diversity + quality over quantity. A dataset of 100,000 low-quality or redundant examples trains worse than 1,000 high-quality diverse ones.
Q3: What is catastrophic forgetting and how do you detect and mitigate it?
Catastrophic forgetting occurs when fine-tuning on a narrow dataset overwrites the general representations learned during pretraining. Detection: run MMLU, HellaSwag, or other broad benchmarks on your fine-tuned model and compare to the base model. A drop greater than 5 points signals significant forgetting. Mitigation: (1) use a small learning rate (1e-5 to 3e-5); (2) limit to 1-3 epochs; (3) mix in general instruction data (5-10% of training set); (4) use LoRA - which keeps pretrained weights frozen entirely; (5) data diversity - a wider distribution of tasks reduces the risk of the model becoming narrow.
Q4: Why do you only compute loss on response tokens and not the full prompt+response?
Computing loss on prompt tokens trains the model to predict the prompt given itself - which teaches it nothing useful. More problematically, if the prompt template appears many times in training data, the model will learn to produce the prompt template pattern even in inappropriate contexts. A model trained with loss on the full "Answer the following question: {question}" template will sometimes respond to questions with "Answer the following question:" as if it is predicting what comes before the answer. Setting prompt labels to -100 (the ignore index) ensures the model only learns to produce good responses, not to predict the input.
Q5: When would you choose full fine-tuning over LoRA?
Full fine-tuning updates all parameters and produces the best possible adaptation for your task - all the model's capacity is available to change. Choose full fine-tuning when: (1) you have a large high-quality dataset (100K+ examples) and the compute budget to match; (2) you need the absolute best performance and quality is worth the cost; (3) you are doing domain-specific pretraining on a specialized corpus (medicine, law, code) where you want the entire representation space to shift; (4) you are doing continual pretraining to update knowledge. Choose LoRA when: compute and memory are constrained; you need to fine-tune frequently (e.g., per-user adapters); you want to preserve the base model's general capabilities; you are fine-tuning on a small dataset (LoRA's regularization effect helps prevent overfitting). In practice, LoRA at r=64 or above matches full fine-tuning quality on most tasks.
Advanced: SFT Data Curation Pipeline
Building a high-quality SFT dataset is where most real-world fine-tuning projects succeed or fail. Here is a production-grade data curation pipeline:
"""
SFT data curation pipeline.
Filters, deduplicates, and quality-scores instruction-response pairs.
"""
import re
import hashlib
from typing import List, Dict, Optional
import torch
from transformers import AutoTokenizer
class SFTDataCurator:
"""
Pipeline for curating high-quality SFT training data.
Applies multiple quality filters and deduplication.
"""
def __init__(self, tokenizer_name: str = "meta-llama/Llama-2-7b-hf"):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def filter_and_curate(
self,
examples: List[Dict],
min_response_tokens: int = 20,
max_response_tokens: int = 2048,
min_instruction_tokens: int = 5,
max_instruction_tokens: int = 512,
dedup_threshold: float = 0.9,
) -> List[Dict]:
"""
Apply quality filters and deduplication.
Filters:
1. Minimum/maximum length constraints
2. Language quality checks (avoid repetition, truncation)
3. Near-duplicate removal
4. Format validation
"""
print(f"Starting with {len(examples)} examples")
# Step 1: Length filtering
filtered = [e for e in examples if self._passes_length_check(
e, min_response_tokens, max_response_tokens,
min_instruction_tokens, max_instruction_tokens
)]
print(f"After length filtering: {len(filtered)}")
# Step 2: Quality checks
filtered = [e for e in filtered if self._passes_quality_check(e)]
print(f"After quality filtering: {len(filtered)}")
# Step 3: Deduplication
filtered = self._deduplicate(filtered, dedup_threshold)
print(f"After deduplication: {len(filtered)}")
return filtered
def _passes_length_check(
self, example: Dict,
min_resp: int, max_resp: int,
min_inst: int, max_inst: int
) -> bool:
inst_tokens = len(self.tokenizer(example["instruction"])["input_ids"])
resp_tokens = len(self.tokenizer(example["response"])["input_ids"])
return (
min_inst <= inst_tokens <= max_inst and
min_resp <= resp_tokens <= max_resp
)
def _passes_quality_check(self, example: Dict) -> bool:
response = example["response"]
# Check for repetition (signs of degenerate generation)
words = response.lower().split()
if len(words) > 10:
unique_ratio = len(set(words)) / len(words)
if unique_ratio < 0.3: # More than 70% repeated words
return False
# Check for abrupt truncation (response ends mid-sentence)
if len(response) > 50 and response[-1] not in ".!?\"'`":
# Might be truncated - check if it ends with a complete word
if response[-1].isalpha():
return False
# Check for empty or near-empty responses
if len(response.strip()) < 20:
return False
# Check for refusal patterns (too many refusals hurt helpfulness)
refusal_patterns = [
r"^i cannot",
r"^i'm unable",
r"^as an ai",
r"^i don't have access",
]
response_lower = response.lower().strip()
refusal_count = sum(
1 for p in refusal_patterns
if re.match(p, response_lower)
)
if refusal_count > 0:
# Keep some refusals (they are valid) but flag for inspection
example["_is_refusal"] = True
return True
def _deduplicate(self, examples: List[Dict], threshold: float) -> List[Dict]:
"""Remove near-duplicate instructions using MinHash-like approach."""
seen_hashes = set()
unique_examples = []
for example in examples:
# Create a normalized hash of the instruction
instruction = example["instruction"].lower().strip()
instruction = re.sub(r'\s+', ' ', instruction) # Normalize whitespace
instruction_hash = hashlib.md5(instruction.encode()).hexdigest()
if instruction_hash not in seen_hashes:
seen_hashes.add(instruction_hash)
unique_examples.append(example)
return unique_examples
def compute_diversity_score(self, examples: List[Dict]) -> Dict:
"""
Compute dataset diversity metrics.
A diverse dataset produces better instruction-following models.
"""
instruction_lengths = []
response_lengths = []
task_types = {
"question_answering": 0,
"code_generation": 0,
"summarization": 0,
"classification": 0,
"creative_writing": 0,
"math_reasoning": 0,
"other": 0,
}
for ex in examples:
inst = ex["instruction"].lower()
instruction_lengths.append(len(inst.split()))
response_lengths.append(len(ex["response"].split()))
# Simple heuristic task type detection
if any(k in inst for k in ["what", "who", "when", "where", "why", "how"]):
task_types["question_answering"] += 1
elif any(k in inst for k in ["code", "function", "program", "implement"]):
task_types["code_generation"] += 1
elif any(k in inst for k in ["summarize", "summary", "tldr"]):
task_types["summarization"] += 1
elif any(k in inst for k in ["classify", "category", "label"]):
task_types["classification"] += 1
elif any(k in inst for k in ["write", "story", "poem", "creative"]):
task_types["creative_writing"] += 1
elif any(k in inst for k in ["calculate", "solve", "math", "equation"]):
task_types["math_reasoning"] += 1
else:
task_types["other"] += 1
import statistics
return {
"total_examples": len(examples),
"avg_instruction_length": statistics.mean(instruction_lengths),
"avg_response_length": statistics.mean(response_lengths),
"instruction_length_std": statistics.stdev(instruction_lengths),
"task_distribution": {k: v/len(examples) for k, v in task_types.items()},
"task_diversity_entropy": self._entropy(list(task_types.values())),
}
def _entropy(self, counts: List[int]) -> float:
"""Shannon entropy of task type distribution. Higher = more diverse."""
import math
total = sum(counts)
if total == 0:
return 0
probs = [c / total for c in counts if c > 0]
return -sum(p * math.log2(p) for p in probs)
# ---- Data mixing for multi-task SFT ----
def create_balanced_sft_mix(
domain_data: List[Dict],
general_data: List[Dict],
safety_data: List[Dict],
code_data: List[Dict],
target_total: int = 50000,
) -> List[Dict]:
"""
Create a balanced SFT dataset from multiple sources.
Recommended proportions based on industry practice.
"""
import random
proportions = {
"general": 0.35, # General instruction following
"domain": 0.30, # Domain-specific task
"code": 0.20, # Code/reasoning
"safety": 0.15, # Safety, refusals, edge cases
}
mixed = []
sources = {
"general": general_data,
"domain": domain_data,
"code": code_data,
"safety": safety_data,
}
for source_name, proportion in proportions.items():
target_count = int(target_total * proportion)
source = sources[source_name]
if len(source) >= target_count:
selected = random.sample(source, target_count)
else:
# Upsample if not enough data
selected = source * (target_count // len(source) + 1)
selected = selected[:target_count]
print(f"WARNING: {source_name} data upsampled from {len(source)} to {target_count}")
for ex in selected:
ex["_source"] = source_name
mixed.extend(selected)
random.shuffle(mixed)
return mixed
Production Notes: Fine-Tuning at Different Scales
Fine-tuning decisions change significantly based on model and dataset scale:
Scale 1: Small model, small data (1B model, 1K examples)
- Use LoRA r=8, learning rate 2e-5, 5-10 epochs
- Risk: overfitting (use validation loss to monitor)
- Recommendation: heavy regularization (higher dropout, lower LR)
Scale 2: Medium model, medium data (7B model, 50K examples)
- Use LoRA r=16 or full fine-tuning with gradient checkpointing
- Learning rate 2e-5, 2-3 epochs
- This is the "sweet spot" - most production fine-tuning projects
Scale 3: Large model, large data (70B model, 500K examples)
- QLoRA or LoRA r=64 (full fine-tuning requires 8+ A100s)
- Learning rate 1e-5, 1-2 epochs
- Focus on data quality at this scale - model is capable enough, data is the bottleneck
Scale 4: Continual pretraining (any size, 1B+ tokens)
- Full fine-tuning required to deeply update representations
- Learning rate 1e-4 (higher than SFT - closer to pretraining scale)
- 1 epoch maximum to avoid catastrophic forgetting of base knowledge
Checkpoint Selection and Early Stopping
SFT models often peak in quality before the final checkpoint. Monitoring validation metrics - not just training loss - is essential.
from transformers import TrainerCallback, TrainerState, TrainerControl
import numpy as np
class SFTEarlyStoppingCallback(TrainerCallback):
"""
Custom early stopping for SFT training.
Stops when validation loss stops improving, unlike HF's default which requires
an eval metric rather than eval loss.
"""
def __init__(self, patience: int = 3, min_delta: float = 0.001):
self.patience = patience
self.min_delta = min_delta
self.best_eval_loss = float("inf")
self.patience_counter = 0
self.best_checkpoint_step = 0
def on_evaluate(
self,
args,
state: TrainerState,
control: TrainerControl,
metrics: dict,
**kwargs,
):
eval_loss = metrics.get("eval_loss", float("inf"))
if eval_loss < self.best_eval_loss - self.min_delta:
self.best_eval_loss = eval_loss
self.patience_counter = 0
self.best_checkpoint_step = state.global_step
print(f"[Step {state.global_step}] New best eval loss: {eval_loss:.4f}")
else:
self.patience_counter += 1
print(
f"[Step {state.global_step}] No improvement ({self.patience_counter}/{self.patience}). "
f"Best: {self.best_eval_loss:.4f}, Current: {eval_loss:.4f}"
)
if self.patience_counter >= self.patience:
print(f"Early stopping triggered. Best checkpoint: step {self.best_checkpoint_step}")
control.should_training_stop = True
return control
class LossComponentMonitor(TrainerCallback):
"""Monitor training loss components to detect overfitting early."""
def __init__(self, log_every: int = 50):
self.log_every = log_every
self.train_losses = []
self.eval_losses = []
def on_log(self, args, state: TrainerState, control, logs=None, **kwargs):
if logs and "loss" in logs:
self.train_losses.append((state.global_step, logs["loss"]))
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics and "eval_loss" in metrics:
self.eval_losses.append((state.global_step, metrics["eval_loss"]))
self._check_overfitting()
def _check_overfitting(self):
if len(self.eval_losses) < 3:
return
recent_eval = [l for _, l in self.eval_losses[-3:]]
recent_train = [l for _, l in self.train_losses[-50:]] if self.train_losses else []
if recent_train:
train_avg = np.mean(recent_train)
eval_avg = np.mean(recent_eval)
gap = eval_avg - train_avg
if gap > 0.3:
print(f"WARNING: Significant train/eval loss gap ({gap:.3f}). Possible overfitting.")
print(f" Training loss: {train_avg:.4f}, Eval loss: {eval_avg:.4f}")
print(" Consider: reducing epochs, increasing dropout, or reducing dataset size")
if len(recent_eval) >= 3 and recent_eval[-1] > min(recent_eval) * 1.01:
print("WARNING: Eval loss is increasing - overfitting detected. Consider early stopping.")
Multi-Turn Conversation SFT
Single-turn instruction fine-tuning is straightforward, but most production applications require multi-turn capability. Multi-turn SFT requires careful masking - you want the model to generate assistant turns, not the entire conversation.
from typing import Any
def format_multi_turn_conversation(
conversation: list[dict], # [{"role": "user"|"assistant", "content": "..."}]
tokenizer,
system_prompt: str = "",
max_length: int = 4096,
) -> dict[str, Any]:
"""
Format a multi-turn conversation for SFT.
Masks out user turns and system prompt - only assistant turns contribute to loss.
"""
# Build full text
parts = []
if system_prompt:
parts.append(f"<|system|>\n{system_prompt}\n<|endoftext|>")
for turn in conversation:
if turn["role"] == "user":
parts.append(f"<|user|>\n{turn['content']}\n<|endoftext|>")
else:
parts.append(f"<|assistant|>\n{turn['content']}\n<|endoftext|>")
full_text = "".join(parts)
encoding = tokenizer(
full_text,
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = encoding.input_ids[0].tolist()
labels = [-100] * len(input_ids) # Start with all masked
# Find assistant turns and unmask them
# We need to identify which tokens correspond to assistant responses
assistant_token = tokenizer.encode("<|assistant|>", add_special_tokens=False)
endoftext_token = tokenizer.encode("<|endoftext|>", add_special_tokens=False)
in_assistant_turn = False
skip_count = 0
for i, token_id in enumerate(input_ids):
# Detect assistant header start
if (
i + len(assistant_token) <= len(input_ids)
and input_ids[i:i+len(assistant_token)] == assistant_token
):
in_assistant_turn = True
skip_count = len(assistant_token) # Skip the "<|assistant|>" header tokens
if skip_count > 0:
skip_count -= 1
in_assistant_turn = in_assistant_turn and skip_count == 0
continue
if in_assistant_turn:
# Check for end of turn
if (
i + len(endoftext_token) <= len(input_ids)
and input_ids[i:i+len(endoftext_token)] == endoftext_token
):
in_assistant_turn = False
labels[i] = token_id # Include the end-of-turn token in loss
else:
labels[i] = token_id # This token contributes to loss
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
# Example usage
conversation = [
{"role": "user", "content": "What is gradient descent?"},
{"role": "assistant", "content": "Gradient descent is an optimization algorithm..."},
{"role": "user", "content": "How does the learning rate affect it?"},
{"role": "assistant", "content": "The learning rate controls step size..."},
]
# Only the two assistant responses will contribute to the loss
# The user messages and system prompt are masked (-100 labels)
Interview Q&A
Q1: Why do you compute loss only on response tokens during SFT, not on all tokens?
During SFT, the instruction and conversation history are given inputs - they are the context that the model uses to generate its response. Computing loss on these tokens would penalize the model for not perfectly predicting text that was externally provided, which is not what we want to train. More importantly, the instruction text was not produced by the model - it is the human-provided prompt. Training on it would introduce noise and potentially confuse the model about when it should be speaking vs. when it should be listening. Loss only on response tokens is analogous to a teacher/student setup: the student (model) is evaluated only on what it says, not on its ability to predict what the teacher says.
Q2: What is catastrophic forgetting in SFT and how do you mitigate it?
Catastrophic forgetting occurs when fine-tuning on a narrow distribution of SFT data overwrites the model's broadly pretrained representations. After fine-tuning, the model may become better at the target task but worse at everything else - general knowledge, reasoning, coding ability it had during pretraining. Mitigations: (1) Use LoRA instead of full fine-tuning - frozen base weights cannot be overwritten; (2) Mix 5–10% "general" data into your SFT dataset to prevent distribution collapse; (3) Use a small learning rate (1e-5 rather than 1e-4); (4) Fine-tune for fewer epochs (2-3 max, not 10+); (5) Monitor MMLU or HellaSwag accuracy before and after fine-tuning - if it drops more than 2 points, investigate.
Q3: What is the LIMA finding and what does it imply for SFT data collection?
LIMA (Zhou et al., 2023) fine-tuned LLaMA-65B on just 1,000 carefully selected examples and achieved performance competitive with models fine-tuned on 50K+ examples. The core claim: a pretrained model already knows how to follow instructions - it just needs alignment, not teaching. The implication: data quality matters far more than quantity for SFT. 100 excellent examples (clear instruction, appropriate response, correct format) beat 10,000 mediocre ones. In practice, invest in data curation: write your own examples rather than scraping, use multiple human reviewers, filter aggressively for quality, and prioritize coverage of failure modes over absolute volume.
Q4: How should you set the learning rate for SFT on a 7B model?
For SFT of a pretrained 7B model using LoRA: peak learning rate 2e-4 with cosine decay, warmup ratio 0.03 (first 3% of steps). For full fine-tuning of a 7B model: peak learning rate 2e-5 with cosine decay, warmup ratio 0.03. The full fine-tuning LR is 10x lower than LoRA because full FT modifies all parameters including sensitive early layers. Both benefit from warmup - without it, the first batch produces large gradients that destabilize the model. If loss diverges in the first 100 steps, the learning rate is too high. If the model doesn't improve after 500 steps, the learning rate may be too low or the data format is wrong.
Q5: What makes SFT data "high quality" from a training perspective?
Five properties of high-quality SFT data: (1) Correct: the response is factually accurate and does what the instruction asks; (2) Complete: the response fully addresses the instruction - not too terse, not padded; (3) Format-appropriate: if the instruction asks for a list, the response is a list; if it asks for code, the response includes code; (4) Consistent: the same instruction format throughout - inconsistent formatting confuses the model's learned "interface"; (5) Diverse: covers a wide range of tasks, topics, lengths, and styles - training on narrow SFT data produces narrow capabilities. The LIMA paper adds a sixth property: exemplary - not just "good" but the kind of response you would be proud to show as your model's best output.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Training Dynamics demo on the EngineersOfAI Playground - no code required.
:::
