Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the Quantisation Explorer demo on the EngineersOfAI Playground - no code required. :::

AWQ: Activation-Aware Weight Quantization

The Experiment That Rewrote the Rules

In early 2023, a Mistral 7B model arrived in production at an AI startup serving 40 million API calls per day. The model ran in FP16 across two A10G GPUs. The infrastructure team had a mandate: cut serving costs by 60% without degrading user experience. They tried GPTQ INT4 first. The quantized model was 4x smaller, served 2.5x faster, and benchmarks showed only 1.8% accuracy drop on MMLU. They shipped it. Within 48 hours, they started seeing a pattern in user complaints - a specific class of multi-step reasoning questions that the new model got wrong consistently. The benchmark had missed it. Rollback. Back to FP16.

Two weeks later, the MIT Han Lab published AWQ. The team tried the AWQ-quantized version of the same model. Same 4-bit compression, same memory footprint. But the reasoning accuracy on that problem class was within 0.3% of FP16. They shipped it. No rollbacks. The difference: AWQ's mechanism for protecting the tiny fraction of weights most responsible for reasoning accuracy had preserved exactly the capability that GPTQ had eroded.

The MIT Han Lab's key experiment was simple. They took a 7B language model and quantized it to 4-bit using three methods: (1) naive round-to-nearest INT4, (2) GPTQ with full Hessian-based error compensation, and (3) a deceptively simple approach - protect the 1% of weight channels most activated by typical inputs, and quantize everything else aggressively. The result: the "protect 1%, quantize 99%" method matched GPTQ's accuracy - without any Hessian computation. The insight was stark: not all weights contribute equally to model output. A tiny fraction of weight channels - those corresponding to input dimensions that consistently carry large activation magnitudes - cause disproportionate damage when quantized. Protecting these "salient" weights preserves most of the model's accuracy, regardless of what happens to the other 99%.

This experiment became AWQ: Activation-Aware Weight Quantization. Today, AWQ models on Hugging Face consistently outperform GPTQ models on most benchmarks at 4-bit precision, and serve 20-40% more tokens per second thanks to optimized GEMM kernels like Marlin. For production INT4 inference on NVIDIA Ampere or later GPUs, AWQ is the correct default choice in the majority of scenarios.

The Fundamental Insight: Why Some Weights Matter More

Consider a linear layer output: yi=jWijxjy_i = \sum_j W_{ij} \cdot x_j.

The quantization error for weight WijW_{ij} contributes to the output error by:

δyiδWijxj\delta y_i \approx \delta W_{ij} \cdot x_j

If xjx_j is typically large (high activation magnitude), even a small quantization error δWij\delta W_{ij} produces a large output error. If xjx_j is typically near zero, the same quantization error in WijW_{ij} has negligible output impact.

This is the activation-importance insight formalized mathematically:

Impact(j)=δW:,jEx ⁣[xj]\text{Impact}(j) = |\delta W_{:,j}| \cdot \mathbb{E}_{x}\!\left[|x_j|\right]

The expected absolute activation magnitude E[xj]\mathbb{E}[|x_j|] tells you how much weight column jj's quantization errors are amplified into output errors. High-activation columns are "salient" - their quantization errors have outsized impact. Low-activation columns can tolerate large quantization errors because those errors get multiplied by a small xjx_j.

AWQ quantifies salience simply and efficiently:

importance(j)=Excalibration ⁣[xj]\text{importance}(j) = \mathbb{E}_{x \sim \text{calibration}}\!\left[|x_j|\right]

Columns with high mean absolute activation are salient. Columns with low mean absolute activation can be quantized aggressively. This computation requires only a forward pass through the model - no expensive Hessian inversion required.

The AWQ Scaling Trick: Protection Without Extra Memory

The key challenge: you cannot simply keep salient weights at FP16 and quantize the rest at INT4. That would require storing different data types for different columns, doubling memory access complexity and destroying throughput. It would also require special-casing in the compute kernels - a performance and engineering nightmare.

AWQ solves this with a mathematical invariance that is both elegant and powerful. For any scalar s>0s > 0:

W:,jxj=(sW:,j)(xj/s)W_{:,j} \cdot x_j = (s \cdot W_{:,j}) \cdot (x_j / s)

The matrix product is unchanged. But the scaled weight sW:,js \cdot W_{:,j} spans a larger range. With a fixed number of INT4 bins (16 values), the scaled weight has a larger absolute value relative to the step size, which means quantization noise represents a smaller fraction of its magnitude:

Relative quantization error=ΔW:,jΔsW:,j\text{Relative quantization error} = \frac{\Delta}{|W_{:,j}|} \rightarrow \frac{\Delta}{s \cdot |W_{:,j}|}

Scaling up the weight by ss divides the relative quantization error by ss - the weight "uses" the quantization range more efficiently.

The corresponding input scaling (xjxj/sx_j \to x_j / s) is implemented by pre-scaling the previous layer's output weight matrix: instead of modifying the input at runtime, AWQ multiplies the previous layer's weights by 1/s1/s per output dimension during quantization. This means the scaling has zero inference overhead - it is absorbed into the model weights during quantization, not computed at runtime.

Finding the optimal scale: AWQ uses a grid search over a family of per-channel scales parameterized by a single exponent α[0,1]\alpha \in [0, 1]:

sj=importance(j)α=E[xj]αs_j = \text{importance}(j)^{\alpha} = \mathbb{E}[|x_j|]^{\alpha}

