Skip to main content

Continual Learning and Domain Adaptation

The Production Scenario: Building the Medical AI That Cannot Forget Medicine

A healthcare startup contracted your team to build an AI assistant for hospital clinical staff. The use case is narrow and well-defined: ICU nurses need quick answers to medication interaction questions, dosage calculations, and protocol lookup. Patients are at risk when nurses spend three minutes finding an answer they should get in ten seconds.

You start with Llama 3.1 8B, the best open-source base model at the time of the project. You collect 50,000 clinical question-answer pairs from your hospital partner and run standard instruction tuning. The model learns to follow question-answer format quickly. Two weeks later, you demo it to the clinical team. They are initially impressed. Then a senior nurse asks about a drug interaction involving tacrolimus and fluconazole - a well-known combination with serious immunosuppression implications. The model gives a generic response about "consulting a pharmacist." It has no idea what tacrolimus is for. It does not know that this combination requires immediate dose adjustment. It is producing confident-sounding hedges about information it fundamentally does not have.

The problem is that instruction tuning did not teach the model medical knowledge - it taught the model medical conversation format. The base model never encountered enough clinical pharmacology text during its pre-training to build a useful internal representation of drug mechanisms, interactions, and dosing. You fine-tuned the presentation layer without ever addressing the knowledge layer. The model learned to answer in the style of a clinical expert without acquiring clinical expertise.

What you actually needed was two phases. First, continual pre-training (CPT): take the base model and continue pre-training it on a large corpus of medical text - PubMed abstracts, clinical guidelines, pharmacology textbooks, drug databases. This would build internal representations of medical concepts, teach the model the vocabulary of clinical medicine, and establish the knowledge scaffolding that instruction tuning can then activate. Second, instruction tuning on clinical Q&A to teach the model to use that knowledge in the format your users need. Skipping the first phase and going straight to the second is the most common mistake teams make when deploying AI in specialized domains.

But CPT introduces its own dangerous failure mode: catastrophic forgetting. The model trained for months on web-scale general text. Its ability to reason, follow instructions, write coherently, and handle diverse linguistic patterns is baked into billions of parameters. CPT on a narrow medical corpus risks overwriting those general capabilities. You might end up with a model that knows more about tacrolimus but cannot parse a compound sentence written by someone who is not a physician. In production, your users do not ask only about drugs - they describe situations, embed questions in context, and expect the model to understand them. A model that forgot how to read general English to become better at clinical English is not an improvement.

This lesson covers how to do CPT right: how much general text to mix in, what learning rates prevent forgetting, how to measure whether you are winning on domain knowledge without losing on general capability, and how examples like BioMedLM and LegalBERT solved these trade-offs at production scale.


Why This Exists: The Knowledge Gap That Fine-Tuning Cannot Fill

Standard fine-tuning assumes the knowledge is already there. You take a capable base model and teach it a new format, a new tone, or a new task structure. The model has already absorbed vast knowledge about the world from pre-training - fine-tuning activates and directs that knowledge. This assumption is mostly valid for general tasks. When you fine-tune a model on coding instructions, it is not learning what Python is - it already knows Python from pre-training on GitHub. It is learning to apply that knowledge in a Q&A format.

The assumption breaks in specialized technical domains that are underrepresented in general pre-training corpora. The internet contains vastly more general English than clinical pharmacology. Llama's training data was roughly a trillion tokens of web text, books, and code. Clinical pharmacology literature occupies perhaps a few billion tokens of that corpus - 0.3% of the training budget. The model has surface-level familiarity with medical terms but shallow understanding of the mechanisms, interactions, and evidence behind them.

The evidence for this knowledge gap is empirical. Ran et al. (2023) showed that models trained on general corpora score 15-25 points below domain-specialist models on domain-specific QA benchmarks, even when the general model is 3x larger. Instruction tuning does not close this gap - it narrows it slightly by making the general model better at presenting its limited domain knowledge, but the knowledge limit remains. The only way to genuinely raise domain expertise is to give the model more domain text at the pre-training stage or to extend pre-training on domain text after the fact. The latter is continual pre-training.


Historical Context: From EWC to Modern CPT Pipelines

The neural network catastrophic forgetting problem was first rigorously described by McCloskey and Cohen in 1989 in the context of simple recurrent networks. The observation was that when you train a neural network on task B after training it on task A, it "forgets" task A with surprising speed - the gradient updates that improve performance on B actively degrade performance on A because they shift shared weights in the wrong direction.

For shallow networks, forgetting is total and nearly immediate. For deep transformers, it is more gradual but still dangerous. Kirkpatrick et al. (2017) proposed Elastic Weight Consolidation (EWC) as a principled fix: compute the Fisher information matrix of the loss with respect to all parameters after task A training, then add a regularization term during task B training that penalizes changes to parameters that were important for task A. The Fisher diagonal approximates parameter importance, and the quadratic penalty keeps important parameters near their task A values.

