Skip to main content

Axolotl and TRL Training Frameworks

The Production Fine-Tuning Problem

It is 11 PM on a Tuesday. You have spent the last three weeks curating 8,000 high-quality instruction examples for your company's internal code assistant. The data is clean. The use case is well-defined. Every stakeholder has signed off. All you need to do now is run the actual fine-tuning job.

You open your laptop and start writing training code from scratch. You hook up the HuggingFace Trainer, wire in your LoRA config, set up the tokenizer, write the dataset collation logic, handle the special tokens for your chat template, add gradient checkpointing, configure your optimizer, and - four hours later - you realize you have been reinventing the wheel that three other engineers at your company already reinvented last quarter. Your training loop crashes at step 847 because you forgot to handle sequences longer than the model's context window. You fix that. It crashes again because your GPU runs out of memory on the last batch. You fix that too. By 3 AM, the job is running. By morning, you check the eval metrics and realize your chat template was wrong the entire time, so the model learned to produce outputs with the wrong format.

This is not a hypothetical. It is the story of most first attempts at LLM fine-tuning in production environments. The actual math - LoRA rank decomposition, 4-bit quantization, loss computation - is the easy part. The operational complexity of wiring everything together correctly is where real projects stall.

Two frameworks exist specifically to absorb this complexity: HuggingFace TRL (Transformer Reinforcement Learning) and Axolotl. TRL is the reference implementation from HuggingFace - a set of purpose-built trainers for supervised fine-tuning, preference optimization, and reinforcement learning from human feedback. Axolotl is a config-driven training framework built on top of TRL and HuggingFace, designed for production multi-GPU fine-tuning runs where you want to specify everything in a YAML file and let the framework handle the rest.

Understanding when to use each, how to configure each correctly, and where each one can surprise you is the difference between a fine-tuning project that ships and one that spends three months in infrastructure debugging. This lesson covers both frameworks in depth - from their core abstractions down to the specific configuration knobs that matter for production runs.


Why This Exists - The Training Loop Problem

Before TRL and Axolotl, fine-tuning a language model meant writing a custom training loop using the base HuggingFace Trainer class. The Trainer is a general-purpose tool designed for any sequence-to-sequence or classification task. It has no concept of chat templates, no built-in support for packing short sequences together to maximize GPU utilization, no first-class integration with PEFT/LoRA, and no opinion about how your dataset should be formatted.

The result was that every team doing LLM fine-tuning wrote bespoke glue code. This glue code looked superficially similar across teams but differed in critical details - how padding tokens were masked in the loss function, whether the instruction portion of a prompt was ignored during loss computation (it should be), how special tokens were handled at sequence boundaries. These differences were not cosmetic. A training loop that includes the instruction tokens in the loss computation will produce a model that has memorized the instruction format rather than learned to follow instructions. A training loop with incorrect padding mask handling will compute loss on padding tokens, introducing noise that degrades model quality in ways that are hard to diagnose.

TRL's SFTTrainer was created to solve exactly this class of problem. It ships with correct defaults for supervised fine-tuning of language models: it masks instruction tokens from the loss by default (when you use the DataCollatorForCompletionOnlyLM), it handles chat template application, it supports sequence packing, and it integrates natively with PEFT. The team at HuggingFace used it internally to train their own models, which means the defaults reflect hard-won operational experience.

Axolotl went a step further and asked: what if you did not have to write Python code at all for a standard fine-tuning run? It replaced the Python configuration surface with YAML files that capture all meaningful training decisions - base model, dataset format, LoRA parameters, optimizer settings, hardware configuration - and added first-class support for the distributed training backends (FSDP, DeepSpeed) needed for multi-GPU jobs. For teams running repeated fine-tuning experiments with different hyperparameters or datasets, Axolotl makes it trivial to version-control your exact training configuration and reproduce any run exactly.


Historical Context - From Research Code to Production Tools

TRL was created by Leandro von Werra at HuggingFace in 2020, initially as an implementation of the PPO (Proximal Policy Optimization) algorithm for language model fine-tuning. The original goal was to make RLHF - reinforcement learning from human feedback - accessible without requiring teams to implement PPO from scratch. The InstructGPT paper from OpenAI (Ouyang et al., 2022) had demonstrated that RLHF was the key step separating raw language model capabilities from useful assistants, and TRL was the open-source answer.

Over 2023 and 2024, TRL expanded significantly. The addition of SFTTrainer brought supervised fine-tuning into the same ecosystem. DPOTrainer implemented Direct Preference Optimization (Rafailov et al., 2023), which simplified preference learning by eliminating the need for a separate reward model. RewardTrainer handled the reward modeling step when teams did need traditional RLHF pipelines. By 2024, TRL had become the reference implementation for nearly every open-source alignment and fine-tuning workflow.

