Skip to main content

Quantization-Aware Training

The Model That Would Not Quantize

The team had spent three months fine-tuning a 3B-parameter model for document classification in the legal domain. They had 40,000 labeled contracts, a careful annotation pipeline, and a fine-tune that reached 94.2% accuracy on their internal benchmark - 6 points above any open model they had tried. The model was destined for an edge deployment: a small server in each law firm's on-premise infrastructure, no cloud connectivity, running on a single RTX 3090.

The 3090 has 24GB VRAM. A 3B-parameter model in fp16 takes 6GB. That leaves 18GB for KV cache and batch processing, which is plenty. But the business requirement had shifted: the law firms needed to run document comparison in parallel with the classification model, and the comparison model was 7B parameters. Together they would not fit unless the classification model shrank to under 4GB. INT4 was the answer on paper.

They ran GPTQ. The 94.2% model became a 91.8% model. AWQ gave them 92.1%. Both well below the 93% floor that the legal team had set as the minimum acceptable accuracy for contract classification. The lawyers had been explicit: a model that miscategorizes a material adverse change clause has real liability implications.

Post-training quantization was not going to get them there. The domain was too narrow, the vocabulary too specialized (legal Latin, clause references, jurisdiction-specific terms), and the model had fine-tuned into a representation that was sensitive to the precision of specific weight values in the later transformer layers. Quantizing after the fact was destroying information that could not be recovered by any calibration-based correction.

The solution was quantization-aware training. They took the fine-tuned fp16 model, inserted fake quantization nodes throughout, and continued training for three epochs with a small learning rate on the same 40,000 labeled contracts. The model learned to represent the same information in a way that survived INT4 quantization. Final accuracy: 93.7% at INT4. Within the acceptable window. Deployed.

This story repeats across specialized domain deployments. PTQ is the right first answer because it is fast and usually good enough. QAT is the right answer when PTQ is not good enough and you cannot afford to accept the accuracy gap.


Why This Exists

Post-training quantization methods - RTN, GPTQ, AWQ - share a fundamental limitation. They are applied after the model has finished learning. The model's weights encode information in fp16 or bf16 precision, and quantization is an aggressive approximation that discards roughly 75% of the bits used to represent each weight. No amount of calibration-based correction can fully recover what is lost when you round a weight from 0.31847... to a 4-bit integer.

For large general-purpose models (7B, 13B, 70B parameters), the redundancy is high enough that PTQ works well. There are millions of paths through the network for representing any given concept, and quantization errors in individual weights are averaged out across layers. You lose 0.5-1% accuracy and call it acceptable.

For smaller, specialized models - 1B to 3B parameters fine-tuned on narrow domains - this redundancy does not exist. The model has 1/10th the capacity of a 13B model but must represent equally complex domain knowledge. Every weight is carrying more information. Quantization error is not averaged out; it is concentrated.

The theoretical argument was understood long before large language models existed. In the computer vision literature, Jacob et al. (2018) at Google showed that training with fake quantization consistently outperformed post-training methods at INT8 for convolutional networks. The insight transferred directly to transformers: if the model sees quantization noise during training, it learns representations that are robust to that noise.

The mechanism is straightforward but subtle. A model trained in full precision may have learned to represent a boundary condition by having two nearby weights whose difference encodes critical information. After quantization both weights round to the same INT4 value and the difference disappears. A model trained with fake quantization learns to use the difference between quantization levels rather than sub-level precision, because quantization noise during training destroyed the sub-level information before it could be relied upon. The model finds a different representational strategy that survives quantization.


Historical Context

Fake quantization during training was first proposed systematically by Bengio et al. in "Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation" (2013), which introduced the Straight-Through Estimator (STE) that QAT depends on. The STE is the key idea: gradients pass through the quantization operator unchanged in the backward pass, even though the quantization operator's true gradient is zero almost everywhere.

Google's landmark 2018 paper "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference" (Jacob et al.) brought QAT into mainstream practice for production deployment on mobile hardware. This paper introduced the fake quantization node abstraction: an operator that quantizes-then-dequantizes a tensor during the forward pass (simulating quantization error) but passes gradients through unchanged during the backward pass. This paper formed the foundation for PyTorch's torch.quantization QAT API.

For large language models specifically, LLM-QAT ("LLM-QAT: Data-Free Quantization Aware Training for Large Language Models", Liu et al., 2023) extended QAT to decoder-only transformers using a data-free approach - generating calibration data from the model itself rather than requiring labeled training data. This was important because LLMs are general-purpose and training data may not be freely redistributable.

The most radical application of quantization-aware training is BitNet (Wang et al., Microsoft Research, 2023): training 1-bit LLMs entirely from scratch where every weight is -1 or +1. BitNet 1.58-bit (Ma et al., 2024) extended this to ternary weights (-1, 0, +1), achieving competitive performance with full-precision models at roughly equal parameter count. BitNet represents the logical endpoint of QAT: not post-training quantization and not even QAT on top of a full-precision model, but a training regime designed around the constraint that weights must be representable in 1-2 bits.


Core Concepts

The Quantization Problem for Backpropagation

Quantization introduces a non-differentiable step function. For INT4 symmetric quantization of a weight ww with scale Δ\Delta:

w^=clip(round(wΔ),8,7)Δ\hat{w} = \text{clip}\left(\text{round}\left(\frac{w}{\Delta}\right), -8, 7\right) \cdot \Delta