EWC was designed for small networks on toy tasks (Atari games, MNIST). Scaling it to billion-parameter transformers is computationally impractical in its original form - computing the full Fisher matrix for a 7B parameter model requires 7B x 7B storage. Practitioners use diagonal approximations and partial application (only applying EWC to specific layers), which works well enough in practice.

The modern CPT approach for LLMs was popularized by the BioMedLM paper (Bolton et al., 2022) and MedPaLM series. BioMedLM continued pre-training GPT-2 scale models on PubMed abstracts and full-text biomedical articles, showing dramatic improvements on BioASQ and MedQA benchmarks. The key insight was that you did not need a specialized architecture - just domain text at sufficient scale. Concurrent work on LegalBERT (Chalkidis et al., 2020) and FinBERT (Araci, 2019) demonstrated the same pattern in law and finance.

The "aha moment" for the broader community came from the Don't Stop Pretraining paper (Gururangan et al., 2020). They systematically studied domain-adaptive pre-training across four domains (biomedical, computer science, news, reviews) and showed that continued pre-training on domain text consistently improved downstream task performance, even when followed by task-specific fine-tuning. Crucially, they found that the order mattered: base model -> domain CPT -> task fine-tuning outperformed base model -> task fine-tuning by 2-5 points on every benchmark they tested. This established domain CPT as a first-class step in the specialization pipeline.


Core Concepts: CPT, Forgetting, and the Data Mixing Trade-Off

Continual Pre-Training vs Fine-Tuning: What Each Phase Teaches

The distinction between CPT and instruction fine-tuning is sometimes blurred, but it is conceptually important.

Continual pre-training trains the model on raw domain text using the same next-token prediction objective as the original pre-training. The input is unstructured domain text: PubMed abstracts, legal opinions, financial reports, medical textbooks. There is no instruction-following format, no Q&A structure, no system prompt. The model is learning the statistical patterns of domain language: which words co-occur with which concepts, what syntactic patterns are common in clinical writing, which entities are related, what the flow of reasoning looks like in a domain expert's writing. This is knowledge acquisition.

Instruction fine-tuning trains the model on (instruction, response) pairs using a supervised next-token loss over the response tokens. The model learns to map user instructions to appropriate responses in the format you define. The knowledge must already exist in the model's weights for instruction fine-tuning to elicit it correctly. This is knowledge activation.

CPT teaches the model what to know. Instruction fine-tuning teaches the model how to express what it knows.

Catastrophic Forgetting in Large Language Models

Catastrophic forgetting in transformers is subtler than in shallow networks. A large language model has many components that generalize across domains: attention patterns for long-range dependency, position encodings, general reasoning circuits, instruction-following behavior, code comprehension. These are spread across many layers and many parameters. When you CPT on a narrow domain, gradient updates optimize for predicting domain text. Parameters that were critical for general capabilities but are not strongly used when predicting biomedical text will drift.

The severity of forgetting depends on:

Learning rate: Higher learning rates cause larger weight updates per step, faster forgetting. CPT uses learning rates 10-30x lower than original pre-training. If the base model was trained at 3×1043 \times 10^{-4}, CPT typically uses 1×1051 \times 10^{-5} to 3×1053 \times 10^{-5}.

Domain distance: A model CPT'd on clinical text will forget less general capability than one CPT'd on a highly specialized formal language (e.g., legal Latin). The further the domain vocabulary and syntax are from general English, the more aggressively CPT reshapes the representations.

Data volume: More domain text means more forgetting. 1B domain tokens causes measurably more forgetting than 100M domain tokens. The relationship is roughly logarithmic - early domain tokens cause the most learning and the most forgetting, later tokens have diminishing returns on both.

Model size: Larger models forget less per update because their representations are higher-dimensional and more distributed. A 70B parameter model can absorb more domain specialization before general capability degrades noticeably compared to a 7B model.

Elastic Weight Consolidation for Transformers

EWC adds a regularization term to the CPT loss that penalizes changes to weights that were important for general capabilities. The augmented loss is:

LEWC(θ)=LCPT(θ)+λ2iFi(θiθi)2\mathcal{L}_{EWC}(\theta) = \mathcal{L}_{CPT}(\theta) + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_i^*)^2

where:

  • θi\theta_i^* are the model weights after general pre-training (or SFT)
  • FiF_i is the Fisher information for parameter ii, approximating how important that parameter is
  • λ\lambda controls the forgetting penalty strength

The Fisher information diagonal is computed as:

Fi=ED[(logpθ(yx)θi)2]F_i = \mathbb{E}_{\mathcal{D}} \left[ \left( \frac{\partial \log p_{\theta^*}(y|x)}{\partial \theta_i} \right)^2 \right]

This is expensive to compute exactly for all parameters. Practical approximations:

  • Compute on a small sample (1,000-5,000 examples from general text)
  • Apply only to specific parameter groups (attention weights, not MLP layers)
  • Use a running average updated incrementally rather than a single batch computation

For most practitioners, EWC in full form is overly complex. The simpler and often better alternative is data mixing.