Axolotl was created by Wing Lian in 2023 as a practical response to the growing complexity of fine-tuning configurations. The "aha moment" was simple: if you are running dozens of fine-tuning experiments, and each one requires modifying a Python script, you are going to introduce bugs. YAML configuration files are version-controlled, diffable, and shareable in a way that Python scripts with inline hyperparameters are not. Axolotl wrapped TRL and HuggingFace Transformers in a configuration layer that made the most common fine-tuning workflows expressible as pure YAML, while still allowing Python-level customization for teams that needed it.

The framework gained rapid adoption because it solved a real pain point: the gap between "I have the theory" and "I have a working production training run." By 2024, Axolotl had become the de facto choice for serious open-source fine-tuning projects, used by researchers at EleutherAI, independent practitioners on the HuggingFace Hub, and ML teams at companies that needed repeatable fine-tuning pipelines.


TRL - The Reference Implementation

The SFTTrainer

SFTTrainer is the starting point for almost all supervised fine-tuning work with TRL. It extends the standard HuggingFace Trainer with defaults and utilities specific to language model fine-tuning.

The three things SFTTrainer gets right that a raw Trainer does not:

1. Dataset formatting. SFTTrainer accepts datasets in multiple formats - raw text, instruction/response pairs, or pre-formatted conversation strings. It applies your tokenizer's chat template automatically when you specify dataset_text_field or pass a formatting function.

2. Sequence packing. Short training examples waste GPU memory because the batch is padded to the longest sequence. SFTTrainer with packing=True concatenates multiple short examples into a single long sequence up to max_seq_length, separated by EOS tokens. This dramatically improves GPU utilization when your average example length is much shorter than the model's context window.

3. PEFT integration. Pass a PeftConfig to SFTTrainer and it handles the LoRA setup for you - wrapping the model, printing trainable parameter counts, saving adapter weights separately at checkpoints.

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, TaskType
from trl import SFTTrainer, SFTConfig
import torch

# Load base model with 4-bit quantization for QLoRA
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # requires flash-attn installed
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # important for causal LMs

# LoRA configuration
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM,
)

# Dataset - expects "messages" column with chat format
dataset = load_dataset("your-org/your-dataset", split="train")

def format_chat(example):
"""Apply chat template to messages column."""
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
add_generation_prompt=False,
)
}

dataset = dataset.map(format_chat)

# SFTConfig (replaces TrainingArguments for TRL >= 0.8)
training_args = SFTConfig(
output_dir="./llama3-8b-finetuned",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # effective batch size = 16
gradient_checkpointing=True,
optim="paged_adamw_8bit", # 8-bit optimizer saves memory
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
logging_steps=10,
save_strategy="steps",
save_steps=100,
evaluation_strategy="steps",
eval_steps=100,
bf16=True,
max_seq_length=2048,
packing=True, # pack short sequences
dataset_text_field="text",
report_to="wandb",
)

trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
peft_config=lora_config,
tokenizer=tokenizer,
)

trainer.train()
trainer.save_model()

Masking the Instruction from Loss

One of the most impactful correctness details in supervised fine-tuning is ensuring the model only learns to predict the response tokens, not the instruction tokens. If you train with loss on the full sequence (instruction + response), the model spends capacity learning to predict your prompt template, which wastes parameters and can cause the model to be brittle to prompt variations.

TRL provides DataCollatorForCompletionOnlyLM for this purpose:

from trl import DataCollatorForCompletionOnlyLM

# For Llama 3 chat format, responses follow the <|start_header_id|>assistant<|end_header_id|> token
response_template = "<|start_header_id|>assistant<|end_header_id|>"

collator = DataCollatorForCompletionOnlyLM(
response_template=response_template,
tokenizer=tokenizer,
)

trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
peft_config=lora_config,
tokenizer=tokenizer,
data_collator=collator, # masks instruction tokens from loss
)

The collator finds the response_template token sequence in each batch item and sets the labels for all preceding tokens to -100 (HuggingFace's convention for "ignore this token in loss computation"). This is a small change with a large impact on model quality.

DPOTrainer - Direct Preference Optimization

After supervised fine-tuning, you often want to align the model's outputs with human preferences - making it more helpful, less harmful, or better at following specific output format conventions. The classical approach (RLHF with PPO) is complex and unstable. DPO (Rafailov et al., 2023) reformulates preference learning as a supervised objective that can be optimized directly without a reward model.

The DPO loss for a preference pair (chosen response ywy_w, rejected response yly_l) given prompt xx is:

LDPO=E[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))]\mathcal{L}_{DPO} = -\mathbb{E}\left[\log \sigma\left(\beta \log \frac{\pi_\theta(y_w | x)}{\pi_{ref}(y_w | x)} - \beta \log \frac{\pi_\theta(y_l | x)}{\pi_{ref}(y_l | x)}\right)\right]