The derivative of the round function is zero everywhere except at integer values where it is undefined:

ddwround(w)=0almost everywhere\frac{d}{dw}\text{round}(w) = 0 \quad \text{almost everywhere}

This means the gradient of the quantized weight with respect to the original weight is zero almost everywhere. If you include quantization in the forward pass and backpropagate through it naively, you get zero gradients for all weights - training stops immediately.

The Straight-Through Estimator

The STE is a principled hack. Instead of computing the true gradient through the quantization operator (which is zero), you pretend the quantization operator is the identity function during the backward pass. The gradient flows through the quantization node unchanged:

LwLw^\frac{\partial \mathcal{L}}{\partial w} \approx \frac{\partial \mathcal{L}}{\partial \hat{w}}

This is mathematically wrong. The quantization operator is not the identity function. But it is empirically correct in the sense that it produces useful gradient information that guides the weights toward values that are robust to quantization. The intuition: if the loss would decrease if w^\hat{w} increased, then increasing ww (the pre-quantization weight) is likely to help, even though the relationship is non-linear due to quantization.

The STE works because training with small learning rates makes weight updates small. If ww changes by a small δw\delta w, then with high probability w^\hat{w} changes in the same direction (since round(w+δw)round(w)\text{round}(w + \delta w) \geq \text{round}(w) when δw>0\delta w > 0 for most values of ww). The zero-gradient regions are narrow and the optimizer effectively averages through them over many gradient steps.

A useful way to think about it: the STE treats the quantization operator as having a "soft" gradient of 1 everywhere within the clipping range, and 0 outside it. This is sometimes called the "clipped identity" gradient:

w^w1[ΔclipwΔclip]\frac{\partial \hat{w}}{\partial w} \approx \mathbf{1}[-\Delta_{\text{clip}} \leq w \leq \Delta_{\text{clip}}]

Fake Quantization Nodes

In practice, QAT is implemented by inserting fake quantization nodes into the model graph. A fake quantization node:

  1. In the forward pass: quantizes the input to INT4 (or INT8) and then dequantizes it back to fp16/bf16. The output is fp16/bf16 but has the distribution of quantized values - only values on the quantization grid are possible.
  2. In the backward pass: passes gradients through unchanged (STE).

The effect on training: the model sees quantization-induced rounding error on every forward pass. Over thousands of gradient steps, the model learns to minimize loss in a way that is robust to this error. The weights converge to values that "snap cleanly" to quantization grid points, reducing the rounding error during inference.

Critically: the weights are stored in fp16/bf16 throughout QAT. The quantization is simulated during training, not actually applied. Only when you export the model for inference do you actually convert to INT4. This matters because the optimizer (Adam, AdamW) needs fp16/bf16 weights and gradients to work correctly - if you stored actual INT4 weights you could not compute meaningful gradient updates.

QAT vs PTQ: When to Choose Each

The decision is primarily about the cost of training versus the cost of accuracy loss:

  • PTQ takes minutes to hours. QAT takes hours to days.
  • PTQ accuracy loss is 0.5-1% for general models at INT4. QAT accuracy loss is typically under 0.2%.
  • PTQ requires only a calibration set (128-512 samples). QAT requires the full training set and a working training pipeline.
  • PTQ works on any pretrained model without modification. QAT requires inserting fake quantization nodes and continuing training.

For most production deployments at INT8, PTQ is sufficient and QAT is unnecessary overhead. For INT4 on specialized models, the accuracy difference is usually worth the QAT investment. For INT2 or 1-bit quantization, QAT or from-scratch quantization-aware training (BitNet) is the only viable path.

The crossover point depends on the model size and domain. Smaller models fine-tuned on narrow domains hit the PTQ accuracy floor earlier. A 70B general-purpose model can handle PTQ INT4 with under 1% degradation. A 1B domain-specialized model may show 3-5% degradation with PTQ INT4, which is often unacceptable.


Mermaid Diagrams

QAT Forward and Backward Pass

PTQ vs QAT Decision Tree

BitNet Training vs Standard QAT


Code Examples

Implementing Fake Quantization from Scratch

Before using PyTorch's built-in QAT tools, it helps to understand what fake quantization actually does at the code level.

import torch
import torch.nn as nn
import torch.nn.functional as F


class FakeQuantize(torch.autograd.Function):
"""
Fake quantization with Straight-Through Estimator.

Forward: quantize then dequantize (simulates INT4 rounding error)
Backward: identity for values within clipping range, zero outside
"""

@staticmethod
def forward(ctx, x: torch.Tensor, scale: float, num_bits: int = 4):
# Compute quantization range
qmin = -(2 ** (num_bits - 1)) # -8 for INT4
qmax = 2 ** (num_bits - 1) - 1 # +7 for INT4

# Quantize
x_scaled = x / scale
x_clipped = torch.clamp(x_scaled, qmin, qmax)
x_rounded = torch.round(x_clipped)

# Dequantize (back to fp16 values, but on INT4 grid)
x_dequant = x_rounded * scale

# Save clipping mask for backward pass (STE clips gradient outside range)
ctx.save_for_backward(x_scaled)
ctx.qmin = qmin
ctx.qmax = qmax

return x_dequant

@staticmethod
def backward(ctx, grad_output):
x_scaled, = ctx.saved_tensors