When α=0\alpha = 0: uniform scaling (no protection). When α=1\alpha = 1: scale proportional to activation importance. The optimal α\alpha is found by evaluating output error on calibration data for 20 candidate values and picking the best. This is much cheaper than computing the full Hessian - just 20 forward passes with quantized weights.

AWQ Implementation From First Principles

import torch
import numpy as np
from typing import Tuple, List, Optional


def compute_activation_magnitudes(
calibration_inputs: List[torch.Tensor],
device: str = "cuda",
smooth_factor: float = 0.0,
) -> torch.Tensor:
"""
Compute per-channel activation magnitudes from calibration data.

Returns the AWQ importance score for each input dimension:
importance_j = E[|x_j|] averaged over all calibration samples.

High values indicate salient input dimensions - weight columns
corresponding to these dimensions cause the most output error when quantized.

Args:
calibration_inputs: List of input tensors, each shape (batch, seq_len, d_in)
or (n_tokens, d_in)
smooth_factor: If > 0, blend importance with uniform distribution.
Use 0.1 if a few dimensions are extreme outliers
that would dominate the scale allocation.

Returns:
importance: Tensor of shape (d_in,), importance per input dimension
"""
importance_accumulator = None
n_samples = 0

for inp in calibration_inputs:
inp = inp.to(device).float()

# Flatten batch and sequence dimensions to (n_tokens, d_in)
if inp.dim() == 3:
inp = inp.reshape(-1, inp.shape[-1])
elif inp.dim() != 2:
raise ValueError(f"Expected 2D or 3D input, got {inp.dim()}D")

# Mean absolute activation per dimension
dim_importance = inp.abs().mean(dim=0) # shape: (d_in,)

if importance_accumulator is None:
importance_accumulator = dim_importance
else:
importance_accumulator = importance_accumulator + dim_importance

n_samples += 1

if importance_accumulator is None or n_samples == 0:
raise ValueError("No calibration inputs provided")

importance = importance_accumulator / n_samples # (d_in,)

# Optional smoothing: blend with uniform to reduce extreme outlier dominance
if smooth_factor > 0:
uniform = torch.ones_like(importance) * importance.mean()
importance = (1 - smooth_factor) * importance + smooth_factor * uniform

return importance


def find_optimal_awq_scales(
weight: torch.Tensor,
activation_importance: torch.Tensor,
calibration_inputs: List[torch.Tensor],
n_bits: int = 4,
n_scale_search_steps: int = 20,
device: str = "cuda",
) -> torch.Tensor:
"""
Search for the optimal per-channel scale factors for AWQ.

AWQ uses a parametric search over scales of the form:
s_j = importance_j ^ alpha, for alpha in [0, 1]

- alpha = 0: no scaling (standard quantization)
- alpha = 1: scale fully proportional to activation importance

We grid-search alpha and find the value minimizing output error
on calibration data.

Args:
weight: Weight matrix, shape (d_out, d_in)
activation_importance: Mean absolute activation per dimension, shape (d_in,)
calibration_inputs: Calibration activations for computing output error
n_bits: Quantization bits (4 for INT4)
n_scale_search_steps: Number of alpha values to evaluate

Returns:
optimal_scales: Per-channel scale factors, shape (d_in,)
"""
d_out, d_in = weight.shape
Q_MAX = 2 ** (n_bits - 1) - 1 # 7 for INT4 (symmetric)

weight = weight.to(device).float()
activation_importance = activation_importance.to(device).float()

# Prepare calibration inputs for output error measurement
calibration_sample = calibration_inputs[:min(8, len(calibration_inputs))]
calibration_batch = []
for inp in calibration_sample:
inp = inp.to(device).float()
if inp.dim() == 3:
inp = inp.reshape(-1, inp.shape[-1])
calibration_batch.append(inp)

if not calibration_batch:
return torch.ones(d_in, device=device)

calibration_concat = torch.cat(calibration_batch, dim=0) # (n_total, d_in)

# Reference output before any scaling or quantization
with torch.no_grad():
y_reference = calibration_concat @ weight.T # (n_total, d_out)

best_scales = torch.ones(d_in, device=device)
best_error = float('inf')

# Grid search over alpha values
alpha_values = torch.linspace(0.0, 1.0, n_scale_search_steps)

for alpha in alpha_values:
# s_j = importance_j ^ alpha, normalized so mean scale = 1
scales = activation_importance.float() ** alpha.item()
scales = scales / (scales.mean() + 1e-8) # Normalize to prevent scale collapse
scales = scales.clamp(min=1e-4, max=1e4)

# Apply scales: scale up salient columns in the weight matrix
scaled_weight = weight * scales.unsqueeze(0) # (d_out, d_in)

# Quantize the scaled weight (per-row symmetric for this search step)
w_abs_max = scaled_weight.abs().max(dim=1, keepdim=True).values
q_scale = w_abs_max / Q_MAX
q_scale = q_scale.clamp(min=1e-8)
q_weight = torch.round(scaled_weight / q_scale).clamp(-Q_MAX, Q_MAX)
q_weight_dequant = q_weight * q_scale

# Unscale: recover the original weight space for error measurement
q_weight_unscaled = q_weight_dequant / scales.unsqueeze(0) # (d_out, d_in)

# Compute output error on calibration data
with torch.no_grad():
y_quantized = calibration_concat @ q_weight_unscaled.T
error = ((y_reference - y_quantized) ** 2).mean().item()