In plain English: increase the probability of chosen responses relative to a frozen reference model, while decreasing the probability of rejected responses.

from trl import DPOTrainer, DPOConfig
from peft import get_peft_model

# DPO requires a preference dataset with (prompt, chosen, rejected) columns
dpo_dataset = load_dataset("your-org/preference-data", split="train")
# Expected format: {"prompt": "...", "chosen": "...", "rejected": "..."}

# Start from your SFT model as both policy and reference
sft_model = AutoModelForCausalLM.from_pretrained(
"./llama3-8b-finetuned/merged",
torch_dtype=torch.bfloat16,
device_map="auto",
)

dpo_lora = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
task_type=TaskType.CAUSAL_LM,
)

dpo_args = DPOConfig(
output_dir="./llama3-8b-dpo",
beta=0.1, # KL penalty coefficient - lower = more aggressive
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=5e-5, # lower LR than SFT
bf16=True,
max_length=1024,
max_prompt_length=512,
report_to="wandb",
)

dpo_trainer = DPOTrainer(
model=sft_model,
args=dpo_args,
train_dataset=dpo_dataset,
tokenizer=tokenizer,
peft_config=dpo_lora,
# reference model is automatically created from a frozen copy of sft_model
)

dpo_trainer.train()

RewardTrainer - Training a Reward Model

When you do need a full RLHF pipeline, RewardTrainer handles the reward model training step. The reward model learns to assign a scalar score to (prompt, response) pairs based on human preference labels:

from trl import RewardTrainer, RewardConfig

# Dataset needs (chosen, rejected) pairs - same as DPO
reward_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
num_labels=1, # classification head
torch_dtype=torch.bfloat16,
)

reward_config = RewardConfig(
output_dir="./reward-model",
num_train_epochs=1,
per_device_train_batch_size=4,
learning_rate=1e-5,
bf16=True,
max_length=512,
)

reward_trainer = RewardTrainer(
model=reward_model,
args=reward_config,
train_dataset=dpo_dataset,
tokenizer=tokenizer,
)

reward_trainer.train()

Axolotl - Config-Driven Production Fine-Tuning

The Core Philosophy

Axolotl replaces Python configuration with YAML. Every decision that matters for a fine-tuning run - base model, quantization, dataset format, LoRA parameters, optimizer, scheduler, distributed training backend - lives in a single YAML file. This has a profound effect on how teams work:

  • Training runs are reproducible by sharing a single file
  • Hyperparameter experiments are documented by default (git diff the YAML)
  • No Python expertise is required to run a training job
  • Common mistakes (wrong chat template, incorrect padding) are handled by the framework

The tradeoff is flexibility. If your training loop needs custom loss functions, custom data augmentation, or architectural modifications not supported by Axolotl's abstractions, you will eventually fight the framework. TRL with a custom training loop is the better choice in those cases. For standard instruction fine-tuning, DPO, and most QLoRA workflows, Axolotl is significantly faster to operate.

Complete Axolotl Config - LLaMA 3 8B Instruction Fine-Tuning

# axolotl_llama3_8b.yaml
# QLoRA fine-tuning of LLaMA 3 8B on instruction data

# --- Base Model ---
base_model: meta-llama/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer

# --- Quantization (QLoRA) ---
load_in_4bit: true
bnb_4bit_quant_type: nf4
bnb_4bit_compute_dtype: bfloat16
bnb_4bit_use_double_quant: true

# --- LoRA Configuration ---
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
lora_modules_to_save: # save these full-precision (not LoRA)
- embed_tokens
- lm_head

# --- Dataset Configuration ---
datasets:
- path: your-org/your-dataset
type: chat_template # use the model's chat template
split: train
field_messages: messages # column containing conversation messages

# Alternatively for custom datasets:
# datasets:
# - path: /local/path/to/data.jsonl
# type: alpaca_chat.load_no_prompt # many format types supported

dataset_prepared_path: /tmp/axolotl-prepared # cache tokenized dataset

val_set_size: 0.02 # 2% validation split