# STE: gradient passes through if value was within clipping range
# Zero gradient for values that were clipped
in_range = (x_scaled >= ctx.qmin) & (x_scaled <= ctx.qmax)
grad_input = grad_output * in_range.float()

# No gradient for scale parameter (simplified; real implementations
# may also learn the scale)
return grad_input, None, None


def fake_quantize(x: torch.Tensor, num_bits: int = 4) -> torch.Tensor:
"""Per-tensor fake quantization."""
# Compute scale from tensor range
x_max = x.abs().max().item()
qmax = 2 ** (num_bits - 1) - 1
scale = x_max / qmax if x_max > 0 else 1.0

return FakeQuantize.apply(x, scale, num_bits)


class QATLinear(nn.Linear):
"""
Linear layer with fake quantization for QAT.
Quantizes weights (and optionally activations) during training.
"""

def __init__(self, in_features: int, out_features: int, weight_bits: int = 4,
act_bits: int = 8, quantize_activations: bool = True, **kwargs):
super().__init__(in_features, out_features, **kwargs)
self.weight_bits = weight_bits
self.act_bits = act_bits
self.quantize_activations = quantize_activations
self.qat_enabled = True

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.qat_enabled:
# Apply fake quantization to weights
weight_fq = fake_quantize(self.weight, num_bits=self.weight_bits)

# Optionally apply fake quantization to activations
if self.quantize_activations:
x_fq = fake_quantize(x, num_bits=self.act_bits)
else:
x_fq = x

return F.linear(x_fq, weight_fq, self.bias)
else:
# Normal forward pass (used for baseline comparison)
return F.linear(x, self.weight, self.bias)

def enable_qat(self):
self.qat_enabled = True

def disable_qat(self):
self.qat_enabled = False


# Example: replace all linear layers in a model with QAT-aware versions
def convert_model_to_qat(model: nn.Module, weight_bits: int = 4) -> nn.Module:
"""
Replace all nn.Linear layers with QATLinear.
Copies weights from original layer.
"""
for name, module in model.named_children():
if isinstance(module, nn.Linear):
qat_layer = QATLinear(
module.in_features,
module.out_features,
weight_bits=weight_bits,
bias=module.bias is not None,
)
# Copy pretrained weights
qat_layer.weight.data = module.weight.data.clone()
if module.bias is not None:
qat_layer.bias.data = module.bias.data.clone()
setattr(model, name, qat_layer)
else:
# Recurse into child modules
convert_model_to_qat(module, weight_bits=weight_bits)

return model

QAT Training Pipeline for a Classification Model

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from typing import Optional
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def run_qat_training(
model_name: str,
train_dataloader: DataLoader,
eval_dataloader: DataLoader,
num_labels: int,
weight_bits: int = 4,
num_epochs: int = 3,
learning_rate: float = 2e-5,
warmup_steps: int = 100,
output_dir: str = "./qat-model",
) -> nn.Module:
"""
Full QAT training pipeline for a sequence classification model.

Strategy:
1. Load pretrained model
2. Insert fake quantization nodes in all Linear layers
3. Fine-tune with quantization noise active
4. Export quantized model
"""
from transformers import get_linear_schedule_with_warmup

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 1: Load pretrained model
logger.info(f"Loading {model_name}...")
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=num_labels,
torch_dtype=torch.float16,
).to(device)

# Measure baseline accuracy before QAT
baseline_acc = evaluate_accuracy(model, eval_dataloader, device)
logger.info(f"Baseline fp16 accuracy: {baseline_acc:.4f}")

# Step 2: Insert fake quantization nodes
logger.info(f"Converting model to {weight_bits}-bit QAT...")
model = convert_model_to_qat(model, weight_bits=weight_bits)

# Count QAT layers
qat_layers = sum(1 for m in model.modules() if isinstance(m, QATLinear))
logger.info(f"Inserted fake quantization in {qat_layers} linear layers")

# Step 3: QAT training
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
weight_decay=0.01,
)

total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)

best_accuracy = 0.0
best_model_state = None

for epoch in range(num_epochs):
model.train()
total_loss = 0.0
num_batches = 0

for batch_idx, batch in enumerate(train_dataloader):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)

outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)

loss = outputs.loss
loss.backward()

# Gradient clipping is important for QAT stability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

optimizer.step()
scheduler.step()
optimizer.zero_grad()

total_loss += loss.item()
num_batches += 1

if batch_idx % 100 == 0:
logger.info(
f"Epoch {epoch+1}/{num_epochs}, "
f"Step {batch_idx}/{len(train_dataloader)}, "
f"Loss: {loss.item():.4f}"
)

avg_loss = total_loss / num_batches
epoch_acc = evaluate_accuracy(model, eval_dataloader, device)
logger.info(
f"Epoch {epoch+1} complete. Loss: {avg_loss:.4f}, Accuracy: {epoch_acc:.4f}"
)

if epoch_acc > best_accuracy:
best_accuracy = epoch_acc
best_model_state = {k: v.clone() for k, v in model.state_dict().items()}

# Restore best checkpoint
model.load_state_dict(best_model_state)
logger.info(f"Best QAT accuracy: {best_accuracy:.4f} (baseline: {baseline_acc:.4f})")
logger.info(f"Accuracy gap from fp16: {baseline_acc - best_accuracy:.4f}")

return model