if error < best_error:
best_error = error
best_scales = scales.clone()

return best_scales


def awq_quantize_weight_matrix(
weight: torch.Tensor,
optimal_scales: torch.Tensor,
n_bits: int = 4,
group_size: int = 128,
zero_point: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Apply AWQ quantization to a weight matrix using pre-computed optimal scales.

Steps:
1. Apply pre-computed optimal scales: scale up salient columns
2. Quantize scaled weight with per-group asymmetric quantization
3. Store quantized INT4 weights with group scales and zeros

At inference:
- Dequantize: q_weight → float (using group scale and zero)
- The input scaling (divide by channel_scales) is absorbed into
the PREVIOUS layer's output weights at quantization time.
- So inference is: standard INT4 matmul + group dequant, nothing else.

Returns:
quantized_int: INT4 values as uint8 (packed 2 per byte in real impl)
group_scales: Per-group quantization scales, shape (d_out, n_groups)
group_zeros: Per-group zero-points, shape (d_out, n_groups)
"""
d_out, d_in = weight.shape
n_groups = (d_in + group_size - 1) // group_size
Q_MAX = 2 ** n_bits - 1 # 15 for INT4 unsigned asymmetric

# Scale up salient weight columns using AWQ-found scales
scaled_weight = weight.float() * optimal_scales.unsqueeze(0) # (d_out, d_in)

# Quantize with per-group asymmetric quantization
quantized_int = torch.zeros(d_out, d_in, dtype=torch.uint8, device=weight.device)
group_scales = torch.zeros(d_out, n_groups, dtype=torch.float16, device=weight.device)
group_zeros = torch.zeros(d_out, n_groups, dtype=torch.float16, device=weight.device)

for g in range(n_groups):
col_start = g * group_size
col_end = min(col_start + group_size, d_in)
w_group = scaled_weight[:, col_start:col_end] # (d_out, group_size)

if zero_point:
# Asymmetric: use full [0, Q_MAX] range - better accuracy
w_min = w_group.min(dim=1, keepdim=True).values
w_max = w_group.max(dim=1, keepdim=True).values
scale_g = (w_max - w_min) / Q_MAX
scale_g = scale_g.clamp(min=1e-8)
zero_g = -torch.round(w_min / scale_g).clamp(0, Q_MAX)
q_g = (torch.round(w_group / scale_g) + zero_g).clamp(0, Q_MAX).to(torch.uint8)
else:
# Symmetric: simpler, slightly lower accuracy for asymmetric distributions
w_max_abs = w_group.abs().max(dim=1, keepdim=True).values
half_max = Q_MAX // 2
scale_g = w_max_abs / half_max
scale_g = scale_g.clamp(min=1e-8)
zero_g = torch.full_like(scale_g, half_max)
q_g = (torch.round(w_group / scale_g) + half_max).clamp(0, Q_MAX).to(torch.uint8)

quantized_int[:, col_start:col_end] = q_g
group_scales[:, g] = scale_g.squeeze(1).half()
group_zeros[:, g] = zero_g.squeeze(1).half()

return quantized_int, group_scales, group_zeros

Using AutoAWQ: The Production Library

The autoawq library implements the full AWQ pipeline with optimized inference kernels. This is the recommended approach for production deployments:

# pip install autoawq transformers accelerate
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from typing import List, Optional
import torch


def quantize_model_awq(
model_name: str,
output_path: str,
n_bits: int = 4,
group_size: int = 128,
zero_point: bool = True,
version: str = "GEMM",
n_calibration_samples: int = 128,
seq_length: int = 512,
calib_data: str = "pileval",
) -> None:
"""
Quantize a model using AWQ and save it for production deployment.

AWQ quantization is typically faster than GPTQ (no Hessian inversion)
and achieves comparable or better accuracy at INT4.

Time estimates on A100 80GB:
7B model: 30-60 minutes
13B model: 1-2 hours
70B model: 3-5 hours (requires 2x A100 for FP16 loading)

Args:
model_name: HuggingFace model ID or local path
output_path: Where to save the quantized model
n_bits: 4 (standard; 8-bit also supported but bitsandbytes is simpler for INT8)
group_size: 128 (balanced default) or 64 (better accuracy, ~12% more scale overhead)
zero_point: True = asymmetric quantization (better accuracy, recommended)
False = symmetric (simpler, slightly worse for asymmetric weight distributions)
version: Inference kernel - choose based on GPU and serving pattern:
"GEMM" = general matrix multiply, good for batch_size >= 2
"GEMV" = general matrix-vector, best for single-user batch_size=1
"Marlin" = adaptive, near-peak bandwidth at any batch size (recommended)
n_calibration_samples: 128 standard; use 256 for more accurate Hessian proxy
seq_length: Match your typical inference context length
calib_data: "pileval" (general), "wikitext2", or custom dataset name
"""
print(f"Starting AWQ quantization of {model_name}")
print(f"Config: {n_bits}-bit, group_size={group_size}, "
f"zero_point={zero_point}, version={version}")

# Load model and tokenizer in FP16 for quantization
model = AutoAWQForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
use_cache=False,
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# AWQ quantization configuration
quant_config = {
"zero_point": zero_point,
"q_group_size": group_size,
"w_bit": n_bits,
"version": version,
}

print("Running AWQ scale search and quantization...")
model.quantize(
tokenizer,
quant_config=quant_config,
calib_data=calib_data,
split="train",
text_column="text",
n_samples=n_calibration_samples,
seqlen=seq_length,
)

# Save in safetensors format (safer than .bin for quantized models)
model.save_quantized(output_path, safetensors=True)
tokenizer.save_pretrained(output_path)

# Report quantized model size
import os
total_size_gb = sum(
os.path.getsize(os.path.join(output_path, f))
for f in os.listdir(output_path)
if f.endswith(".safetensors") or f.endswith(".bin")
) / 1e9
print(f"\nQuantization complete!")
print(f" Saved to: {output_path}")
print(f" Model files size: {total_size_gb:.2f} GB")


def quantize_awq_with_custom_calibration(
model_name: str,
output_path: str,
calibration_texts: List[str],
n_bits: int = 4,
group_size: int = 128,
version: str = "GEMM",
) -> None:
"""
AWQ quantization with domain-specific calibration data.

Use this when default Pile calibration causes accuracy degradation
due to domain mismatch - e.g., medical, legal, code, or math models.

The calibration texts teach AWQ which weight channels are salient
in your deployment domain. Mismatch means the wrong 1% of weights
are "protected," causing 2-4% extra accuracy loss.

Args:
calibration_texts: 128-256 texts from your deployment domain.
Each should be 512+ tokens after tokenization.
Use actual user inputs if available (anonymized).
"""
model = AutoAWQForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

quant_config = {
"zero_point": True,
"q_group_size": group_size,
"w_bit": n_bits,
"version": version,
}

# Tokenize calibration texts
tokenized_data = []
for text in calibration_texts[:256]:
encoded = tokenizer(
text,
return_tensors="pt",
max_length=512,
truncation=True,
)
if encoded["input_ids"].shape[1] >= 64:
tokenized_data.append(encoded["input_ids"])

print(f"Calibrating with {len(tokenized_data)} domain-specific samples")
model.quantize(tokenizer, quant_config=quant_config, calib_data=tokenized_data)
model.save_quantized(output_path, safetensors=True)
tokenizer.save_pretrained(output_path)
print(f"Domain-calibrated AWQ model saved to {output_path}")


def load_and_run_awq(
model_path: str,
prompt: str,
max_new_tokens: int = 512,
fuse_layers: bool = True,
version: str = "GEMM",
) -> str:
"""
Load an AWQ-quantized model and run inference with all optimizations.

Args:
fuse_layers: Fuse attention QKV, MLP gate+up, and LayerNorm operations.
Reduces HBM round-trips for intermediate activations.
Typical speedup: 20-40% on throughput benchmarks.
Recommended: True for all production inference.
"""
model = AutoAWQForCausalLM.from_quantized(
model_path,
fuse_layers=fuse_layers,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
temperature=1.0,
)

generated = tokenizer.decode(
output[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True,
)
return generated

AWQ Inference Kernels: GEMM vs GEMV vs Marlin

AWQ's inference throughput advantage over GPTQ comes from purpose-built GPU kernels. Understanding the kernel variants lets you configure correctly for your serving pattern:

GEMV (General Matrix-Vector Multiply): During autoregressive decoding at batch size 1, you process one token at a time - a vector, not a matrix. GEMV kernels are optimized for this: they avoid the overhead of setting up large GEMM tiles when the "batch" is just a single vector. Result: lowest latency for single-user serving.

GEMM (General Matrix Multiply): When serving multiple concurrent users batched together, GEMM uses NVIDIA Tensor Cores efficiently for matrix-matrix operations. At batch size 4+, GEMM is typically 20-30% faster than GEMV.

Marlin: A research kernel from Frantar and Alistarh (2024) that adaptively handles both GEMM and GEMV efficiently. It processes 4-bit weights by streaming INT4 data and interleaving dequantization with computation to achieve near-peak memory bandwidth utilization regardless of batch size. Marlin is the recommended kernel for production serving systems where batch size varies.

def benchmark_awq_kernels(
model_path: str,
prompt: str = "Explain the transformer architecture and its key components in detail:",
batch_sizes: List[int] = [1, 4, 8, 16, 32],
n_new_tokens: int = 100,
n_warmup: int = 5,
n_runs: int = 20,
) -> dict:
"""
Benchmark AWQ GEMM vs GEMV vs Marlin kernels across batch sizes.

Expected outcome pattern:
- batch_size=1: GEMV often wins (lower overhead)
- batch_size=4+: GEMM or Marlin wins (Tensor Core utilization)
- All sizes: Marlin competitive with the winner at each batch size

Use this to select the kernel version for your deployment.
"""
import time

results = {}
kernel_versions = ["GEMM", "GEMV"] # Add "Marlin" if autoawq >= 0.2.0

for version in kernel_versions:
print(f"\nBenchmarking kernel: {version}")

model = AutoAWQForCausalLM.from_quantized(
model_path,
version=version,
fuse_layers=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

version_results = {}
for batch_size in batch_sizes:
prompts = [prompt] * batch_size
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

# Warmup (first few runs have JIT compilation overhead)
for _ in range(n_warmup):
with torch.no_grad():
model.generate(**inputs, max_new_tokens=10, do_sample=False)

# Timed runs with synchronization for accurate GPU measurement
latencies = []
for _ in range(n_runs):
torch.cuda.synchronize()
t = time.perf_counter()
with torch.no_grad():
model.generate(**inputs, max_new_tokens=n_new_tokens, do_sample=False)
torch.cuda.synchronize()
latencies.append(time.perf_counter() - t)

latencies.sort()
mean_latency = sum(latencies) / len(latencies)
p50_latency = latencies[len(latencies) // 2]
p90_latency = latencies[int(len(latencies) * 0.9)]
tps = batch_size * n_new_tokens / mean_latency

version_results[batch_size] = {
"mean_latency_ms": round(mean_latency * 1000, 1),
"p50_latency_ms": round(p50_latency * 1000, 1),
"p90_latency_ms": round(p90_latency * 1000, 1),
"tokens_per_second": round(tps, 1),
}
print(
f" batch={batch_size}: "
f"p50={p50_latency*1000:.1f}ms, "
f"p90={p90_latency*1000:.1f}ms, "
f"{tps:.0f} tok/s"
)

results[version] = version_results
del model
torch.cuda.empty_cache()

# Print comparison summary
print(f"\n{'='*60}")
print("Kernel comparison (winner per batch size):")
for bs in batch_sizes:
scores = {v: results[v][bs]["tokens_per_second"] for v in kernel_versions}
winner = max(scores, key=scores.get)
best = scores[winner]
worst = min(scores.values())
advantage = best / worst if worst > 0 else 1.0
tps_str = ", ".join(f"{v}={scores[v]}" for v in kernel_versions)
print(f" batch={bs}: {winner} wins ({tps_str}) - {advantage:.2f}x advantage")

return results

Layer Fusion: 20-40% Free Throughput

AWQ supports layer fusion - combining multiple sequential GPU kernel launches into a single kernel. This reduces memory bandwidth pressure by keeping data in GPU SRAM (L1/L2 cache) instead of repeatedly writing to and reading from HBM (the main GPU memory).

Transformer inference involves many sequential operations that are normally separate kernels:

Without fusion (8 separate HBM round-trips for one transformer block):
HBM → Q projection → HBM → K projection → HBM → V projection → ...

With QKV fusion (3 HBM round-trips instead of 8):
HBM → QKV projection (single kernel) → HBM → Attention → ...
Read Q,K,V weights once. Write QKV once. 2.7x fewer HBM accesses.

MLP gate+up fusion (Llama / Mistral architecture):
Without: Read gate_proj weights → HBM → Read up_proj weights → HBM → SiLU
With: Read gate+up weights → (single kernel) → SiLU - 2x fewer reads

Total savings from all fusions: ~30-50% reduction in HBM traffic
At 70-80% inference time being memory-bandwidth-bound: 20-40% throughput gain
def load_awq_with_all_optimizations(
model_path: str,
version: str = "GEMM",
) -> tuple:
"""
Load AWQ model with all available throughput optimizations.

Fusion types enabled with fuse_layers=True:
- Attention QKV fusion: Single kernel computes Q, K, V projections
- MLP gate+up fusion: Single kernel for gate and up (Llama/Mistral architecture)
- RMSNorm + linear fusion: Norm computed inline before subsequent projection

Expected throughput improvement over unfused model: 20-40%
No accuracy impact - fusion changes only execution order, not computation.
"""
model = AutoAWQForCausalLM.from_quantized(
model_path,
fuse_layers=True, # Enable all available fusions
version=version, # Inference kernel version
)

# Audit which modules were fused
fused_modules = {}
for name, module in model.named_modules():
mtype = type(module).__name__
if "Fused" in mtype or "fused" in mtype.lower():
fused_modules[name] = mtype

print(f"Fused modules: {len(fused_modules)} total")
for name, mtype in list(fused_modules.items())[:5]:
print(f" {name}: {mtype}")
if len(fused_modules) > 5:
print(f" ... and {len(fused_modules) - 5} more")

tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer


def estimate_awq_memory_requirements(
model_params_b: float,
batch_size: int = 1,
prompt_length: int = 512,
n_new_tokens: int = 256,
n_bits: int = 4,
group_size: int = 128,
) -> dict:
"""
Estimate total GPU memory needed for AWQ model inference.

Memory breakdown:
1. Weight memory: params × (n_bits/8) bytes
2. Scale overhead: 1 FP16 per group (2 bytes per group_size weights)
3. KV cache: 2 × n_layers × n_kv_heads × head_dim × seq_len × batch × 2 bytes
4. Activation memory: batch × seq_len × hidden_dim × 2 bytes
5. Framework overhead: ~1 GB (CUDA, PyTorch state)
"""
n_params = model_params_b * 1e9

# Weight memory (INT4 = 0.5 bytes/param)
weight_gb = n_params * n_bits / 8 / 1e9

# Scale and zero-point overhead (2 FP16 values per group)
scale_overhead_gb = n_params * 2 * 2 / group_size / 1e9

# Architecture-based KV cache estimation (Llama-family approximation)
n_layers_approx = max(16, int(32 * (model_params_b / 7) ** 0.4))
n_kv_heads = 8
head_dim = 128
total_seq = prompt_length + n_new_tokens

kv_cache_bytes = (
2 # K and V
* n_layers_approx
* n_kv_heads
* head_dim
* total_seq
* batch_size
* 2 # BF16 = 2 bytes
)
kv_cache_gb = kv_cache_bytes / 1e9

# Activation memory approximation
hidden_dim_approx = n_kv_heads * head_dim * 8 # Rough scaling
activation_gb = batch_size * total_seq * hidden_dim_approx * 2 / 1e9

framework_gb = 1.0
total_gb = weight_gb + scale_overhead_gb + kv_cache_gb + activation_gb + framework_gb

def gpu_recommendation(gb: float) -> str:
if gb <= 8: return "RTX 4070 Ti / 3080 12GB"
elif gb <= 16: return "RTX 4080 / 3080 Ti 16GB"
elif gb <= 24: return "RTX 4090 24GB"
elif gb <= 40: return "A100 40GB"
elif gb <= 80: return "A100 80GB / H100 80GB"
else: return f"Multi-GPU: {-(-int(gb)//80)}× A100/H100 80GB"

return {
"model_params_b": model_params_b,
"n_bits": n_bits,
"batch_size": batch_size,
"weight_memory_gb": round(weight_gb, 2),
"scale_overhead_gb": round(scale_overhead_gb, 3),
"kv_cache_gb": round(kv_cache_gb, 3),
"activation_gb": round(activation_gb, 3),
"framework_gb": framework_gb,
"total_estimated_gb": round(total_gb, 2),
"gpu_recommendation": gpu_recommendation(total_gb),
}

AWQ Accuracy Evaluation: How to Measure What Matters

After quantizing with AWQ, always evaluate accuracy before deploying. Perplexity alone is insufficient - it correlates loosely with task accuracy but does not capture task-specific failure modes.

import torch
import math
from awq import AutoAWQForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from typing import List, Dict


def evaluate_awq_vs_fp16(
awq_model_path: str,
fp16_model_name: str,
evaluation_prompts: List[str],
expected_outputs: List[str],
max_new_tokens: int = 128,
device: str = "cuda",
) -> Dict[str, float]:
"""
Compare AWQ INT4 model accuracy against FP16 baseline on your task.

This is the production validation step before deploying a quantized model.
Always run this comparison - perplexity numbers can look fine while task
accuracy has meaningful degradation.

Args:
awq_model_path: Path to quantized AWQ model
fp16_model_name: Original FP16 model for comparison
evaluation_prompts: Task-specific prompts (use your real deployment inputs)
expected_outputs: Ground truth or reference outputs for scoring

Returns dict with:
awq_accuracy: Fraction of outputs matching expected (exact or fuzzy)
fp16_accuracy: Same for FP16 baseline
relative_retention: awq_accuracy / fp16_accuracy (1.0 = no degradation)
awq_throughput_tps: Tokens per second for AWQ
fp16_throughput_tps: Tokens per second for FP16
"""
import time

results = {}

for model_label, model_path, is_awq in [
("AWQ INT4", awq_model_path, True),
("FP16 baseline", fp16_model_name, False),
]:
print(f"\nEvaluating: {model_label}")

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

if is_awq:
model = AutoAWQForCausalLM.from_quantized(
model_path, fuse_layers=True
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="auto"
)

model.eval()
correct = 0
total_tokens = 0
total_time = 0.0

for prompt, expected in zip(evaluation_prompts, expected_outputs):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

torch.cuda.synchronize()
t0 = time.perf_counter()
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
)
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0

generated = tokenizer.decode(
output_ids[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True,
)
total_tokens += max_new_tokens
total_time += elapsed

# Simple exact-match scoring (replace with task-specific metric)
if expected.lower().strip() in generated.lower():
correct += 1

accuracy = correct / len(evaluation_prompts)
throughput = total_tokens / total_time

results[model_label] = {
"accuracy": accuracy,
"throughput_tps": throughput,
"n_examples": len(evaluation_prompts),
}
print(f" Accuracy: {accuracy:.3f}")
print(f" Throughput: {throughput:.1f} tok/s")

del model
torch.cuda.empty_cache()

awq_acc = results["AWQ INT4"]["accuracy"]
fp16_acc = results["FP16 baseline"]["accuracy"]
retention = awq_acc / fp16_acc if fp16_acc > 0 else 0.0

return {
"awq_accuracy": awq_acc,
"fp16_accuracy": fp16_acc,
"relative_retention": retention,
"awq_throughput_tps": results["AWQ INT4"]["throughput_tps"],
"fp16_throughput_tps": results["FP16 baseline"]["throughput_tps"],
"throughput_speedup": results["AWQ INT4"]["throughput_tps"] / results["FP16 baseline"]["throughput_tps"],
}


def compute_perplexity_awq(
model_path: str,
dataset_name: str = "wikitext",
n_samples: int = 50,
max_length: int = 2048,
stride: int = 512,
) -> float:
"""
Compute perplexity for an AWQ model.

WikiText-2 perplexity reference values (lower is better):
FP16 Llama-3.1-8B: ~6.2
AWQ INT4 group128: ~6.5–6.8 (5-10% higher PPL)
GPTQ INT4 group128: ~6.6–6.9
Naive INT4: ~9–15 (catastrophic)

Use perplexity as a quick sanity check.
Always follow with task-specific accuracy evaluation.
"""
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoAWQForCausalLM.from_quantized(
model_path, fuse_layers=True
)
model.eval()

dataset = load_dataset(dataset_name, "wikitext-2-raw-v1", split="test")
text = "\n\n".join(dataset["text"])
encodings = tokenizer(text, return_tensors="pt")
input_ids = encodings.input_ids.to(model.device)

nlls = []
for begin_loc in range(0, min(input_ids.shape[1] - max_length, n_samples * stride), stride):
end_loc = begin_loc + max_length
target_len = max_length - stride

chunk = input_ids[:, begin_loc:end_loc]
target_ids = chunk.clone()
target_ids[:, :-target_len] = -100

with torch.no_grad():
output = model(chunk, labels=target_ids)
nlls.append(output.loss.float().item())

ppl = math.exp(sum(nlls) / len(nlls))
return ppl

AWQ vs. GPTQ: Full Technical Comparison

MetricGPTQ INT4AWQ INT4
WikiText-2 PPL (Llama-2-70B)3.473.40
MMLU accuracy drop (typical)-1.9%-1.0%
GSM8K (math) accuracy drop-6.2%-4.8%
Throughput at batch=1 (tok/s)~26~30
Throughput at batch=8 (tok/s)~108~134
Quantization time (7B, A100)30-60 min30-60 min
Quantization time (70B, A100)~4 hours~3.5 hours
Calibration data requiredYesYes
Best inference kernelExLlama v2 / TritonMarlin / GEMM
3-bit supportYes (good accuracy)Not recommended
CPU inference (GGUF)YesNo
Column reordering overheadWith desc_act=TrueNot needed
vLLM integrationFullFull

Deploying AWQ with vLLM

For high-throughput production serving, vLLM provides the best AWQ integration:

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine


def create_awq_vllm_engine(
model_path: str,
tensor_parallel_size: int = 1,
max_model_len: int = 8192,
gpu_memory_utilization: float = 0.90,
enable_prefix_caching: bool = True,
) -> LLM:
"""
Initialize vLLM engine with AWQ model for production serving.

vLLM provides three key benefits over direct AutoAWQ inference:
1. Continuous batching: requests start as soon as GPU capacity frees
2. Paged attention: KV cache in non-contiguous pages → more concurrency
3. Integrated AWQ kernels: Marlin support for near-peak bandwidth

Args:
tensor_parallel_size: Shard across multiple GPUs (for large models)
gpu_memory_utilization: Reserve this fraction for model + KV cache.
0.90 = 90% → leaves 10% for other overhead.
enable_prefix_caching: Cache KV for shared system prompt prefixes.
Essential if all requests share a system prompt.
"""
return LLM(
model=model_path,
quantization="awq", # Tell vLLM this is AWQ
dtype="float16",
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enable_prefix_caching=enable_prefix_caching,
max_num_seqs=512, # Max concurrent sequences
)


def batch_inference_awq(
engine: LLM,
prompts: list[str],
max_tokens: int = 512,
temperature: float = 0.0,
) -> list[str]:
"""
High-throughput batch inference with AWQ model via vLLM.
vLLM internally handles request scheduling and continuous batching.
"""
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=0.9 if temperature > 0 else 1.0,
)
outputs = engine.generate(prompts, sampling_params)
return [out.outputs[0].text for out in outputs]


def awq_serving_fastapi_example():
"""
Example FastAPI endpoint structure for AWQ model serving.
This shows the pattern - import and run separately.
"""
from fastapi import FastAPI
from pydantic import BaseModel
import asyncio
import uuid

app = FastAPI(title="AWQ Model Serving API")

class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 512
temperature: float = 0.0
stream: bool = False

class GenerateResponse(BaseModel):
text: str
tokens_generated: int
model_path: str

# In production, initialize engine at startup:
# engine = create_awq_vllm_engine("path/to/awq/model")

@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
# engine is the vLLM LLM instance
# results = batch_inference_awq(engine, [request.prompt], request.max_tokens)
# return GenerateResponse(
# text=results[0],
# tokens_generated=len(results[0].split()),
# model_path="path/to/model"
# )
pass # Placeholder - wire up to actual engine

return app

Production Deployment Checklist

:::danger Do Not Quantize Embeddings or the LM Head The token embedding table and language model head (output projection) have fundamentally different numerical properties from transformer body layers. The embedding table maps discrete token IDs to continuous vectors - quantizing it loses fine-grained distinctions between similar tokens. The LM head must produce precise logits to correctly rank the next token. Both AWQ and GPTQ skip these layers by default. If implementing custom quantization, explicitly exclude embed_tokens, lm_head, and all layer normalization layers from quantization. :::

:::warning Domain-Specific Models Need Domain-Specific Calibration AWQ's activation importance statistics are only as good as the calibration data. For a model fine-tuned on medical Q&A, using Pile or C4 calibration data will misidentify which weight columns are salient for medical language. The model's medical-specific weight activations will not appear prominently in general-text calibration statistics. Result: the 1% of weights that most matter for medical accuracy are not protected. Always use calibration data that matches your deployment distribution - this applies equally to code, math, legal, and domain-specialized models. :::

:::tip Fine-Tune First, Then Quantize If you need a fine-tuned, compressed model: (1) Fine-tune the base model using full fine-tuning or LoRA. (2) Quantize the fine-tuned model with AWQ using fine-tune-domain calibration data. Do NOT quantize first and then fine-tune - INT4 quantized weights cannot be directly fine-tuned with standard backpropagation. If you need post-quantization fine-tuning flexibility, use QLoRA on the NF4 base model, merge the adapter, then run AWQ on the merged FP16 model. :::

:::info AWQ with vLLM for Maximum Throughput For production serving of AWQ models, vLLM provides: continuous batching (requests start as soon as capacity frees), paged attention (efficient KV cache management), and integrated AWQ Marlin kernel support. Load with quantization="awq" in the vLLM LLM constructor. For multi-GPU serving, also set tensor_parallel_size. vLLM handles INT4-to-BF16 dequantization internally - no extra configuration needed. :::

Interview Questions

Q1: What is AWQ's core insight, and how does it differ from GPTQ's approach?

AWQ's insight is that not all weights are equally important for model accuracy when quantized. Specifically, weight columns corresponding to input dimensions with large activation magnitudes - "salient" weights - cause disproportionately large output errors when quantized, because their errors are amplified by the large input values. The impact of quantizing weight column jj scales with E[xj]\mathbb{E}[|x_j|], the expected activation magnitude for that dimension. AWQ identifies salient columns using only a forward pass on calibration data (mean absolute activation per dimension), then scales those columns up before quantization. After scaling, the quantizer allocates more effective precision to salient weights without storing them at higher bit-width. The scaling is absorbed into model weights during quantization - zero runtime memory or compute overhead. GPTQ's approach is complementary but different: it uses second-order Hessian information to propagate each weight's quantization error to remaining weights. AWQ is simpler (no Hessian inversion), achieves better accuracy at INT4, and produces faster inference through optimized kernels without the column-reordering overhead that complicates GPTQ.

Q2: How does AWQ protect salient weights without storing them at higher precision?

AWQ uses the mathematical identity W:,jxj=(sW:,j)(xj/s)W_{:,j} \cdot x_j = (s \cdot W_{:,j}) \cdot (x_j / s) for any scalar s>1s > 1. The matrix product is unchanged, but the scaled weight sW:,js \cdot W_{:,j} spans a larger absolute range. When this larger-range weight is quantized to INT4 (16 values), the relative quantization error decreases: Δ/(sW:,j)\Delta / (s \cdot |W_{:,j}|) versus Δ/W:,j\Delta / |W_{:,j}| before scaling. The corresponding input scaling (xjxj/sx_j \to x_j / s) is implemented by pre-multiplying the previous layer's output projection weights by 1/s1/s per output dimension during quantization. The scaling is baked in at quantization time, not computed at runtime. The inference computation is identical to standard INT4 quantization: dequantize weights to BF16, multiply by input, no special cases. Protection without precision overhead.

Q3: When should you use GEMM vs GEMV vs Marlin kernels in AWQ?

GEMV (General Matrix-Vector Multiply) is optimized for single-vector inputs - when batch_size=1 during autoregressive decoding. The activation is a vector, not a matrix, and GEMV kernels minimize latency by avoiding GEMM tile setup overhead. GEMM (General Matrix Multiply) is optimized when multiple requests are batched. It uses NVIDIA Tensor Cores efficiently and achieves high throughput at batch_size 4 or more. Marlin is a research kernel that adaptively handles both, streaming INT4 weights and interleaving dequantization with computation to achieve near-peak memory bandwidth utilization at any batch size. Marlin is the recommended choice for production serving systems where batch size varies. In practice: choose GEMV for always-single-user applications (code completion tools). Choose GEMM for consistently batched workloads. Choose Marlin for production APIs with variable load.

Q4: What is layer fusion in AWQ and what speedup does it provide?

Layer fusion combines multiple sequential GPU kernel launches into a single kernel, reducing the number of times intermediate activations must be written to and read from HBM (GPU main memory). In a standard transformer layer without fusion, each operation is a separate kernel: Q projection, K projection, V projection, attention, output projection, gate, up, activation, down. Each launch reads weights from HBM, computes, writes activations back to HBM, and the next kernel reads those activations again. With QKV fusion: read three weight matrices once, write QKV once - instead of three separate read-compute-write cycles. Gate+up fusion similarly combines two projections. Since transformer inference is typically 70-85% memory-bandwidth-bound, reducing HBM traffic improves throughput. Measured speedup: 20-40% on throughput benchmarks. Enable with fuse_layers=True when loading - no accuracy impact, significant performance gain.

Q5: How do you choose between AWQ and bitsandbytes NF4 for your deployment?

They serve different primary purposes. AWQ is optimized for inference throughput: quantized model is stored with precomputed scales, and specialized GEMM kernels handle dequantization during matmul with minimal overhead. AWQ models are faster at inference but cannot be directly fine-tuned. bitsandbytes NF4 is optimized for training flexibility: the quantized base model supports LoRA fine-tuning (QLoRA), enabling training with 4x fewer GPU memory requirements. NF4 computes dequantization at runtime in a way that supports gradient flow through LoRA adapters. For inference-only deployment: AWQ wins on throughput (10-30% faster than NF4). For fine-tuning on consumer hardware: NF4 with QLoRA is the correct choice. For the common pattern of fine-tuning then deploying: fine-tune with NF4+LoRA, merge the adapter into the base model, requantize the merged FP16 model with AWQ for deployment.

Q6: Walk through how you would validate that an AWQ model is production-ready before shipping.

Validation has four stages. First, compute perplexity on domain-matched text - not just WikiText-2. If domain perplexity is significantly higher than FP16 perplexity, re-run AWQ with domain-specific calibration data. Second, run task-specific accuracy benchmarks: the specific capabilities your product uses. Compare AWQ INT4 to FP16 on 500+ examples. Target less than 2-3% relative drop for non-critical capabilities, less than 1% for accuracy-critical features. Third, test at production sequence lengths - quantization errors compound over long contexts. Run evaluation at 512, 2K, 4K, 8K tokens if your system uses long contexts. Fourth, load-test: run the AWQ model through vLLM at your expected concurrent request count and measure P50, P90, P99 latency and throughput. Compare to SLA requirements. Only after all four stages pass should you promote to production. The most common failure mode is skipping stage 2 and discovering domain accuracy problems after deployment.

© 2026 EngineersOfAI. All rights reserved.