# --- Sequence Handling ---
sequence_len: 4096 # max sequence length
sample_packing: true # pack multiple samples per sequence
pad_to_sequence_len: true

# --- Training Arguments ---
output_dir: ./outputs/llama3-8b-instruct
num_epochs: 3
micro_batch_size: 2 # per GPU
gradient_accumulation_steps: 4 # effective batch = 2 * 4 * num_gpus
learning_rate: 0.0002
lr_scheduler: cosine
warmup_ratio: 0.05
weight_decay: 0.01
optimizer: paged_adamw_8bit
max_grad_norm: 1.0

# --- Mixed Precision and Performance ---
bf16: auto # use bf16 if hardware supports it
tf32: false
flash_attention: true # Flash Attention 2 - requires flash-attn

# --- Gradient Checkpointing ---
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false # newer PyTorch default, avoids some bugs

# --- Logging and Saving ---
logging_steps: 10
eval_steps: 100
save_steps: 100
save_total_limit: 3
eval_table_size: 5 # number of generation samples logged per eval

wandb_project: llama3-finetuning
wandb_name: llama3-8b-instruct-qlora-r16

# --- Chat Template ---
chat_template: llama3 # applies the correct special tokens
default_system_prompt: "You are a helpful, accurate, and honest assistant."

# --- Special Tokens ---
tokens: # ensure these are in tokenizer vocab
- "<|begin_of_text|>"
- "<|end_of_text|>"
- "<|eot_id|>"

Running Axolotl

# Single GPU training
python -m axolotl.cli.train axolotl_llama3_8b.yaml

# Multi-GPU with accelerate (FSDP or DeepSpeed - see below)
accelerate launch -m axolotl.cli.train axolotl_llama3_8b.yaml

# Preprocess dataset only (useful to check formatting before training)
python -m axolotl.cli.preprocess axolotl_llama3_8b.yaml

# Inference after training
python -m axolotl.cli.inference axolotl_llama3_8b.yaml \
--lora_model_dir ./outputs/llama3-8b-instruct

Dataset Format Types in Axolotl

Axolotl supports many dataset format types through the type field. Understanding which to use is critical - the wrong format silently trains on incorrectly formatted data:

# For ShareGPT-style data (list of {from, value} or {role, content} dicts)
datasets:
- path: your-org/sharegpt-data
type: sharegpt
conversation: llama3 # which conversation template to apply

# For Alpaca-style data (instruction, input, output fields)
datasets:
- path: your-org/alpaca-data
type: alpaca

# For raw completion (no instruction masking - trains on full sequence)
datasets:
- path: your-org/raw-text
type: completion
field: text

# For custom JSONL with explicit system/user/assistant fields
datasets:
- path: /path/to/custom.jsonl
type: chat_template
field_messages: messages
message_field_role: role
message_field_content: content

Axolotl Dataset Format Diagram


Multi-GPU Training - FSDP vs DeepSpeed ZeRO-3

Single-GPU fine-tuning with QLoRA can train a 7B-13B model on a single A100 80GB. For larger models (70B+) or when you want to speed up training with multiple GPUs, you need a distributed training strategy. Two major options exist: PyTorch FSDP (Fully Sharded Data Parallel) and DeepSpeed ZeRO-3.

Memory Math - Why You Need Sharding

A 70B model in bf16 takes approximately 140 GB just for parameters. Add optimizer states (Adam requires 2x parameter memory for moment estimates), gradients, and activations, and you are looking at 500-600 GB for full fine-tuning. Even with QLoRA (4-bit weights, bf16 LoRA adapters), you need to load 35 GB of quantized weights plus bf16 LoRA gradients.

The key formula for understanding memory requirements during LoRA fine-tuning:

Mtotal=Mweights+MLoRA+Mgradients+Moptimizer+MactivationsM_{total} = M_{weights} + M_{LoRA} + M_{gradients} + M_{optimizer} + M_{activations}

Where:

  • MweightsM_{weights} = frozen 4-bit weights (not in optimizer, but loaded on GPU)
  • MLoRAM_{LoRA} = bf16 LoRA parameters (trainable)
  • MgradientsM_{gradients} = gradients for LoRA parameters only
  • MoptimizerM_{optimizer} = Adam states for LoRA parameters (2x LoRA parameter count)
  • MactivationsM_{activations} = intermediate activations (reduced by gradient checkpointing)

FSDP Configuration for Axolotl

FSDP shards model parameters, gradients, and optimizer states across GPUs. Each GPU holds a shard of every tensor, and all-gather operations reconstruct full tensors during forward/backward passes:

# Add to axolotl config for FSDP
fsdp:
- full_shard # shard params, grads, and optimizer states
- auto_wrap # automatically wrap transformer layers

fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_offload_params: false # set true to offload to CPU (slower but less VRAM)
fsdp_forward_prefetch: true
fsdp_use_orig_params: true # required for LoRA with FSDP
# Launch FSDP job on 4 GPUs
accelerate launch --config_file fsdp_config.yaml \
-m axolotl.cli.train axolotl_llama3_70b_fsdp.yaml

The accelerate config file for FSDP:

# fsdp_config.yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: 1 # 1=FULL_SHARD, 2=SHARD_GRAD_OP, 3=NO_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4 # number of GPUs

DeepSpeed ZeRO-3 Configuration

DeepSpeed ZeRO (Zero Redundancy Optimizer) Stage 3 is similar to FSDP in that it shards parameters, gradients, and optimizer states. It additionally supports CPU offloading of optimizer states, enabling training of very large models on limited GPU memory:

# Add to axolotl config for DeepSpeed
deepspeed: deepspeed_zero3.json
// deepspeed_zero3.json
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
}
},
"bf16": {
"enabled": true
},
"gradient_clipping": 1.0,
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 4,
"steps_per_print": 10,
"wall_clock_breakdown": false
}

FSDP vs DeepSpeed - When to Use Each

Key practical differences:

  • FSDP is native PyTorch - no additional installation, better debugging support, integrates cleanly with PyTorch profiler. Preferred for most single-node multi-GPU setups.
  • DeepSpeed is more mature for multi-node training, has better CPU/NVMe offloading support, and handles very large models more efficiently. Preferred when you need CPU offload or have more than 8 GPUs.
  • LoRA + FSDP requires fsdp_use_orig_params: true - this is a common gotcha that causes cryptic errors if missed.
  • LoRA + DeepSpeed ZeRO-3 has historically had compatibility issues with the PEFT library. Check the PEFT and TRL release notes for the specific version combinations known to work.

Flash Attention 2 - The Performance Multiplier

Flash Attention 2 (Dao et al., 2023) is an algorithmically optimized attention implementation that rewrites the attention computation to avoid materializing the full attention matrix in GPU memory. Standard attention computes an n×nn \times n attention matrix where nn is sequence length, requiring O(n2)O(n^2) memory. Flash Attention tiles the computation across the SRAM of the GPU, reducing memory complexity to O(n)O(n) and dramatically improving throughput for long sequences.

Practical impact: on A100 GPUs with sequence length 2048, Flash Attention 2 typically gives 2-3x throughput improvement over standard attention.

Enabling it in Axolotl:

flash_attention: true # axolotl config

Enabling it in TRL/HuggingFace:

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
attn_implementation="flash_attention_2", # requires flash-attn package
torch_dtype=torch.bfloat16,
)

Flash Attention 2 requires:

  • pip install flash-attn --no-build-isolation (compilation can take 10-15 minutes)
  • Ampere or newer GPU (A100, A10, H100, RTX 3090/4090)
  • bfloat16 or float16 (not float32)

Gradient Checkpointing - Trading Compute for Memory

During the forward pass, PyTorch stores all intermediate activations in memory for use during the backward pass. For a transformer with 32 layers, this is a large amount of memory - roughly proportional to batch size x sequence length x model dimension x number of layers.

Gradient checkpointing (also called activation checkpointing) discards these intermediate activations during the forward pass, then recomputes them on-demand during the backward pass. The tradeoff: approximately 30-40% more compute in exchange for drastically reduced activation memory.

For fine-tuning, where memory is the primary constraint, gradient checkpointing is almost always worth enabling:

# TRL / HuggingFace
training_args = SFTConfig(
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
...
)
# Axolotl
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false

The use_reentrant: False setting uses the newer PyTorch implementation of gradient checkpointing and avoids subtle bugs related to in-place operations on tensors. For any new training run, use use_reentrant: False.


Production Engineering Notes

Checkpoint Management

Training crashes happen. Always configure checkpointing at a granularity that balances storage cost with recovery time:

# Axolotl
save_steps: 100 # save every 100 steps
save_total_limit: 3 # keep only 3 most recent checkpoints

For long runs (10K+ steps), saving every 100 steps may generate too many checkpoints. A practical heuristic: save approximately every 2-5% of total training, with a limit of 3-5 checkpoints. The save_total_limit ensures you do not fill disk.

Monitoring Training Health

Watch for these signals during training:

Loss curve shape: Loss should decrease smoothly. Spikes indicate learning rate too high, bad batches, or numerical instability. A loss that plateaus immediately suggests the learning rate is too low or the data is too similar to the base model's training distribution.