def evaluate_accuracy(model: nn.Module, dataloader: DataLoader, device) -> float:
"""Evaluate classification accuracy."""
model.eval()
correct = 0
total = 0

with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)

outputs = model(input_ids=input_ids, attention_mask=attention_mask)
predictions = outputs.logits.argmax(dim=-1)

correct += (predictions == labels).sum().item()
total += labels.size(0)

return correct / total

QLoRA: Quantization-Aware Fine-Tuning in Practice

QLoRA is not classical QAT, but it is the most practically important form of quantization-aware fine-tuning for LLMs. It fine-tunes a quantized base model using low-rank adapters, producing a model that is inherently adapted to the quantized representation.

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
import torch

def setup_qlora_model(
model_name: str,
lora_rank: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.05,
target_modules: Optional[list] = None,
) -> tuple:
"""
Set up a model for QLoRA fine-tuning.

QLoRA loads the base model in INT4 (NF4 format) and trains
only the LoRA adapter weights in bf16. This is QAT in the sense
that gradients from the adapter propagate through quantized
base model weights.
"""
# NF4 quantization config (4-bit NormalFloat, double quantization)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True, # Quantize the quantization scales too
bnb_4bit_quant_type="nf4", # NormalFloat4 - better for normally distributed weights
bnb_4bit_compute_dtype=torch.bfloat16, # bf16 for computations
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
)

# Prepare model for k-bit training (handles gradient checkpointing, etc.)
model = prepare_model_for_kbit_training(model)

# Default target modules for LLaMA/Mistral architecture
if target_modules is None:
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"]

lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type=TaskType.CAUSAL_LM,
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Output: trainable params: 83,886,080 || all params: 7,241,748,480 || trainable%: 1.159

return model, tokenizer


def train_qlora(
model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct",
dataset_name: str = "your-domain-dataset",
output_dir: str = "./qlora-output",
num_epochs: int = 2,
batch_size: int = 4,
gradient_accumulation_steps: int = 4,
):
"""QLoRA fine-tuning with Hugging Face Trainer."""
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq
from datasets import load_dataset

model, tokenizer = setup_qlora_model(model_name)

dataset = load_dataset(dataset_name)

training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=True, # Essential for memory efficiency
optim="paged_adamw_8bit", # 8-bit paged optimizer - key QLoRA optimization
learning_rate=2e-4,
weight_decay=0.001,
fp16=False,
bf16=True,
max_grad_norm=0.3,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
logging_steps=10,
save_strategy="epoch",
evaluation_strategy="epoch",
load_best_model_at_end=True,
)

trainer = Trainer(
model=model,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
args=training_args,
data_collator=DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8),
)

trainer.train()
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

return model

Implementing BitNet-Style Ternary Weight Training

BitNet represents the extreme end of QAT: training with ternary weights from scratch.

import torch
import torch.nn as nn
import torch.nn.functional as F


class TernaryWeightFunction(torch.autograd.Function):
"""
Ternary weight quantization with STE.
Weights are mapped to {-1, 0, +1} using the absmean threshold.
Forward: ternarize weights
Backward: STE - gradient passes through unchanged
"""

@staticmethod
def forward(ctx, weight: torch.Tensor) -> torch.Tensor:
# Absmean quantization: threshold = mean(|w|) / 0.5
# This is BitNet's 1.58-bit approach (ternary: -1, 0, +1)
gamma = weight.abs().mean()
# Ternarize: values within [-gamma, gamma] -> 0, others -> sign
w_ternary = torch.where(
weight.abs() > gamma,
torch.sign(weight),
torch.zeros_like(weight)
)
ctx.save_for_backward(weight)
return w_ternary

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
# STE: gradient passes through unchanged
return grad_output


class BitLinear(nn.Linear):
"""
Linear layer with ternary weights (BitNet 1.58b style).
Weights are {-1, 0, +1} during the forward pass.
Full-precision weights are maintained for gradient updates.
"""

def __init__(self, in_features: int, out_features: int, **kwargs):
super().__init__(in_features, out_features, **kwargs)
# Replace LayerNorm before this layer is handled externally
# Activations are quantized to INT8 with absmax scaling

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Quantize activations: INT8 absmax per token
gamma = x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
x_quantized = (x * 127.0 / gamma).round().clamp(-128, 127) / 127.0 * gamma

# Ternarize weights
w_ternary = TernaryWeightFunction.apply(self.weight)

# Scale weights back (absmean normalization)
beta = self.weight.abs().mean()
w_scaled = w_ternary * beta

return F.linear(x_quantized, w_scaled, self.bias)


class BitNetBlock(nn.Module):
"""
A transformer block designed for BitNet training.
Uses BitLinear instead of nn.Linear throughout.
"""

def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.norm1 = nn.RMSNorm(d_model) # RMSNorm before attention
self.norm2 = nn.RMSNorm(d_model) # RMSNorm before FFN