Data Mixing: The Practical Forgetting Defense

The most widely used and empirically validated approach to managing forgetting in CPT is mixing domain text with a small fraction of general text during CPT. If 20% of every training batch comes from general web text and 80% comes from domain text, the model continuously receives gradient signal that preserves general capabilities while still learning the domain.

The trade-off is that you learn the domain more slowly per epoch - 20% of your compute is spent on non-domain text. But in practice, this is a good trade: the model takes 20% longer to acquire domain knowledge, and it retains close to 100% of general capability.

Gururangan et al. (2020) found that a 95%/5% (domain/general) split is close to the Pareto frontier for most domains. You get most of the domain specialization while preventing severe forgetting. For domains that are very far from general English (e.g., specialized code, mathematical notation), shift the ratio toward 85%/15% to allocate more replay.

Learning Rate Schedule for CPT

CPT should use a lower learning rate than the original pre-training and typically a cosine decay schedule without a warmup phase (or with very short warmup). The reasoning: the model is already well-initialized. You are making targeted adjustments, not training from a random initialization. Large learning rate warmups are needed when weights are random because you need to explore the loss landscape before committing. In CPT, the loss landscape is already well-explored - aggressive exploration causes forgetting.

Recommended schedule for CPT:

  • Peak LR: 1×1051 \times 10^{-5} to 5×1055 \times 10^{-5} (vs 3×1043 \times 10^{-4} for pre-training from scratch)
  • Warmup: 1% of total steps (or none)
  • Decay: cosine to 10%10\% of peak LR over the full CPT budget
  • Gradient clipping: 1.0 (same as pre-training)

Domain-Specific Tokenizer Extension

Base model tokenizers are built on general web text. They tokenize domain-specific terms inefficiently because those terms were rare in the tokenizer training data. A medical tokenizer trained on PubMed will represent "tacrolimus" as a single token. Llama's general tokenizer might represent it as ["t", "ac", "rol", "imus"] - four tokens. This inefficiency wastes context window and makes it harder for the model to build meaningful representations of the concept.

Extending the tokenizer involves:

  1. Training a domain-specific BPE or sentencepiece tokenizer on domain text
  2. Finding tokens in the domain tokenizer that are not in the base tokenizer (new domain tokens)
  3. Adding those tokens to the base vocabulary
  4. Initializing their embeddings as the average of the constituent token embeddings from the base tokenizer (e.g., the embedding for "tacrolimus" is initialized as the mean of the embeddings for "t", "ac", "rol", "imus")
  5. Running CPT so the model can learn domain-specific representations for the new tokens

Tokenizer extension is most valuable for domains with dense specialized terminology (medical, legal, chemistry). For domains that mostly use general English vocabulary in specialized ways (e.g., business email writing), tokenizer extension provides negligible benefit.


Code Examples

Preparing a Domain Corpus for CPT

import json
import os
from pathlib import Path
from datasets import Dataset, concatenate_datasets, load_dataset
from transformers import AutoTokenizer