Gradient norm: The grad_norm metric in your logs should stay below your max_grad_norm (typically 1.0) after the warmup phase. Consistent clipping (grad_norm equal to max_grad_norm for many consecutive steps) suggests the learning rate is too high.

GPU utilization: Should be 85-95% during training. Lower utilization often indicates data loading is the bottleneck - increase dataloader_num_workers.

Resuming from Checkpoint

# Axolotl - automatic resume if output_dir contains a checkpoint
python -m axolotl.cli.train axolotl_config.yaml --resume_from_checkpoint true

# TRL - explicit checkpoint path
trainer.train(resume_from_checkpoint="./outputs/checkpoint-500")

Weights and Biases Integration

Both frameworks integrate with W&B for experiment tracking:

# Axolotl
wandb_project: my-fine-tuning-project
wandb_name: llama3-8b-r16-run1
wandb_log_model: checkpoint # upload checkpoints to W&B artifacts

Track at minimum: train loss, eval loss, grad norm, learning rate, GPU memory. Add custom metrics for task-specific evaluation using W&B callbacks.

The Saved Adapter vs Merged Model Decision

After training, you have two options for what to save:

Save adapter only (default): Small file (tens of MB for r=16), requires base model loaded separately. Best for iterative development.

Merge and save: Creates a full model with LoRA weights merged into base weights. Larger file, faster inference, deployable without PEFT library.

# Merge LoRA into base model and save
from peft import PeftModel
import torch

base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype=torch.bfloat16,
)
peft_model = PeftModel.from_pretrained(base_model, "./outputs/llama3-8b-finetuned")
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("./outputs/llama3-8b-merged", safe_serialization=True)
# Axolotl - merge during training via post-training hook
merge_adapters: true # merge LoRA into base after training completes

Common Mistakes

:::danger Wrong Chat Template The single most common cause of poor fine-tuning results is applying the wrong chat template. LLaMA 3 uses a different template than LLaMA 2, which uses a different template than Mistral, which uses a different template than Phi-3. If your Axolotl config specifies chat_template: llama3 but your base model is Mistral, every example in your training data is being tokenized incorrectly. The model still trains - loss goes down - but the trained model will produce garbled outputs at inference time because it was trained to predict tokens in the wrong positions.

Always verify the applied template by running: python -m axolotl.cli.preprocess your_config.yaml and inspecting the first decoded example in the prepared dataset. :::

:::danger Gradient Checkpointing + LoRA FSDP Without use_reentrant: False If you are using LoRA with FSDP and gradient checkpointing, omitting gradient_checkpointing_kwargs: {use_reentrant: false} will cause in-place modification errors during the backward pass on PyTorch >= 2.0. These errors may be intermittent and only appear after several hundred training steps, wasting significant compute. Always set use_reentrant: false explicitly. :::

:::warning Forgetting to Set padding_side = "right" for Causal LMs HuggingFace tokenizers default to padding_side = "left" for some model families, but causal language models (GPT-style, LLaMA) require padding_side = "right" for training. With left padding, the loss is computed on a partially padded sequence and the model can leak padding token patterns into its outputs. Set tokenizer.padding_side = "right" before any tokenization in your training script. Axolotl handles this automatically, but TRL does not always enforce it. :::

:::warning Packing Invalidates Per-Example Loss Metrics When you enable sequence packing (sample_packing: true in Axolotl, packing=True in SFTTrainer), multiple examples are concatenated into a single sequence. The per-step loss is now an average over multiple packed examples, which can make it look lower than it actually is per example. This does not affect training quality, but it means you cannot directly compare loss curves between packed and unpacked runs. Use eval loss (computed on unpacked validation sequences) as your primary quality metric. :::

:::warning DeepSpeed ZeRO-3 + PEFT Version Pinning Specific combinations of deepspeed, peft, and transformers versions have known incompatibilities that cause silent training failures or incorrect gradient accumulation. Before starting a production run with DeepSpeed ZeRO-3 and LoRA, check the TRL and PEFT GitHub issues for the version combination you are using. As of mid-2024, pinning peft>=0.11.0, trl>=0.9.0, and deepspeed>=0.14.0 resolves most known issues. :::


TRL vs Axolotl Decision Framework


Practical: End-to-End TRL Fine-Tuning with Evaluation Loop

Here is a complete TRL-based fine-tuning script with proper dataset formatting, completion-only loss masking, and built-in evaluation:

"""
Complete TRL SFTTrainer example for LLaMA 3 8B instruction fine-tuning.
Run: python train_sft.py
Requirements: transformers>=4.40, trl>=0.9, peft>=0.11, bitsandbytes>=0.43
"""

import torch
import wandb
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from peft import LoraConfig, TaskType
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM


def load_model_and_tokenizer(model_name: str):
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
# Disable KV cache during training (incompatible with gradient checkpointing)
model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
return model, tokenizer


def format_dataset(dataset, tokenizer):
"""Apply chat template to a dataset with 'messages' column."""
def apply_template(example):
text = tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
add_generation_prompt=False,
)
return {"text": text}

return dataset.map(apply_template, remove_columns=dataset.column_names)


def main():
MODEL_NAME = "meta-llama/Meta-Llama-3-8B"
OUTPUT_DIR = "./outputs/llama3-8b-sft"

wandb.init(
project="llama3-fine-tuning",
name="llama3-8b-qlora-r16",
config={
"model": MODEL_NAME,
"lora_r": 16,
"lora_alpha": 32,
"learning_rate": 2e-4,
}
)

model, tokenizer = load_model_and_tokenizer(MODEL_NAME)

# Dataset loading - replace with your dataset
raw_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft[:10000]")
eval_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="test_sft[:500]")

train_dataset = format_dataset(raw_dataset, tokenizer)
eval_dataset = format_dataset(eval_dataset, tokenizer)

# Completion-only collator - masks instruction tokens from loss
# For Llama 3, responses start after the assistant header token
response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n"
collator = DataCollatorForCompletionOnlyLM(
response_template=response_template,
tokenizer=tokenizer,
)

lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM,
)

training_args = SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=3,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
gradient_accumulation_steps=8, # effective batch = 16
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
optim="paged_adamw_8bit",
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
weight_decay=0.01,
max_grad_norm=1.0,
bf16=True,
max_seq_length=2048,
packing=False, # disabled when using completion collator
dataset_text_field="text",
logging_steps=10,
eval_strategy="steps",
eval_steps=200,
save_strategy="steps",
save_steps=200,
save_total_limit=3,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="wandb",
dataloader_num_workers=4,
dataloader_pin_memory=True,
)

trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=lora_config,
tokenizer=tokenizer,
data_collator=collator,
)

trainer.train()
trainer.save_model(OUTPUT_DIR + "/final")
tokenizer.save_pretrained(OUTPUT_DIR + "/final")
wandb.finish()


if __name__ == "__main__":
main()

Interview Questions and Answers

Q1: What is the difference between TRL's SFTTrainer and the base HuggingFace Trainer for LLM fine-tuning? When would you choose one over the other?

The base HuggingFace Trainer is a general-purpose training loop designed for any model and task. It has no concept of chat templates, does not handle instruction masking from loss computation, and has no built-in support for sequence packing. Using it for LLM fine-tuning requires manually implementing all of these concerns, which is error-prone.

SFTTrainer is purpose-built for supervised fine-tuning of language models. It handles chat template application, integrates with PEFT for LoRA training, supports sequence packing out of the box, and when combined with DataCollatorForCompletionOnlyLM, automatically masks instruction tokens from the loss. For standard instruction fine-tuning, SFTTrainer is the right choice. The base Trainer is appropriate when you need unusual training objectives (custom loss functions, multi-task learning with multiple losses) that do not fit the standard SFT paradigm.

Q2: Explain sequence packing. Why does it improve training efficiency, and what is the main caveat when using it?

Sequence packing concatenates multiple short training examples into a single sequence up to the model's maximum context length, separated by EOS tokens. Without packing, a batch containing examples of lengths 200, 150, and 300 tokens in a 2048-length context window is padded to 2048 for all three examples - 86% of the compute is wasted on padding tokens. With packing, all three examples fit into roughly 650 tokens, and multiple packed sequences fit in the same VRAM budget.

The main caveat is that packing disables the ability to use DataCollatorForCompletionOnlyLM simultaneously, because the collator relies on finding instruction/response boundaries within individual examples, and packed sequences blur those boundaries. When packing is enabled, the model trains on all tokens (including instruction tokens), which can slightly reduce fine-tuning quality. For very short datasets (average length under 512 tokens), the throughput gain from packing typically outweighs this quality tradeoff. For longer examples, disable packing and use the completion collator instead.

Q3: Your 70B LLaMA fine-tuning job with FSDP FULL_SHARD crashes with: "FSDP module can't be used with a non-reentrant gradient checkpointing." What is the fix?