# Attention projections - all ternary weights
self.q_proj = BitLinear(d_model, d_model, bias=False)
self.k_proj = BitLinear(d_model, d_model // n_heads, bias=False)
self.v_proj = BitLinear(d_model, d_model, bias=False)
self.o_proj = BitLinear(d_model, d_model, bias=False)

# FFN - all ternary weights
self.gate_proj = BitLinear(d_model, d_ff, bias=False)
self.up_proj = BitLinear(d_model, d_ff, bias=False)
self.down_proj = BitLinear(d_ff, d_model, bias=False)

self.n_heads = n_heads
self.d_head = d_model // n_heads
self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
B, T, C = x.shape

# Self-attention with ternary weight projections
normed = self.norm1(x)
q = self.q_proj(normed).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
k = self.k_proj(normed).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
v = self.v_proj(normed).view(B, T, self.n_heads, self.d_head).transpose(1, 2)

attn = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0.0
)
attn = attn.transpose(1, 2).contiguous().view(B, T, C)
x = x + self.dropout(self.o_proj(attn))

# FFN with SwiGLU and ternary weights
normed = self.norm2(x)
gate = F.silu(self.gate_proj(normed))
up = self.up_proj(normed)
x = x + self.dropout(self.down_proj(gate * up))

return x

Comparing QAT vs PTQ Accuracy at INT4

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from awq import AutoAWQForCausalLM
from torch.utils.data import DataLoader
import json


def compare_qat_vs_ptq(
fp16_model_path: str,
qat_model_path: str,
awq_model_path: str,
test_dataloader: DataLoader,
task: str = "classification",
) -> dict:
"""
Side-by-side comparison of fp16, PTQ (AWQ), and QAT accuracy.
Run this before deployment to decide which quantization strategy to use.
"""
device = torch.device("cuda")
results = {}

# Evaluate fp16 baseline
print("Evaluating fp16 baseline...")
fp16_model = AutoModelForSequenceClassification.from_pretrained(
fp16_model_path,
torch_dtype=torch.float16,
).to(device)
results["fp16"] = {
"accuracy": evaluate_accuracy(fp16_model, test_dataloader, device),
"model_size_gb": get_model_size_gb(fp16_model),
"memory_gb": measure_inference_memory_gb(fp16_model, device),
}
del fp16_model
torch.cuda.empty_cache()

# Evaluate AWQ PTQ
print("Evaluating AWQ INT4 (PTQ)...")
awq_model = AutoAWQForCausalLM.from_quantized(
awq_model_path, fuse_layers=False, device_map="cuda:0"
)
results["awq_int4"] = {
"accuracy": evaluate_accuracy(awq_model, test_dataloader, device),
"model_size_gb": get_model_size_gb(awq_model),
"memory_gb": measure_inference_memory_gb(awq_model, device),
}
del awq_model
torch.cuda.empty_cache()

# Evaluate QAT INT4
print("Evaluating QAT INT4...")
qat_model = AutoModelForSequenceClassification.from_pretrained(
qat_model_path,
torch_dtype=torch.float16,
).to(device)
# QAT model has fake quant nodes - for inference, we need to actually quantize
# In practice: export with torch.ao.quantization.convert() after QAT
results["qat_int4"] = {
"accuracy": evaluate_accuracy(qat_model, test_dataloader, device),
"model_size_gb": get_model_size_gb(qat_model),
"memory_gb": measure_inference_memory_gb(qat_model, device),
}

# Print comparison
baseline_acc = results["fp16"]["accuracy"]
print("\n" + "=" * 60)
print(f"{'Method':<15} {'Accuracy':>10} {'vs fp16':>10} {'Size GB':>10}")
print("-" * 60)
for method, metrics in results.items():
acc = metrics["accuracy"]
gap = acc - baseline_acc
size = metrics["model_size_gb"]
print(f"{method:<15} {acc:>9.3%} {gap:>+9.3%} {size:>9.2f}")
print("=" * 60)

return results


def get_model_size_gb(model: nn.Module) -> float:
"""Estimate model size in GB from parameter count and dtype."""
total_bytes = sum(
p.numel() * p.element_size() for p in model.parameters()
)
return total_bytes / 1e9


def measure_inference_memory_gb(model: nn.Module, device, batch_size: int = 1) -> float:
"""Measure peak GPU memory during a single forward pass."""
torch.cuda.reset_peak_memory_stats(device)
dummy_input = torch.randint(0, 1000, (batch_size, 128)).to(device)
with torch.no_grad():
_ = model(input_ids=dummy_input)
peak_bytes = torch.cuda.max_memory_allocated(device)
return peak_bytes / 1e9

PyTorch FX Graph Mode QAT

For production-grade QAT with more control over which operations are quantized, PyTorch's torch.ao.quantization FX mode is the right tool.

import torch
import torch.nn as nn
from torch.ao.quantization import (
get_default_qat_qconfig_mapping,
prepare_qat_fx,
convert_fx,
QConfigMapping,
)
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver


def setup_fx_qat_model(model: nn.Module, example_input: torch.Tensor) -> nn.Module:
"""
Prepare a model for QAT using PyTorch FX graph mode.

FX mode has advantages over eager mode QAT:
- Whole-model graph capture allows cross-layer optimization
- Automatic handling of residual connections and skip connections
- Supports quantizing arbitrary operations, not just nn.Module layers
"""
model.train() # FX QAT preparation requires training mode

# Build a custom QConfig for INT4 weights, INT8 activations (W4A8)
weight_qconfig = QConfig(
activation=FakeQuantize.with_args(
observer=MovingAverageMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_tensor_affine,
),
weight=FakeQuantize.with_args(
observer=MinMaxObserver,
quant_min=-8,
quant_max=7,
dtype=torch.qint8, # Using qint8 to represent INT4 values [-8, 7]
qscheme=torch.per_channel_symmetric,
),
)