def prepare_domain_corpus(
domain_data_path: str,
general_data_fraction: float = 0.05,
output_path: str = "./cpt_corpus",
tokenizer_name: str = "meta-llama/Llama-3.2-8B",
max_sequence_length: int = 2048,
):
"""
Prepare a mixed domain/general corpus for continual pre-training.

Args:
domain_data_path: Path to domain text files (one doc per line, JSONL with 'text' field)
general_data_fraction: Fraction of general text to mix in (0.05 = 5%)
output_path: Where to save the processed corpus
tokenizer_name: Tokenizer to use for tokenization
max_sequence_length: Chunk text into sequences of this length
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.pad_token = tokenizer.eos_token

# Load domain data
print(f"Loading domain data from {domain_data_path}...")
domain_dataset = load_dataset("json", data_files=domain_data_path, split="train")
print(f" Domain examples: {len(domain_dataset):,}")

# Load general data (small fraction of C4 or similar)
num_general = int(len(domain_dataset) * general_data_fraction / (1 - general_data_fraction))
print(f"Loading {num_general:,} general text examples (C4)...")
general_dataset = load_dataset(
"allenai/c4", "en",
split=f"train[:{num_general}]",
trust_remote_code=True
).select_columns(["text"])

print(f" General examples: {len(general_dataset):,}")

# Combine and shuffle
combined = concatenate_datasets([domain_dataset, general_dataset])
combined = combined.shuffle(seed=42)
print(f" Total combined: {len(combined):,}")
print(f" Domain fraction: {len(domain_dataset)/len(combined):.1%}")

# Tokenize and chunk into fixed-length sequences
def tokenize_and_chunk(examples):
all_input_ids = []
for text in examples["text"]:
tokens = tokenizer(
text,
truncation=False,
padding=False,
return_attention_mask=False,
)["input_ids"]
# Add EOS token between documents
tokens.append(tokenizer.eos_token_id)
all_input_ids.extend(tokens)

# Chunk into max_sequence_length-sized blocks
chunks = []
for i in range(0, len(all_input_ids) - max_sequence_length, max_sequence_length):
chunk = all_input_ids[i: i + max_sequence_length]
chunks.append(chunk)

return {
"input_ids": chunks,
"labels": chunks, # CLM: labels = input_ids
}

tokenized = combined.map(
tokenize_and_chunk,
batched=True,
batch_size=1000,
remove_columns=combined.column_names,
num_proc=4,
desc="Tokenizing and chunking",
)

print(f" Tokenized chunks: {len(tokenized):,}")
print(f" Approx tokens: {len(tokenized) * max_sequence_length / 1e9:.2f}B")

tokenized.save_to_disk(output_path)
print(f"Corpus saved to {output_path}")
return tokenized

Continual Pre-Training with Transformers Trainer

import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
)
from datasets import load_from_disk
from peft import LoraConfig, get_peft_model, TaskType

# -- Load model --
model_name = "meta-llama/Llama-3.2-8B"
output_dir = "./medical-llm-cpt"

model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2", # requires flash-attn package
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Optional: use LoRA for parameter-efficient CPT
# Full CPT is better for maximum domain specialization but requires more memory
USE_LORA = True
if USE_LORA:
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=64, # Higher rank than alignment LoRA - CPT benefits from more parameters
lora_alpha=128,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.0, # No dropout for CPT - it is pretraining, not fine-tuning
bias="none",
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# -- Load CPT corpus --
corpus = load_from_disk("./cpt_corpus")
corpus = corpus.train_test_split(test_size=0.005, seed=42) # 0.5% eval
train_dataset = corpus["train"]
eval_dataset = corpus["test"]

print(f"Train: {len(train_dataset):,} chunks")
print(f"Eval: {len(eval_dataset):,} chunks")

# -- Data collator --
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # Causal LM, not masked LM
)

# -- Training arguments - NOTE the low learning rate --
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=1, # CPT usually 1-3 epochs over domain corpus
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=16, # effective batch size = 64
learning_rate=2e-5, # 10x lower than instruction fine-tuning
lr_scheduler_type="cosine",
warmup_ratio=0.01, # very short warmup for CPT
weight_decay=0.1,
max_grad_norm=1.0,
bf16=True,
logging_steps=50,
eval_steps=500,
save_steps=1000,
evaluation_strategy="steps",
save_strategy="steps",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
report_to="wandb",
dataloader_num_workers=4,
gradient_checkpointing=True, # saves memory at cost of compute
optim="adamw_torch_fused", # faster than default adamw
)

# -- Trainer --
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)

# -- Train --
trainer.train()
trainer.save_model(output_dir)
print(f"CPT complete. Model saved to {output_dir}")

Elastic Weight Consolidation (Simplified for Transformers)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

def compute_fisher_diagonal(
model: nn.Module,
tokenizer,
general_text_samples: list[str],
num_samples: int = 2000,
device: str = "cuda",
) -> dict[str, torch.Tensor]:
"""
Compute diagonal Fisher information matrix for EWC.
Returns a dict mapping parameter names to importance weights.
"""
model.eval()
fisher = {name: torch.zeros_like(param) for name, param in model.named_parameters()
if param.requires_grad}

samples = general_text_samples[:num_samples]
total = 0

for text in samples:
inputs = tokenizer(
text,
return_tensors="pt",
max_length=512,
truncation=True,
).to(device)

# Forward pass with labels = input_ids for CLM
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss

model.zero_grad()
loss.backward()

for name, param in model.named_parameters():
if param.requires_grad and param.grad is not None:
# Fisher diagonal approximation: squared gradient
fisher[name] += param.grad.data.pow(2)

total += 1
if total % 100 == 0:
print(f" Fisher computation: {total}/{num_samples}")

# Normalize by number of samples
for name in fisher:
fisher[name] /= total

return fisher


class EWCLoss(nn.Module):
"""Adds EWC regularization to the standard CPT loss."""

def __init__(
self,
model: nn.Module,
fisher: dict[str, torch.Tensor],
original_params: dict[str, torch.Tensor],
lambda_ewc: float = 5000.0,
):
super().__init__()
self.model = model
self.fisher = fisher
self.original_params = original_params
self.lambda_ewc = lambda_ewc

def compute_ewc_penalty(self) -> torch.Tensor:
"""Compute the EWC regularization term."""
penalty = 0.0
for name, param in self.model.named_parameters():
if name in self.fisher and param.requires_grad:
importance = self.fisher[name]
reference = self.original_params[name].to(param.device)
penalty += (importance * (param - reference).pow(2)).sum()
return self.lambda_ewc * 0.5 * penalty

def forward(self, cpt_loss: torch.Tensor) -> torch.Tensor:
ewc_penalty = self.compute_ewc_penalty()
return cpt_loss + ewc_penalty


# Usage example:
# 1. Load base model
# 2. Compute Fisher on general text
# 3. Save original parameter values
# 4. Run CPT with EWC loss

def setup_ewc_training(model_name: str, general_samples: list[str]):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)

print("Computing Fisher information diagonal...")
fisher = compute_fisher_diagonal(model, tokenizer, general_samples)

# Save original parameter values (reference point)
original_params = {
name: param.data.clone()
for name, param in model.named_parameters()
if param.requires_grad
}

ewc_loss_fn = EWCLoss(model, fisher, original_params, lambda_ewc=1000.0)
print("EWC setup complete.")
return model, ewc_loss_fn

Extending the Tokenizer with Domain Vocabulary

from transformers import AutoTokenizer
from tokenizers import ByteLevelBPETokenizer
import os

def extend_tokenizer_with_domain_vocab(
base_tokenizer_name: str,
domain_corpus_path: str,
domain_vocab_size: int = 5000,
min_frequency: int = 50,
output_path: str = "./domain-tokenizer",
):
"""
Train a domain BPE tokenizer and merge new tokens into the base tokenizer.
"""
base_tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name)
print(f"Base vocab size: {len(base_tokenizer)}")

# Train domain BPE tokenizer on domain corpus
print("Training domain tokenizer...")
domain_tokenizer = ByteLevelBPETokenizer()
domain_tokenizer.train(
files=[domain_corpus_path],
vocab_size=domain_vocab_size,
min_frequency=min_frequency,
special_tokens=["<s>", "</s>", "<unk>", "<pad>", "<mask>"],
)

# Find new tokens not in base vocabulary
domain_vocab = set(domain_tokenizer.get_vocab().keys())
base_vocab = set(base_tokenizer.get_vocab().keys())
new_tokens = list(domain_vocab - base_vocab)

# Filter out noise tokens (too short, all punctuation, etc.)
new_tokens = [t for t in new_tokens if len(t) >= 4 and t.isalpha()]
print(f"New domain tokens to add: {len(new_tokens)}")
print(f"Sample new tokens: {new_tokens[:20]}")

# Add new tokens to base tokenizer
num_added = base_tokenizer.add_tokens(new_tokens)
print(f"Successfully added {num_added} new tokens")
print(f"New vocab size: {len(base_tokenizer)}")

base_tokenizer.save_pretrained(output_path)
print(f"Extended tokenizer saved to {output_path}")

return base_tokenizer, new_tokens


def initialize_new_token_embeddings(
model,
tokenizer,
new_tokens: list[str],
old_tokenizer,
):
"""
Initialize new token embeddings as the average of their constituent
subword token embeddings from the original tokenizer.
"""
model.resize_token_embeddings(len(tokenizer))
embedding_matrix = model.get_input_embeddings().weight.data

for token in new_tokens:
new_token_id = tokenizer.convert_tokens_to_ids(token)

# Tokenize the new token string with the OLD tokenizer
constituent_ids = old_tokenizer.encode(token, add_special_tokens=False)
if len(constituent_ids) == 0:
continue

# Average the constituent embeddings
constituent_embeddings = embedding_matrix[constituent_ids]
avg_embedding = constituent_embeddings.mean(dim=0)
embedding_matrix[new_token_id] = avg_embedding

print(f"Initialized embeddings for {len(new_tokens)} new tokens")
return model

Evaluating Domain Knowledge Gain and General Capability Loss

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
import json

def evaluate_before_after_cpt(
base_model_path: str,
cpt_model_path: str,
domain_tasks: list[str],
general_tasks: list[str],
) -> dict:
"""
Compare base model vs CPT model on domain-specific and general benchmarks.
Reports domain gain and general capability change.
"""
results = {}

for model_name, model_path in [("base", base_model_path), ("cpt", cpt_model_path)]:
print(f"\nEvaluating {model_name} model: {model_path}")

lm = HFLM(
pretrained=model_path,
dtype="bfloat16",
batch_size=8,
)

all_tasks = domain_tasks + general_tasks
task_results = evaluator.simple_evaluate(
model=lm,
tasks=all_tasks,
num_fewshot=5,
batch_size=8,
device="cuda",
)
results[model_name] = task_results["results"]

# Compute deltas
print("\n" + "="*60)
print("RESULTS SUMMARY")
print("="*60)

print("\nDomain tasks (higher = better domain knowledge):")
for task in domain_tasks:
base_score = results["base"].get(task, {}).get("acc,none", 0)
cpt_score = results["cpt"].get(task, {}).get("acc,none", 0)
delta = cpt_score - base_score
sign = "+" if delta >= 0 else ""
print(f" {task}: base={base_score:.3f} | cpt={cpt_score:.3f} | delta={sign}{delta:.3f}")

print("\nGeneral tasks (lower delta = less forgetting):")
for task in general_tasks:
base_score = results["base"].get(task, {}).get("acc,none", 0)
cpt_score = results["cpt"].get(task, {}).get("acc,none", 0)
delta = cpt_score - base_score
sign = "+" if delta >= 0 else ""
print(f" {task}: base={base_score:.3f} | cpt={cpt_score:.3f} | delta={sign}{delta:.3f}")

return results


# Example evaluation setup for medical domain
domain_tasks = [
"medqa", # MedQA (USMLE-style medical QA)
"medmcqa", # MedMCQA (medical multiple choice)
"pubmedqa", # PubMedQA (research question answering)
]

general_tasks = [
"mmlu", # General knowledge
"arc_challenge", # Reasoning
"hellaswag", # Commonsense completion
"truthfulqa_mc1", # Truthfulness
]

# results = evaluate_before_after_cpt(
# base_model_path="meta-llama/Llama-3.2-8B",
# cpt_model_path="./medical-llm-cpt",
# domain_tasks=domain_tasks,
# general_tasks=general_tasks,
# )

Full CPT Pipeline - Putting It Together

"""
Full continual pre-training pipeline for medical domain adaptation.
Run this script after preparing domain data.
"""
import os
import json
from pathlib import Path

# Stage 1: Prepare corpus with data mixing
print("Stage 1: Preparing mixed corpus...")
# prepare_domain_corpus(
# domain_data_path="pubmed_abstracts.jsonl",
# general_data_fraction=0.05, # 5% general text
# output_path="./cpt_corpus",
# tokenizer_name="meta-llama/Llama-3.2-8B",
# max_sequence_length=2048,
# )

# Stage 2: (Optional) Extend tokenizer with medical vocabulary
print("Stage 2: Extending tokenizer...")
# base_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-8B")
# extended_tokenizer, new_tokens = extend_tokenizer_with_domain_vocab(
# base_tokenizer_name="meta-llama/Llama-3.2-8B",
# domain_corpus_path="pubmed_abstracts.txt",
# domain_vocab_size=3000, # Add up to 3000 new medical tokens
# output_path="./medical-tokenizer",
# )

# Stage 3: Continual pre-training
print("Stage 3: Running continual pre-training...")
# See the full CPT training script above

# Stage 4: Instruction fine-tuning on medical Q&A
# After CPT, run standard SFT on (clinical_question, clinical_answer) pairs
# This activates the domain knowledge in the conversational format

print("\nCPT Pipeline complete.")
print("Next step: run instruction fine-tuning with medical QA pairs")
print("Then optionally run DPO with medical preference data for alignment")

Architecture Diagrams

The Full Domain Adaptation Pipeline

Catastrophic Forgetting: What Data Mixing Prevents

CPT vs Fine-Tuning: What Each Layer Learns


Production Engineering Notes

Estimating How Much Domain Text You Need

The domain knowledge gain from CPT roughly follows a logarithmic curve with the amount of domain text. Practical guidelines based on the Don't Stop Pretraining and related papers:

  • Under 100M domain tokens: minimal domain improvement, probably not worth the compute
  • 100M - 1B domain tokens: meaningful domain gains (8-15 points on domain benchmarks), acceptable forgetting with 5% replay
  • 1B - 10B domain tokens: strong domain specialization (15-25 points), requires careful replay and EWC for sensitive applications
  • Over 10B domain tokens: specialist-level domain performance, significant forgetting risk, consider full retraining from scratch instead

For medical domain: PubMed abstracts (37M abstracts x ~150 tokens average = ~5.5B tokens) is an excellent starting corpus. Adding full-text PMC articles (4M papers x ~4000 tokens = ~16B tokens) brings you to specialist-level territory.

Compute Budget for CPT

CPT over a 7B model on 1B tokens takes approximately:

  • 1B tokens / 2048 sequence length = 488,000 steps per epoch
  • With batch size 64 (4 GPUs, batch 4 per GPU, grad accum 4): ~7,600 steps
  • At 1 epoch: 7,600 steps x 4 GPUs x ~0.4 seconds per step = ~3,400 GPU-hours per A100

For a single A100 GPU: approximately 3,400 hours. For 8 GPUs (the minimum practical setup for 7B full CPT): ~425 hours or about 18 days. Use LoRA CPT to cut this by 3-4x at the cost of some specialization depth.

Checkpoint Strategy for CPT

Save checkpoints more frequently in CPT than in instruction fine-tuning because forgetting happens gradually. The optimal checkpoint (maximum domain gain + minimum forgetting) often occurs at 60-80% of the training budget, not at the end. Monitor both domain perplexity (should decrease) and general perplexity on held-out general text (should stay flat or increase only slightly). If general perplexity increases by more than 15% from baseline, roll back to the previous checkpoint and reduce learning rate or increase replay fraction.

Production Deployment Considerations

After CPT, you have a model that is specialized at the representation level but may not have learned to express that specialization in a useful format. Always run instruction fine-tuning after CPT before deployment. The CPT model is not more useful to end users than the base model - it is more capable of becoming useful after instruction fine-tuning.

One anti-pattern to avoid: deploying the CPT model directly with a system prompt instead of running instruction fine-tuning. The CPT model has not learned to follow instructions - it will continue text in the style of its domain corpus, not respond to user queries. A medical CPT model without instruction tuning will respond to "What is the dose of metformin?" by generating text that sounds like a PubMed abstract continuing from the question as a title.


Common Mistakes

:::danger Using a high learning rate for CPT

The most damaging mistake in CPT is using the same learning rate as instruction fine-tuning (1×1041 \times 10^{-4} or higher). At this learning rate, the model updates weights aggressively on domain text, and general representations are overwritten within the first few thousand steps. You can lose 10-20 points on MMLU in a single CPT run at high learning rate. Always start CPT at 1×1051 \times 10^{-5} to 5×1055 \times 10^{-5} and verify general benchmark stability before committing to a full CPT run.

:::

:::danger Skipping instruction fine-tuning after CPT and deploying directly

A continually pre-trained model is NOT ready for user-facing deployment. CPT teaches knowledge, not instruction following. If you deploy a CPT model with a system prompt and expect it to answer questions helpfully, you will get text that sounds like an expert's notebook, not a helpful assistant. CPT and instruction fine-tuning are complementary and sequential steps - CPT is always followed by instruction tuning before deployment.

:::

:::warning Using 100% domain text with no general replay

Many teams running CPT for the first time omit general replay to maximize domain learning speed. This is a mistake even when forgetting seems acceptable. Evaluation on general benchmarks after CPT without replay consistently shows 5-15 point drops on MMLU and similar benchmarks. More insidiously, the model loses subtle general capabilities (reasoning under ambiguity, handling edge cases in questions, graceful degradation when context is unclear) that are hard to measure but obvious to users. Always include at least 5% general replay. The domain learning cost is small and the forgetting prevention benefit is large.

:::

:::warning Evaluating only domain performance after CPT

Teams naturally evaluate the metric they optimized for - domain benchmark performance. But shipping a model that scores 15 points higher on MedQA but 8 points lower on reasoning benchmarks is a regression in overall capability. Always run a full evaluation sweep after CPT that includes at minimum: your domain benchmark, MMLU (general knowledge), and a reasoning benchmark like ARC-Challenge or HellaSwag. If any general benchmark drops more than 3 points relative to the base model, diagnose before deploying: increase replay fraction, reduce learning rate, or reduce CPT token budget.

:::

:::warning Applying tokenizer extension without re-initializing output head weights

When you add new tokens to the vocabulary and resize the embedding matrix, both the input embeddings and the output (logit) projection are resized. Most practitioners remember to initialize new input embedding rows as averages of constituent tokens. Fewer remember that the output weight matrix (lm_head) also has new rows that need initialization. If you leave lm_head new rows at zero (the default PyTorch behavior on resize), the model will never predict those new tokens - they will have zero logit and will never appear in generated text. Always initialize both the embedding and lm_head rows for new tokens.

:::


Interview Q&A

Q1: What is the difference between continual pre-training and instruction fine-tuning? When do you use each?

Continual pre-training (CPT) uses the same next-token prediction objective as the original pre-training, applied to raw domain text without any instruction-following structure. It teaches the model what to know: the vocabulary, factual associations, reasoning patterns, and distributional properties of a target domain. Instruction fine-tuning trains on (instruction, response) pairs using a supervised loss on response tokens only. It teaches the model how to express what it knows in a helpful, structured format.

Use CPT when the base model lacks deep knowledge of your domain - when it cannot recall key concepts, misuses domain terminology, or fails domain-specific benchmarks not because it cannot follow instructions but because the knowledge is simply absent. Use instruction fine-tuning when the base model has the knowledge but does not apply it in the format you need. In most specialized deployments, you need both in sequence: CPT to build domain knowledge, followed by instruction fine-tuning to activate it in the right format.

Q2: Explain catastrophic forgetting in the context of LLMs and how data mixing addresses it.

Catastrophic forgetting occurs because neural network weights are shared across all learned capabilities. When you train on domain text using gradient descent, every parameter update is a small step in the direction that reduces loss on domain text. Parameters encoding general English syntax, reasoning patterns, and factual knowledge from general pre-training are not "protected" - they shift toward whatever minimizes loss on the current training distribution. After enough updates, the parameters that encoded general capabilities have moved far enough that general performance degrades significantly.

Data mixing prevents this by ensuring the gradient signal is never purely from domain text. In each training step, a fraction of the batch (typically 5%) comes from general text. The gradient for that general text sample pushes parameters back toward general capabilities. The net result is that the optimizer is simultaneously pulled toward domain specialization and toward general capability retention. The equilibrium point - where domain gain and general forgetting balance - can be controlled by adjusting the mixing ratio. More general replay means less forgetting and slower domain specialization.

Q3: How does Elastic Weight Consolidation work and what are the practical limitations for large language models?

EWC estimates parameter importance using the diagonal of the Fisher information matrix: for each parameter, how much does the loss change when that parameter is perturbed? Parameters with high Fisher information are important for current capabilities and should be penalized more when they change. EWC adds a quadratic penalty to the training loss proportional to Fisher information times the squared deviation from the original parameter values.

The fundamental limitation for LLMs is computational: the Fisher information matrix is N×NN \times N where NN is the number of parameters. For a 7B parameter model, computing even the diagonal requires at least 28GB to store, and computing it accurately requires running backpropagation on thousands of samples. Practical approximations (diagonal Fisher estimated on a small subsample) work reasonably well but the computational overhead is still substantial - computing the diagonal Fisher for a 7B model takes several GPU-hours.

More importantly, EWC assumes that what the model "currently knows" can be captured in a single Fisher diagonal computed before CPT. In practice, the model's capabilities interact in complex ways, and a scalar importance weight per parameter is a coarse approximation. For most teams, data mixing achieves comparable forgetting prevention with none of the implementation complexity.

Q4: Why does the learning rate need to be much lower for CPT than for instruction fine-tuning?

Instruction fine-tuning is training on a much smaller, curated dataset - typically tens of thousands to a few million examples. Low learning rates are not strictly necessary because you stop training before you can do much damage, and the format of the training data is sufficiently different from pre-training that the gradient signal is relatively clean.

CPT involves training on potentially billions of tokens with gradient updates that touch every parameter. The model is already well-initialized from its original pre-training. The goal is to nudge the model toward domain specialization without overwriting general capabilities. A high learning rate causes large weight updates that quickly shift the model away from its pre-trained initialization. General capabilities, which were built over many billions of tokens of diverse training, cannot withstand the same high learning rate as instruction fine-tuning, which is a much lighter intervention.

The intuition: the more pre-training compute went into the model, the more "expensive" each unit of forgetting is, and the lower the learning rate needed to preserve what was built.

Q5: What are BioMedLM and LegalBERT, and what do they tell us about domain specialization at scale?

BioMedLM (Bolton et al., 2022, Stanford CRFM) is a 2.7B parameter language model trained from scratch on PubMed abstracts and full-text articles - approximately 21 billion tokens of biomedical text. It achieved state-of-the-art results on BioASQ (biomedical QA) and competitive results on MedQA despite being dramatically smaller than general models like GPT-3.

LegalBERT (Chalkidis et al., 2020) continued pre-training BERT-base on a large corpus of legal text from EU legislation, UK legislation, US court opinions, and contracts. It outperformed general BERT by 5-15 points on multiple legal NLP tasks (contract clause classification, legal judgment prediction) despite using the same architecture and no task-specific changes.

The lesson from both: domain specialization through pre-training or CPT is extremely high value per parameter. A smaller model with domain CPT routinely beats a larger general model on domain tasks. These results established domain-adaptive pre-training as a fundamental tool in the specialized AI deployment toolkit, separate from and complementary to instruction fine-tuning.

Q6: How do you decide whether to do full parameter CPT vs LoRA-based CPT?

Full parameter CPT updates every weight in the model, maximizing the depth of domain specialization. LoRA CPT updates only low-rank adapter matrices, leaving the base model weights frozen. The trade-offs:

Full parameter CPT is better when: you have billions of domain tokens (enough to justify the compute), you need deep representation-level specialization (the model needs to genuinely think in domain language, not just pattern-match on surface features), and you have the GPU budget (multiple A100s for weeks).

LoRA CPT is better when: you have limited compute, you are experimenting with domain curricula and need to iterate quickly, your domain corpus is under 500M tokens (LoRA can absorb that much specialization effectively), or you need to serve the same base model with multiple domain adapters (swap LoRA adapters at inference time without separate full models for each domain).

The empirical finding is that LoRA CPT with rank 64 on all projection layers captures roughly 70-80% of the domain specialization of full CPT at 10-20% of the compute cost. For most production use cases, that trade-off favors LoRA.

Q7: How do you evaluate whether CPT actually improved domain knowledge vs just changing response style?

Benchmark-based evaluation: run domain-specific QA benchmarks (MedQA, LegalBench, FinanceQA) that test factual knowledge, not just conversational ability. Ensure you run the benchmarks in few-shot or zero-shot format to prevent the instruction fine-tuning layer from masking the base model's knowledge. If the CPT model scores higher than the base model before instruction fine-tuning, domain knowledge increased at the representation level.

Perplexity evaluation: compute perplexity on held-out domain text (papers, documents not seen during CPT). Lower perplexity indicates the model has a better internal model of domain language. This is the most direct measure of what CPT is actually doing.

Probing tasks: design tasks that test specific knowledge the model should have acquired. For medical CPT: can the model correctly order drugs by half-life? Can it identify contraindications between two drug classes? These targeted probes reveal specific knowledge gaps rather than aggregate benchmark performance.

The key methodological point: evaluate the CPT model directly (without instruction fine-tuning) to measure knowledge gain in isolation. Once you add instruction fine-tuning, the evaluation conflates knowledge with instruction-following ability, making it hard to attribute performance changes to either stage.

© 2026 EngineersOfAI. All rights reserved.