This error occurs when gradient checkpointing is enabled with the legacy reentrant implementation. Set gradient_checkpointing_kwargs: {use_reentrant: false} in your Axolotl config or SFTConfig. The reentrant implementation uses Python's torch.autograd.Function in a way that conflicts with FSDP's parameter sharding - specifically, FSDP needs to ensure all-gather operations happen in the correct order during the backward pass, and reentrant checkpointing can violate this ordering. The non-reentrant implementation (default in PyTorch >= 2.1) uses a cleaner mechanism that is fully compatible with FSDP.

Q4: When should you use DeepSpeed ZeRO-3 instead of FSDP for multi-GPU fine-tuning?

The main reasons to choose DeepSpeed ZeRO-3 over FSDP:

First, CPU offloading. DeepSpeed can offload optimizer states and parameters to CPU RAM, enabling training of models whose optimizer states do not fit in GPU VRAM even when sharded. This is critical for full fine-tuning of 70B models on 8xA100 80GB clusters.

Second, multi-node training. DeepSpeed has more mature support for multi-node distributed training with heterogeneous network bandwidth. FSDP works for multi-node but requires more careful tuning of communication parameters.

Third, NVMe offloading. DeepSpeed's ZeRO-Infinity supports offloading to NVMe storage, enabling training of models that do not fit in CPU RAM.

FSDP is preferred when: you want pure PyTorch without additional dependencies, you need fine-grained control over sharding strategy (FSDP supports partial sharding), or you are training on a single node with 4-8 GPUs where the simpler stack reduces debugging surface.

Q5: You are running an Axolotl fine-tuning job with a custom dataset in JSONL format. After training, the model produces outputs with correct content but completely wrong format - it always outputs in a JSON structure instead of plain text. What is likely wrong, and how do you diagnose it?

This is almost certainly a dataset format type mismatch or an incorrect chat template. The model learned from training data where the responses were formatted as JSON (either because the dataset legitimately contains JSON responses, or because the format type caused incorrect parsing that wrapped responses in JSON-like structure).

To diagnose: run python -m axolotl.cli.preprocess your_config.yaml and inspect the first 5-10 decoded examples from dataset_prepared_path. Decode the token IDs back to strings and examine what the model was actually trained to predict (the tokens with labels != -100). If those decoded strings show JSON formatting that your training data should not have, the format type is wrong.

Common specific cause: using type: completion instead of type: chat_template for a dataset with message-style data. With type: completion, Axolotl trains on the raw text including all JSON structure in the messages field. Switch to type: chat_template with the correct field_messages and verify the decoded training examples look like what you actually want the model to output.

Q6: Explain the relationship between gradient accumulation, micro batch size, and effective batch size in distributed training. How do you configure these correctly in Axolotl?

Effective batch size is the number of training examples the model sees before a single optimizer step. In distributed training:

effective_batch=micro_batch_size×gradient_accumulation_steps×num_gpus\text{effective\_batch} = \text{micro\_batch\_size} \times \text{gradient\_accumulation\_steps} \times \text{num\_gpus}

micro_batch_size is how many examples fit on a single GPU per forward/backward pass - limited by VRAM. gradient_accumulation_steps controls how many micro-batches to accumulate gradients over before updating weights - it simulates a larger batch size in memory-constrained settings. num_gpus multiplies the effective batch because each GPU processes its own micro-batch in parallel.

In Axolotl: micro_batch_size maps to per_device_train_batch_size in the underlying TrainingArguments. A standard configuration for 70B QLoRA on 4xA100 might use micro_batch_size: 1, gradient_accumulation_steps: 16, giving effective batch = 1 * 16 * 4 = 64. Research suggests effective batch sizes of 32-128 work well for most instruction fine-tuning tasks. Larger batches (256+) can be used but often require proportionally lower learning rates to maintain training stability.

Q7: What is the lora_modules_to_save option in Axolotl, and when should you use it?

lora_modules_to_save specifies modules that should be saved as full precision weights alongside the LoRA adapters, rather than being represented as LoRA. Typically this is used for embed_tokens (the embedding layer) and lm_head (the output projection).

The reason to include these: if your fine-tuning data introduces new tokens (special tokens for your domain, code formatting tokens, etc.) that were not in the base model's vocabulary, those token embeddings need to be trained and saved as full weights because they cannot be represented as a low-rank update to an existing embedding - they are new rows in the embedding matrix. Similarly, lm_head shares weights with embed_tokens in most LLaMA architectures, so saving both together ensures the model can be loaded and run inference correctly after adapter loading. Omitting lora_modules_to_save when you have added new tokens will cause the model to use random embeddings for those tokens at inference time.

© 2026 EngineersOfAI. All rights reserved.