# Apply qconfig to all linear layers
qconfig_mapping = QConfigMapping().set_global(weight_qconfig)

# Prepare model with FX - inserts fake quantization observers
model_prepared = prepare_qat_fx(
model,
qconfig_mapping,
example_inputs=(example_input,),
)

return model_prepared


def finalize_qat_model(model_prepared: nn.Module, example_input: torch.Tensor) -> nn.Module:
"""
Convert a trained QAT model to an actual quantized model for deployment.
This replaces fake quantization nodes with real quantized operations.
"""
model_prepared.eval()

# Convert to quantized model (fake quant -> real quant)
quantized_model = convert_fx(model_prepared)

# Verify output matches
with torch.no_grad():
out_qat = model_prepared(example_input)
out_quant = quantized_model(example_input)

max_diff = (out_qat - out_quant).abs().max().item()
print(f"Max output difference after conversion: {max_diff:.6f}")
# Should be very small (< 0.001) if conversion is correct

return quantized_model

Production Engineering Notes

Scheduling QAT in the Training Pipeline

QAT should not be applied for the full training duration. The standard approach is a three-phase schedule:

class QATScheduler:
"""
Manages the three phases of QAT training:
1. Warmup: train without quantization to stabilize
2. QAT: enable fake quantization, continue training with lower LR
3. Freeze observers: stop updating quantization ranges, fine-tune only weights
"""

def __init__(
self,
model: nn.Module,
total_steps: int,
warmup_fraction: float = 0.1,
qat_fraction: float = 0.8,
freeze_fraction: float = 0.1,
):
assert abs(warmup_fraction + qat_fraction + freeze_fraction - 1.0) < 1e-6

self.model = model
self.total_steps = total_steps
self.warmup_end = int(total_steps * warmup_fraction)
self.freeze_start = int(total_steps * (warmup_fraction + qat_fraction))

self._qat_enabled = False
self._observers_frozen = False

def step(self, current_step: int):
if current_step == self.warmup_end and not self._qat_enabled:
self._enable_qat()

if current_step == self.freeze_start and not self._observers_frozen:
self._freeze_observers()

def _enable_qat(self):
"""Enable fake quantization throughout the model."""
for module in self.model.modules():
if isinstance(module, QATLinear):
module.enable_qat()
# For FX mode: torch.ao.quantization.enable_fake_quant(self.model)
self._qat_enabled = True
print(f"QAT enabled at step {self.warmup_end}")

def _freeze_observers(self):
"""
Freeze the moving average observers that track activation ranges.
After freezing, only weight values are updated - the quantization
scale/zero-point parameters are fixed.
This improves stability in the final training phase.
"""
for module in self.model.modules():
if hasattr(module, "activation_post_process"):
module.activation_post_process.disable_observer()
# For FX mode: torch.ao.quantization.disable_observer(self.model)
self._observers_frozen = True
print(f"Observers frozen at step {self.freeze_start}")

When QAT Takes Too Long: Progressive Quantization

For very large models where full QAT on all layers is prohibitively expensive, progressive quantization provides most of the benefit at a fraction of the cost.

def progressive_qat(
model: nn.Module,
train_dataloader: DataLoader,
num_epochs_per_phase: int = 1,
weight_bits: int = 4,
) -> nn.Module:
"""
Quantize layers progressively from first to last.
Each phase quantizes the next set of layers and trains for a few epochs.

Key insight: early layers are more sensitive to quantization (they affect
all downstream layers). Quantizing from last to first is safer.
"""
# Collect all linear layers in order
linear_layers = [
(name, module)
for name, module in model.named_modules()
if isinstance(module, nn.Linear)
]

# Process in reverse order (last layers first - less sensitive)
linear_layers.reverse()

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

for phase, (layer_name, layer) in enumerate(linear_layers):
print(f"Phase {phase+1}/{len(linear_layers)}: Quantizing {layer_name}")

# Replace this layer with QAT version
parent_name = ".".join(layer_name.split(".")[:-1])
child_name = layer_name.split(".")[-1]
parent = model.get_submodule(parent_name)

qat_layer = QATLinear(
layer.in_features,
layer.out_features,
weight_bits=weight_bits,
bias=layer.bias is not None,
)
qat_layer.weight.data = layer.weight.data.clone()
if layer.bias is not None:
qat_layer.bias.data = layer.bias.data.clone()
setattr(parent, child_name, qat_layer)

# Fine-tune for a few epochs to recover accuracy
for epoch in range(num_epochs_per_phase):
train_one_epoch(model, train_dataloader, optimizer)

print(f" Completed. Layers quantized: {phase+1}/{len(linear_layers)}")

return model

Common Mistakes

:::danger Using a Learning Rate That is Too High for QAT

QAT is fine-tuning, not training from scratch. The model already has good weights; you are adjusting them to be robust to quantization noise, not finding a solution to the optimization problem. Using a learning rate that is appropriate for pretraining (1e-3 to 1e-4) will catastrophically forget the pretrained knowledge.

QAT learning rates should be 5-20x lower than the fine-tuning learning rate. If you fine-tuned at 2e-5, run QAT at 1e-6 to 4e-6. The model should make small adjustments to existing weights, not large movements.

A model that produces NaN loss during QAT is almost always a learning rate problem. Check this before anything else. :::

:::danger Applying Fake Quantization to All Layers Including Embeddings

Embedding layers have fundamentally different weight distributions than linear projection layers. INT4 quantization of embeddings (word embeddings, position embeddings) causes disproportionate accuracy loss because each embedding vector represents a discrete token identity and the values within the vector are not redundant.

Standard practice: exclude embedding layers, the language model head (lm_head), and sometimes the first and last transformer layers from quantization. The accuracy-size tradeoff for these layers is poor - they represent a small fraction of parameters but contribute significantly to quantization error.

# When converting to QAT, skip these layer names:
SKIP_LAYERS = {"embed_tokens", "embed_positions", "lm_head", "model.norm"}

for name, module in model.named_modules():
should_skip = any(skip in name for skip in SKIP_LAYERS)
if isinstance(module, nn.Linear) and not should_skip:
# Replace with QATLinear
pass

:::

:::warning Measuring QAT Accuracy with Fake Quantization Still Active

After QAT training, if you evaluate the model with fake quantization nodes still in training mode, you will measure accuracy with quantization noise still applied. This can give pessimistic results. For evaluation:

  1. Call model.eval() - this switches fake quantization observers to eval mode and stops updating range estimates.
  2. The fake quantization is still applied but with fixed scale parameters.
  3. For the true deployed accuracy, convert the model with torch.ao.quantization.convert_fx() and evaluate the actually-quantized INT4 model.

The difference between eval-mode QAT accuracy and converted INT4 accuracy should be minimal (under 0.1%). If it is larger, your fake quantization nodes are not faithfully simulating the deployment quantization. :::

:::warning Confusing QLoRA Fine-Tuning with True QAT

QLoRA is not the same as QAT in the classical sense. In QLoRA, the base model weights are frozen and the only trainable parameters are the LoRA adapters. The gradients flow through the quantized base model weights (as frozen lookup tables) to update the adapters, but the base model weights themselves do not change.

True QAT updates the actual model weights to be robust to quantization. QLoRA adds a small set of full-precision adapter weights on top of a frozen quantized base.

For the same base model, true QAT at INT4 typically outperforms QLoRA at INT4 by 0.5-1.5 percentage points on domain-specific tasks, because QAT can actually reshape the base model weights to better fit the quantization grid. QLoRA is faster to run and requires less GPU memory, which is why it is more widely used, but it is not the same operation.

Use QLoRA when you need efficient domain adaptation of a large model. Use true QAT when you need to maximize accuracy at aggressive quantization levels and have the compute budget. :::


Interview Q&A

Q1: Explain the Straight-Through Estimator. Why does it work despite being mathematically incorrect?

A: The STE approximates the gradient of the quantization operator as 1 (within the clipping range) instead of 0 (the true derivative of the round function almost everywhere). It is mathematically wrong because the round function's derivative is genuinely zero almost everywhere - there is no first-principles justification for pretending it is 1.

It works empirically for two reasons. First, gradient descent with small learning rates makes small weight updates. For small updates, the expected change in the quantized value has the same sign as the change in the pre-quantization value with high probability. If ww increases by ϵ\epsilon, then round(w+ϵ)\text{round}(w + \epsilon) is usually round(w)\geq \text{round}(w) for small ϵ\epsilon. The gradient direction is usually correct even if the magnitude is wrong.

Second, the STE can be viewed as optimizing a smoothed version of the objective. If you replace the round function with a soft rounding operation (like w12πsin(2πw)w - \frac{1}{2\pi}\sin(2\pi w) which is smooth but pushes weights toward integers), its gradient is non-zero and points in a similar direction to the STE gradient. Training with STE is roughly equivalent to training with this smooth approximation.

The deeper intuition: you do not need the exact gradient to optimize a function with gradient descent. You need a descent direction that is correlated with the true descent direction. The STE provides this.

Q2: What is the difference between W4A8 and W4A16 quantization, and when would you choose each for QAT?

A: W4A8 means 4-bit weights and 8-bit activations. W4A16 means 4-bit weights and 16-bit activations (activations stay in fp16).

W4A16 is simpler to implement and is the AWQ/GPTQ standard. It reduces memory bandwidth for weight loading by 4x (loading 4-bit weights instead of 16-bit) while keeping activations in fp16 for numerical stability. The matrix multiply is done in fp16 after dequantizing the weights. This is appropriate for memory-bandwidth-limited inference on GPUs.

W4A8 quantizes both weights and activations to integer formats, enabling the use of INT8 or INT4 tensor cores for the actual multiply-accumulate operations. On hardware with good INT8 support (like NVIDIA's Tensor Cores), this can provide additional throughput beyond the memory bandwidth savings. But it requires careful calibration for activation quantization - activations have much more dynamic range variation than weights and outlier activations (found in LLMs, especially in later layers) can cause significant quantization error.

For QAT specifically: start with W4A16 (simpler, usually enough accuracy). Only move to W4A8 if you need the additional throughput on INT8-capable hardware and you have confirmed that W4A16 throughput is insufficient. W4A8 QAT is more sensitive to learning rate and calibration quality.

Q3: How does LLM-QAT address the problem of not having access to the original training data?

A: LLM-QAT (Liu et al., 2023) uses the LLM itself to generate calibration data for QAT. The insight: you can sample from the model's own distribution to create training data that is representative of what the model "knows" and "does," without needing the original training corpus.

The process: (1) sample prompts from a small seed set (or generate them randomly), (2) run the fp16 model to generate completions, (3) use these self-generated (prompt, completion) pairs as QAT training data, distilling the fp16 model's behavior into the quantized version.

This is a form of knowledge distillation combined with QAT. The loss function trains the quantized model to match the fp16 model's output distribution (KL divergence on logits) rather than maximizing likelihood on the original training data.

The practical advantage is that this approach works for any pretrained LLM, even when the original training data is proprietary or too large to store. The theoretical concern is that self-generated data may have biases or blind spots from the original model that get amplified during QAT. In practice, LLM-QAT has been shown to outperform PTQ methods across a range of model families and bit widths, validating the approach.

Q4: A 3B model fine-tuned for medical QA shows 93.5% accuracy in fp16 and 88.2% after GPTQ INT4. Your team wants to reach at least 92%. Walk through your QAT strategy.

A: The 5.3-point gap at INT4 is large, suggesting the model is using precision in a way that PTQ cannot recover. Here is how I would approach it.

First, diagnose where the accuracy is lost. Run the fp16 and GPTQ models on the test set and categorize failures. If GPTQ is failing on specific question types or medical domains, this tells you whether the problem is uniform (quantization broadly) or concentrated (specific layer types or medical vocabulary).

Second, prepare QAT infrastructure. Insert fake quantization nodes in all linear layers except embeddings and the language model head. Set learning rate to 10-20x lower than the original fine-tuning rate. Use the same 40,000 labeled medical QA pairs.

Third, use the three-phase QAT schedule. Phase 1 (10% of training): no fake quantization, just verify the training loop is stable. Phase 2 (80%): enable INT4 fake quantization, train with cross-entropy loss on the labeled data. Phase 3 (10%): freeze observer statistics, make final small adjustments to weight values only.

Fourth, if after full QAT you are still below 92%, try mixed precision: quantize 80% of the transformer layers to INT4 and keep the final 20% (the layers closest to the output) in INT8. Late layers in fine-tuned models carry more domain-specific information and can be more sensitive to aggressive quantization.

Fifth, if mixed precision still does not close the gap, consider whether INT4 is the right target. If the legal accuracy requirement is firm at 92%, and the hardware constraint is a 4GB limit, perhaps INT6 or a pruned-then-INT8 approach gives a better accuracy-size tradeoff than INT4.

Q5: What makes BitNet architecturally different from standard QAT, and why can it achieve comparable performance to fp16 despite using only 1.58 bits per weight?

A: Standard QAT starts with a full-precision model and quantizes an existing representation. BitNet is designed with quantization as a fundamental architectural constraint from the beginning. Every weight in a BitNet model must be representable in {-1, 0, +1} at runtime. The model never has access to sub-bit precision; it must learn to represent all information using only three possible weight values per parameter.

This is achievable because the capacity of a neural network is determined by both the precision of individual weights and the total number of parameters interacting together. A model with 7 billion weights that are each in {-1, 0, +1} has enormous representational capacity through the combinatorial interactions between weights, even if each individual weight carries minimal information.

BitNet 1.58-bit achieves comparable performance by compensating the bit-width reduction with per-tensor activation scaling and RMSNorm before each BitLinear layer. The key architectural insight: while each weight is ternary, the matmul output is a sum of many ternary weights times activation values, and this sum has the continuous-valued distribution needed for the next layer's computation.

The efficiency gain is large: a ternary matmul is essentially additions and subtractions (no multiplications needed since weights are -1, 0, or +1). On hardware designed for BitNet (specialized accelerators), this enables dramatic speedups. On standard GPUs with tensor cores optimized for INT8/INT16 arithmetic, the benefit is primarily memory bandwidth (1.58 bits per weight means less data to load from VRAM) rather than compute efficiency.

Q6: How does QLoRA differ from standard QAT mechanistically, and what does this mean for the models they produce?

A: Mechanistically, they differ in what parameters are updated during the "aware" training phase.

In standard QAT, the model's actual weights are updated. Fake quantization nodes inject quantization noise into the forward pass, and gradients flow back through those nodes (via STE) to update the full-precision weight values. At the end of QAT, the weights have shifted to positions on or near the quantization grid - the model has structurally adapted to INT4 representation.

In QLoRA, the base model weights are frozen at their INT4 values (they are actually loaded in NF4 format and never updated). The trainable parameters are the LoRA adapter matrices AA and BB added to each attention and MLP layer. Gradients flow through the frozen quantized base model to reach the adapters, but the base model weights themselves do not change.

The models they produce are different. QAT produces a model with the same architecture as the original but with weights that have been adjusted to be INT4-friendly. You can deploy it as a standard INT4 model with no adapter infrastructure. QLoRA produces a base INT4 model plus a set of fp16 adapter weights. Deployment requires merging the adapter into the base model (which upgrades it back toward fp16, losing the size benefit) or serving with the adapter added at inference time (which adds compute overhead for the adapter matmul).

For pure quantization quality, QAT is superior because it can reshape all weights. For practical domain adaptation of large models with limited compute, QLoRA is superior because the cost is proportional to the number of adapter parameters (1-2% of the model) rather than the full model. Both are valid production tools; the choice depends on whether your constraint is accuracy or compute budget.

© 2026 EngineersOfAI. All rights reserved.