Skip to main content

:::tip 🎮 Interactive Playground Visualize this concept: Try the Knowledge Distillation demo on the EngineersOfAI Playground - no code required. :::

Unstructured Pruning: The Theory is Beautiful, the Practice is Complicated

It's 11 PM on a Thursday. A senior engineer at a startup has just read the SparseGPT paper - Frantar and Alistarh showing 50% of a GPT model's weights can be zeroed out in a single pass with less than 1% accuracy loss. She opens her terminal, quantizes their production Llama-2-13B model to 50% sparsity, and runs inference benchmarks. The perplexity numbers look pristine. The model still answers questions perfectly. She's elated - she's just cut their model size in half.

The next morning she shows the results to the team. Their infrastructure lead pulls up the latency numbers. There's silence. The sparse model is 2% slower than the dense model. Not faster - slower. She has a 50% sparse model that takes up exactly the same GPU memory as the dense model (zeros still occupy space in the weight tensors), runs slightly slower due to overhead from the sparsity mask handling, and required three hours of compute to create. This is the "unstructured pruning trap" - impressive paper results that don't translate to production speedups without specialized hardware.

Unstructured pruning is one of the most theoretically rich and practically treacherous topics in model compression. The Lottery Ticket Hypothesis is beautiful science. SparseGPT and Wanda are clever algorithms. 2:4 structured sparsity on NVIDIA Ampere GPUs is a genuinely useful technique. But most practitioners who pick up "let's try 50% sparsity" end up exactly where that engineer did - with a model that has the same cost profile as before, just with a lot of zeros in it. This lesson will explain the theory properly, show you the real code, and be honest about when unstructured pruning actually helps.

Why Unstructured Pruning Exists: The Overparameterization Hypothesis

Before diving into algorithms, understand the core motivation. Neural networks are dramatically overparameterized - a 7B-parameter model has vastly more parameters than the data it was trained on could possibly "need." The question is why this helps, and whether those excess parameters can be removed after training.

The Lottery Ticket Hypothesis (Frankle and Carlin, 2019) crystallized the scientific case for pruning: large networks exist partly because they make optimization easier, not because they're inherently needed for the final task. The winning sparse subnetwork was always there - the full network just provided the conditions to find it.

The Lottery Ticket Hypothesis: From Theory to Code

Frankle and Carlin demonstrated a four-step procedure that reveals "winning tickets" - sparse subnetworks that, when trained from their original initialization, match the full network's performance:

  1. Train the full network to convergence from initialization θ0\theta_0
  2. Prune the smallest-magnitude weights (e.g., keep top 10% by absolute value)
  3. Reset the remaining weights to their original values θ0\theta_0 (not to zero - to the initial values)
  4. Retrain this sparse network from θ0\theta_0

The result: the sparse network trained from θ0\theta_0 matches the full network. This is extraordinary - it implies the lottery ticket was set at initialization; training just revealed which ticket was the winner.

import torch
import torch.nn as nn
import copy
from typing import Optional


def find_lottery_ticket(
model: nn.Module,
initial_weights: dict,
pruning_fraction: float = 0.9,
layer_filter: Optional[callable] = None,
) -> tuple[nn.Module, dict]:
"""
Implements the Lottery Ticket finding algorithm (Frankle & Carlin, 2019).

Args:
model: The fully trained model
initial_weights: Saved state_dict from before training began
pruning_fraction: Fraction of weights to remove (0.9 = keep top 10%)
layer_filter: Optional function(name) -> bool to select layers

Returns:
(ticket_model, masks) - model reset to initial weights with pruning mask applied

IMPORTANT: You must save initial_weights BEFORE training starts:
initial_weights = {name: param.data.clone()
for name, param in model.named_parameters()}
"""
ticket = copy.deepcopy(model)
masks = {}

for name, param in ticket.named_parameters():
if name not in initial_weights:
continue
if layer_filter and not layer_filter(name):
continue
if "weight" not in name:
continue

current_weight = param.data
init_weight = initial_weights[name].to(param.device)

# Step 2: Create magnitude-based mask from TRAINED weights
abs_weights = current_weight.abs()
threshold = torch.quantile(abs_weights, pruning_fraction)
mask = (abs_weights >= threshold)

# Step 3: Reset to INITIAL weights (this is the key insight)
# Don't zero out pruned weights, reset kept weights to their initial values
param.data = init_weight * mask.float()

masks[name] = mask
sparsity = 1.0 - mask.float().mean().item()
print(f" {name}: {sparsity*100:.1f}% sparse, "
f"{mask.sum().item():,} weights kept")

return ticket, masks


class MaskedLinear(nn.Module):
"""
Linear layer with a fixed sparsity mask.

Applies a binary mask during forward pass to enforce sparsity.
The mask is registered as a buffer (persistent, non-trainable).
"""

def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)

# Register as buffer: persists in state_dict, not a parameter
self.register_buffer(
"mask",
torch.ones(out_features, in_features, dtype=torch.bool),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Apply mask to weights before computation
masked_weight = self.linear.weight * self.mask.float()
return torch.nn.functional.linear(x, masked_weight, self.linear.bias)

def prune_by_magnitude(self, fraction: float) -> None:
"""Zero out the lowest-magnitude fraction of weights."""
with torch.no_grad():
abs_weights = self.linear.weight.data.abs()
threshold = torch.quantile(abs_weights, fraction)
self.mask = abs_weights >= threshold
# Apply mask immediately
self.linear.weight.data *= self.mask.float()

def enforce_mask(self) -> None:
"""Re-apply mask after gradient update to keep pruned weights at zero."""
with torch.no_grad():
self.linear.weight.data *= self.mask.float()

@property
def sparsity(self) -> float:
return 1.0 - self.mask.float().mean().item()

@property
def trainable_params(self) -> int:
return self.mask.sum().item()

:::warning The LTH Doesn't Scale to Billion-Parameter LLMs The original Lottery Ticket Hypothesis results apply to smaller models (ResNets, small transformers) trained from scratch. For billion-parameter LLMs pre-trained on trillions of tokens, the "reset to initialization and retrain" step becomes impractical - you'd need to retrain the entire model from the original random weights, which is the original pre-training cost. Modern research focuses on magnitude pruning with fine-tuning (not retraining from scratch) as the practical analog for LLMs. :::

Magnitude Pruning: The Production Workhorse

When the lottery ticket procedure is too expensive, magnitude-based pruning with subsequent fine-tuning is the practical alternative. Remove the smallest-absolute-value weights globally or per-layer, then fine-tune the remaining weights to recover accuracy.

import numpy as np
from typing import Optional
import torch
import torch.nn as nn


def magnitude_pruning_global(
model: nn.Module,
target_sparsity: float,
layer_filter: Optional[callable] = None,
protect_biases: bool = True,
) -> dict[str, torch.Tensor]:
"""
Global magnitude pruning: find a single threshold across ALL weights.

Advantage: naturally allocates pruning budget where weights are smallest,
regardless of which layer they're in.
Disadvantage: some layers may be over-pruned, some under-pruned.

Args:
model: Model to prune (in-place modification)
target_sparsity: Fraction of total weights to zero out
layer_filter: Optional function(name) -> bool for layer selection
protect_biases: Skip bias vectors (usually much smaller and critical)

Returns:
Dictionary mapping parameter names to binary masks
"""
# Collect all weight magnitudes
all_abs_values = []
eligible_params = []

for name, param in model.named_parameters():
if protect_biases and "bias" in name:
continue
if "weight" not in name:
continue
if layer_filter and not layer_filter(name):
continue
all_abs_values.append(param.data.abs().reshape(-1))
eligible_params.append(name)

# Single global threshold
all_weights_cat = torch.cat(all_abs_values)
threshold = torch.quantile(all_weights_cat, target_sparsity)
print(f"Global threshold: {threshold:.6f} (targeting {target_sparsity*100:.1f}% sparsity)")

# Apply masks
masks = {}
total_params = 0
total_pruned = 0

for name, param in model.named_parameters():
if name not in eligible_params:
continue

mask = (param.data.abs() >= threshold)
param.data = param.data * mask.float()

masks[name] = mask
n_pruned = (~mask).sum().item()
total_params += mask.numel()
total_pruned += n_pruned

actual_sparsity = total_pruned / total_params
print(f"Achieved sparsity: {actual_sparsity*100:.2f}% ({total_pruned:,}/{total_params:,} weights zeroed)")

return masks


def magnitude_pruning_per_layer(
model: nn.Module,
target_sparsity: float,
min_sparsity: float = 0.0,
max_sparsity: float = 0.9,
skip_layers: Optional[list[str]] = None,
) -> dict[str, torch.Tensor]:
"""
Per-layer magnitude pruning: apply target sparsity independently to each layer.

Advantage: each layer maintains its relative weight distribution.
Better for preserving accuracy at high sparsity than global pruning.

Args:
skip_layers: Layer name substrings to protect (e.g., ["embed", "lm_head"])
"""
if skip_layers is None:
skip_layers = ["embed_tokens", "lm_head", "norm"]

masks = {}
n_pruned_layers = 0

for name, param in model.named_parameters():
if "weight" not in name:
continue
if any(skip in name for skip in skip_layers):
print(f" Skipping protected layer: {name}")
continue

# Apply per-layer sparsity, clamped to min/max bounds
layer_sparsity = max(min_sparsity, min(max_sparsity, target_sparsity))

abs_weights = param.data.abs()
threshold = torch.quantile(abs_weights, layer_sparsity)
mask = abs_weights >= threshold

param.data *= mask.float()
masks[name] = mask
n_pruned_layers += 1

actual_sparsity = 1.0 - mask.float().mean().item()
print(f" {name}: {actual_sparsity*100:.1f}% sparse")

print(f"Applied per-layer pruning to {n_pruned_layers} layers")
return masks


def gradual_magnitude_pruning(
model: nn.Module,
optimizer,
train_loader,
criterion,
initial_sparsity: float = 0.0,
final_sparsity: float = 0.9,
begin_step: int = 0,
end_step: int = 10_000,
frequency: int = 200,
n_total_steps: int = 12_000,
device: str = "cuda",
) -> None:
"""
Gradually increase sparsity during training.

The cubic schedule ramps slowly at first, then accelerates:
sparsity(t) = final * (1 - (1 - progress)^3)

Why gradual? One-shot pruning at 90% sparsity causes catastrophic accuracy loss.
Gradual pruning gives the model time to adapt - remaining weights compensate.
"""
step = 0
current_masks = None

for batch in train_loader:
if step >= n_total_steps:
break

# Standard training step
model.train()
inputs, labels = batch
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()

# Re-zero gradients for pruned weights (prevent them from updating)
if current_masks:
for name, param in model.named_parameters():
if name in current_masks and param.grad is not None:
param.grad.data *= current_masks[name].float()

optimizer.step()

# Re-apply mask after gradient update (optimizer may restore pruned weights)
if current_masks:
with torch.no_grad():
for name, param in model.named_parameters():
if name in current_masks:
param.data *= current_masks[name].float()

# Apply pruning at specified frequency during the active window
if begin_step <= step <= end_step and step % frequency == 0:
progress = (step - begin_step) / (end_step - begin_step)
current_target_sparsity = final_sparsity * (1.0 - (1.0 - progress) ** 3)

current_masks = magnitude_pruning_global(
model,
target_sparsity=current_target_sparsity,
)

if step % 1000 == 0:
print(f"Step {step}/{n_total_steps}: "
f"target sparsity = {current_target_sparsity*100:.1f}%")

step += 1

The Hardware Reality: Why Zeros Don't Mean Speedup

This is the most important practical fact in this entire lesson. Most practitioners skip it and waste weeks.

MYTH: "50% sparse model = 2× faster inference"

REALITY ON STANDARD GPU (no sparse hardware support):

Standard GPU GEMM (General Matrix Multiply):
┌─────────────────────────────────────────────────────┐
│ GPU loads Weight matrix in 128×128 tiles from HBM │
│ Tensor Core processes all elements in the tile │
│ Including zeros - no individual-element skip │
│ ALL tiles execute regardless of zero pattern │
└─────────────────────────────────────────────────────┘

50% sparse model still:
✗ Same memory footprint (zeros still in weight tensor)
✗ Same memory bandwidth usage (full tensor loaded)
✗ Same or slightly WORSE compute (mask overhead)

Result:
50% unstructured sparsity × standard GPU = ~0% speedup
(often 2-5% SLOWER due to mask handling overhead)

The path to actual speedup from sparsity requires hardware that physically skips zero computations or reduces storage:

def check_hardware_sparsity_support() -> dict:
"""
Determine what sparsity patterns will actually accelerate inference
on the current hardware.
"""
import torch

results = {
"gpu_available": torch.cuda.is_available(),
"gpu_name": None,
"compute_capability": None,
"supports_24_sparsity": False,
"supports_unstructured_sparse": False,
"recommended_approach": "structured_pruning_or_quantization",
}

if not torch.cuda.is_available():
results["recommended_approach"] = "cpu_sparse_blas_at_very_high_sparsity"
return results

gpu_name = torch.cuda.get_device_name(0)
results["gpu_name"] = gpu_name
results["compute_capability"] = torch.cuda.get_device_capability(0)

major, minor = torch.cuda.get_device_capability(0)
# Ampere = compute capability 8.x, Ada Lovelace/Hopper = 8.9/9.x
is_ampere_plus = major >= 8

gpu_name_lower = gpu_name.lower()
# Explicit check for known Ampere+ GPUs
ampere_plus_keywords = [
"a100", "a10", "a30", "a40", "a800",
"rtx 30", "rtx 40", "rtx 3090", "rtx 4090",
"h100", "h200", "l4", "l40",
]
is_known_ampere = any(kw in gpu_name_lower for kw in ampere_plus_keywords)

results["supports_24_sparsity"] = is_ampere_plus or is_known_ampere

if results["supports_24_sparsity"]:
results["recommended_approach"] = "24_structured_sparsity"
results["expected_speedup"] = "~1.7-2.0x for matrix-multiply heavy workloads"
else:
results["expected_speedup"] = "~0x (no hardware acceleration)"

return results


# Example output for A100:
# {
# 'gpu_name': 'NVIDIA A100-SXM4-80GB',
# 'compute_capability': (8, 0),
# 'supports_24_sparsity': True,
# 'recommended_approach': '24_structured_sparsity',
# 'expected_speedup': '~1.7-2.0x for matrix-multiply heavy workloads'
# }

NVIDIA 2:4 Structured Sparsity: The One That Actually Works

NVIDIA's Ampere architecture (A100, RTX 30xx) introduced sparse tensor cores that accelerate exactly one sparsity pattern: 2 non-zeros per every 4 consecutive elements in each row. This is "2:4 structured sparsity" - a constrained form that's close enough to unstructured to preserve accuracy, but regular enough for dedicated hardware acceleration.

The hardware stores only the 2 non-zero values per group plus 2 metadata bits indicating their positions, halving storage. During matrix multiply, the sparse tensor cores use metadata to efficiently compute only the non-zero multiplications, achieving close to 2× throughput over dense Tensor Cores.

import torch
import torch.nn as nn
from torch.ao.sparsity import WeightNormSparsifier


def apply_24_sparsity(model: nn.Module, skip_layers: list[str] = None) -> nn.Module:
"""
Apply NVIDIA 2:4 structured sparsity to all eligible linear layers.

After this:
- Weight matrices satisfy exactly 2 non-zeros per 4 consecutive elements
- On Ampere+ GPUs, sparse tensor cores provide ~1.7-2.0× speedup
- Memory usage drops by ~50% for weight tensors

Args:
skip_layers: Layer name substrings to exclude (e.g., embedding, lm_head)
"""
if skip_layers is None:
skip_layers = ["embed_tokens", "lm_head", "embed_positions"]

# Build list of eligible layers
layer_config = []
for name, module in model.named_modules():
if not isinstance(module, nn.Linear):
continue
if any(skip in name for skip in skip_layers):
continue
# Minimum size check: 2:4 sparsity requires width divisible by 4
if module.weight.shape[1] % 4 != 0:
print(f" Skipping {name}: width {module.weight.shape[1]} not divisible by 4")
continue
layer_config.append({"tensor_fqn": f"{name}.weight"})

print(f"Applying 2:4 sparsity to {len(layer_config)} layers...")

sparsifier = WeightNormSparsifier(
sparsity_level=0.5, # 50% = 2 out of 4 zeros
sparse_block_shape=(1, 4), # Row-wise, groups of 4
zeros_per_block=2, # Exactly 2 zeros per group
)

sparsifier.prepare(model, layer_config)

# Apply the sparsification
sparsifier.step()

# Verify sparsity pattern
n_verified = 0
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
fqn = f"{name}.weight"
if any(c["tensor_fqn"] == fqn for c in layer_config):
valid = check_24_sparsity_valid(module.weight.data)
if valid:
n_verified += 1
else:
print(f" WARNING: {name} failed 2:4 validation!")

print(f"Verified {n_verified}/{len(layer_config)} layers satisfy 2:4 constraint")
return model


def check_24_sparsity_valid(tensor: torch.Tensor) -> bool:
"""
Verify a weight tensor satisfies the 2:4 sparsity constraint.
Each group of 4 consecutive elements in a row must have exactly 2 zeros.
"""
# Work with 2D weight matrix (flatten extra dims if needed)
if tensor.dim() > 2:
tensor = tensor.reshape(tensor.shape[0], -1)

n_rows, n_cols = tensor.shape

if n_cols % 4 != 0:
return False # Cannot satisfy 2:4 on non-multiple-of-4 width

# Reshape to groups of 4
groups = tensor.reshape(n_rows, -1, 4)

# Count non-zeros per group
nonzeros_per_group = (groups != 0).sum(dim=-1)

# Each group must have exactly 2 non-zeros
return (nonzeros_per_group == 2).all().item()


def convert_to_sparse_tensor_core_format(
model: nn.Module,
save_path: str = None,
) -> nn.Module:
"""
Convert model to compressed sparse format for actual inference acceleration.
Requires `apex` library or `torch.sparse` semi-structured support.

NOTE: As of PyTorch 2.1+, use torch.sparse.semi_structured for production:
https://pytorch.org/docs/stable/sparse.html#sparse-semi-structured-tensors
"""
try:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor

for name, module in model.named_modules():
if not isinstance(module, nn.Linear):
continue
if not check_24_sparsity_valid(module.weight.data):
continue

# Convert to compressed semi-structured format
# This uses the metadata format NVIDIA hardware actually reads
module.weight = nn.Parameter(
to_sparse_semi_structured(module.weight.data)
)
print(f"Converted {name} to semi-structured sparse format")

except ImportError:
print("torch.sparse.to_sparse_semi_structured not available - upgrade to PyTorch 2.1+")
print("For now, 2:4 masks are applied but not in compressed format")

if save_path:
torch.save(model.state_dict(), save_path)
print(f"Sparse model saved to {save_path}")

return model

:::tip 2:4 Sparsity + Quantization = Extreme Compression 2:4 sparsity and INT4/INT8 quantization compose well. Apply AWQ or GPTQ first to quantize to 4-bit, then apply 2:4 sparsity on top. The result: 4-bit weights with 50% zeros = effectively 2-bit average bit-width, with hardware acceleration for both the quantized representation and the sparsity. This combination is used in production at NVIDIA for the most aggressive LLM compression. :::

SparseGPT: Second-Order Pruning for LLMs

SparseGPT (Frantar & Alistarh, 2023) brings the GPTQ approach - second-order error compensation via the Hessian inverse - to unstructured pruning. Instead of quantizing weights, it zeros them out, and uses the inverse Hessian to compensate remaining weights for the error introduced.

The key insight mirrors GPTQ: the Hessian H=2XXTH = 2XX^T captures how much each weight contributes to reconstruction error. For the linear approximation y^Wx\hat{y} \approx Wx, the loss increase from zeroing weight wijw_{ij} is approximately wij2/[H1]jjw_{ij}^2 / [H^{-1}]_{jj}. SparseGPT zeros the weight with minimum loss increase and compensates remaining weights.

import torch


def sparsegpt_prune_layer(
weight: torch.Tensor,
hessian: torch.Tensor,
target_sparsity: float = 0.5,
block_size: int = 128,
percdamp: float = 0.01,
use_24_pattern: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
SparseGPT: Unstructured pruning with second-order error compensation.

Processes columns in blocks of `block_size` to balance accuracy vs speed.
For each column block:
1. Score each weight by its loss increase if zeroed: w^2 / H_inv[j,j]
2. Zero the lowest-scoring weights to hit sparsity target
3. Compensate remaining weights using the inverse Hessian update

Args:
weight: Layer weight matrix (d_out × d_in)
hessian: Hessian H = 2 * E[X X^T] (d_in × d_in)
target_sparsity: Fraction of weights to zero out
block_size: Column block size for processing (larger = less accurate, faster)
use_24_pattern: If True, enforce 2:4 constraint during pruning

Returns:
(pruned_weight, binary_mask) - True = kept
"""
d_row, d_col = weight.shape
W = weight.clone().float()
H = hessian.clone().float()

# Damping for numerical stability (avoids near-singular H)
damp_val = percdamp * torch.diag(H).mean()
idx = torch.arange(d_col, device=H.device)
H[idx, idx] += damp_val

# Cholesky factorization of H for stable inversion
try:
H_inv = torch.linalg.inv(H)
except torch.linalg.LinAlgError:
print("Warning: H not invertible, adding larger damping")
H[idx, idx] += damp_val * 10
H_inv = torch.linalg.inv(H)

# Keep track of which weights are pruned
mask = torch.ones_like(W, dtype=torch.bool) # True = kept

for block_start in range(0, d_col, block_size):
block_end = min(block_start + block_size, d_col)
block_len = block_end - block_start

W_block = W[:, block_start:block_end].clone()
H_inv_block = H_inv[block_start:block_end, block_start:block_end]

# Error accumulator for cross-block compensation
Err = torch.zeros_like(W_block)

for col_in_block in range(block_len):
col = block_start + col_in_block
w_col = W_block[:, col_in_block]
h_inv_diag = H_inv_block[col_in_block, col_in_block]

# Importance score: loss increase from zeroing this weight
# Lower score = cheaper to prune
score = w_col.pow(2) / h_inv_diag.clamp(min=1e-8)

if use_24_pattern:
# For 2:4: in each group of 4, keep the 2 highest scores
n_groups = d_row // 4
should_prune = torch.zeros(d_row, dtype=torch.bool, device=W.device)

for g in range(n_groups):
g_start, g_end = g * 4, (g + 1) * 4
group_scores = score[g_start:g_end]
# Keep top 2 (prune bottom 2)
_, prune_in_group = group_scores.topk(2, largest=False)
for idx_in_group in prune_in_group:
should_prune[g_start + idx_in_group] = True
else:
# Unstructured: prune globally within this column
threshold = torch.quantile(score, target_sparsity)
should_prune = score <= threshold

# Zero out pruned weights and update mask
mask[:, col] = ~should_prune
q_col = w_col.clone()
q_col[should_prune] = 0.0

# Hessian compensation for remaining weights in this block
err_col = (w_col - q_col) / h_inv_diag.clamp(min=1e-8)
W_block[:, col_in_block:] -= (
err_col.unsqueeze(1)
* H_inv_block[col_in_block, col_in_block:].unsqueeze(0)
)
Err[:, col_in_block] = err_col

W[:, block_start:block_end] = W_block

# Propagate error to subsequent blocks
if block_end < d_col:
W[:, block_end:] -= Err @ H_inv[block_start:block_end, block_end:]

# Final mask enforcement
W[~mask] = 0.0
actual_sparsity = (~mask).float().mean().item()
print(f"SparseGPT: {actual_sparsity*100:.2f}% sparsity, "
f"{mask.sum().item():,} weights kept")

return W.to(weight.dtype), mask


def compute_hessian_for_pruning(
module: torch.nn.Linear,
calibration_loader,
n_samples: int = 128,
device: str = "cuda",
) -> torch.Tensor:
"""
Compute H = 2 * E[X X^T] on calibration data for SparseGPT.

This is identical to the Hessian computation in GPTQ -
SparseGPT reuses the same infrastructure.
"""
d_in = module.weight.shape[1]
H = torch.zeros(d_in, d_in, device=device)
n_seen = 0

hooks = []

def input_hook(m, inp, out):
nonlocal H, n_seen
x = inp[0].detach()

# Flatten batch and sequence dimensions
if x.dim() == 3:
x = x.reshape(-1, x.shape[-1])

batch_size = x.shape[0]
H.add_(x.T @ x)
n_seen += batch_size

hook = module.register_forward_hook(input_hook)
hooks.append(hook)

with torch.no_grad():
for i, batch in enumerate(calibration_loader):
if i >= n_samples:
break
inputs = batch if isinstance(batch, torch.Tensor) else batch[0]
inputs = inputs.to(device)
module(inputs)

for hook in hooks:
hook.remove()

# Normalize and scale: H = 2 * X^T X / n_samples
H = H * (2.0 / n_seen)
return H

Wanda: Simple but Surprisingly Effective

Wanda (Sun et al., 2023) achieves similar accuracy to SparseGPT at a fraction of the compute. The insight: instead of the full Hessian, use a simpler importance metric that combines weight magnitude with activation scale.

Importance(i,j)=WijXj2\text{Importance}(i, j) = |W_{ij}| \cdot \|X_j\|_2

A weight is important if: (1) it has large magnitude AND (2) its input feature has large activation norm. Neither factor alone is sufficient - large weights connected to silent input dimensions don't matter; small weights connected to loudly active features might matter a lot.

import torch
from typing import Optional


def compute_activation_norms(
module: torch.nn.Linear,
calibration_data: list[torch.Tensor],
device: str = "cuda",
) -> torch.Tensor:
"""
Compute per-input-dimension activation norms for Wanda importance scoring.

Returns: tensor of shape (d_in,) - the L2 norm of each input feature
averaged over calibration data
"""
d_in = module.weight.shape[1]
sum_sq = torch.zeros(d_in, device=device)
n_samples = 0

activation_collector = []

def hook_fn(m, inp, out):
x = inp[0].detach()
if x.dim() == 3:
x = x.reshape(-1, x.shape[-1])
activation_collector.append(x.cpu())

hook = module.register_forward_hook(hook_fn)

with torch.no_grad():
for x in calibration_data:
module(x.to(device))

hook.remove()

# Stack and compute norms
all_activations = torch.cat(activation_collector, dim=0) # (N, d_in)
# Per-column L2 norm
activation_norms = all_activations.norm(dim=0) # (d_in,)

return activation_norms.to(device)


def wanda_prune_layer(
weight: torch.Tensor,
activation_norms: torch.Tensor,
target_sparsity: float = 0.5,
use_24_pattern: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Wanda: Weight and Activation pruning (Sun et al., 2023).

Importance score = |W_ij| × ||X_j||
Prune row-wise: each output neuron independently determines
which input connections to keep.

Why row-wise? Because each output neuron computes a sum over input features -
it makes sense to decide per-output which inputs matter most.

Args:
weight: Weight matrix (d_out × d_in)
activation_norms: Per-input-feature L2 norm (d_in,)
target_sparsity: Fraction of weights to zero per row
use_24_pattern: Enforce 2:4 constraint

Returns:
(pruned_weight, mask)
"""
d_out, d_in = weight.shape
n_prune_per_row = int(d_in * target_sparsity)

# Importance: |W| × ||X||, broadcast activation_norms across output dim
importance = weight.abs() * activation_norms.unsqueeze(0)
# importance shape: (d_out, d_in)

mask = torch.ones_like(weight, dtype=torch.bool)

if use_24_pattern:
# For each row, in each group of 4 columns, keep top 2 by importance
assert d_in % 4 == 0, "Width must be divisible by 4 for 2:4 pattern"
importance_grouped = importance.reshape(d_out, -1, 4) # (d_out, n_groups, 4)

# Find 2 to KEEP per group (highest importance)
_, top2_indices = importance_grouped.topk(2, dim=-1, largest=True)

# Create mask from top2 indices
mask_grouped = torch.zeros_like(importance_grouped, dtype=torch.bool)
mask_grouped.scatter_(-1, top2_indices, True)
mask = mask_grouped.reshape(d_out, d_in)
else:
# Standard unstructured: per-row, prune n_prune_per_row lowest importance
# Use argsort for efficiency
sorted_importance_indices = importance.argsort(dim=1) # ascending
prune_indices = sorted_importance_indices[:, :n_prune_per_row]

# Set pruned positions to False
mask.scatter_(1, prune_indices, False)

# Apply mask
pruned_weight = weight * mask.float()

actual_sparsity = (~mask).float().mean().item()
print(f"Wanda: {actual_sparsity*100:.2f}% sparsity "
f"({'2:4 pattern' if use_24_pattern else 'unstructured'})")

return pruned_weight, mask


def run_wanda_on_model(
model: torch.nn.Module,
calibration_loader,
target_sparsity: float = 0.5,
use_24_pattern: bool = False,
skip_layers: list[str] = None,
n_calibration_batches: int = 128,
device: str = "cuda",
) -> dict[str, torch.Tensor]:
"""
Apply Wanda pruning to all linear layers of a model.

Collects activation norms layer-by-layer (memory efficient),
then prunes each layer independently.
"""
if skip_layers is None:
skip_layers = ["embed_tokens", "lm_head"]

model.eval()
all_masks = {}

for name, module in model.named_modules():
if not isinstance(module, torch.nn.Linear):
continue
if any(skip in name for skip in skip_layers):
continue

print(f"Wanda pruning: {name} {tuple(module.weight.shape)}")

# Collect calibration data for this layer
calibration_tensors = []
collector_hook = module.register_forward_hook(
lambda m, inp, out: calibration_tensors.append(inp[0].detach().cpu())
)

with torch.no_grad():
for i, batch in enumerate(calibration_loader):
if i >= n_calibration_batches:
break
x = batch if isinstance(batch, torch.Tensor) else batch[0]
model(x.to(device))

collector_hook.remove()

# Compute activation norms
all_acts = torch.cat([a.reshape(-1, a.shape[-1]) for a in calibration_tensors], dim=0)
act_norms = all_acts.norm(dim=0).to(device)

# Prune this layer
with torch.no_grad():
pruned_weight, mask = wanda_prune_layer(
module.weight.data,
act_norms,
target_sparsity=target_sparsity,
use_24_pattern=use_24_pattern,
)
module.weight.data = pruned_weight

all_masks[name] = mask.cpu()

return all_masks

Method Comparison: When to Use What

MethodAccuracy at 50%Compute CostHardware NeededMemory SavingsRecommendation
Magnitude (global)Poor (-5 to -15%)MinimalAnyNoneResearch only
Magnitude (per-layer)Moderate (-2 to -8%)MinimalAnyNoneSmall models only
SparseGPTExcellent (<1%)High (Hessian)Any (no speedup)NoneResearch/2:4 combo
WandaVery good (<2%)Low (calibration)Any (no speedup)NoneBest accuracy/cost ratio
2:4 + MagnitudeGood (~1%)LowAmpere+ GPU50% weight storageProduction on A100/H100
2:4 + WandaExcellent (<1%)LowAmpere+ GPU50% weight storageBest production option

:::danger Unstructured Sparsity Without Sparse Hardware is Useless A 90% sparse model on a standard GPU has identical latency and memory usage to the dense model - sometimes worse due to mask handling overhead. Do not deploy unstructured sparse models without verifying that your inference hardware has sparse tensor core support (NVIDIA Ampere+) or a sparse BLAS library configured (for extreme >95% sparsity on CPU). The weight files may be smaller on disk (if using sparse formats), but GPU inference won't benefit. :::

:::warning The LTH "Reset to Initialization" Step is Impractical for LLMs The original Lottery Ticket Hypothesis finding - that sparse subnetworks exist that can match full network performance when reset to their initialization and retrained - requires retraining from the original random initialization. For pre-trained LLMs, this means running the full pre-training again, which can cost millions of dollars. The practical equivalent is magnitude pruning followed by supervised fine-tuning on the target task, which achieves similar (but not identical) results. :::

Interview Questions

Q: What is the Lottery Ticket Hypothesis and why doesn't it directly apply to production LLM compression?

A: Frankle and Carlin (2019) showed that dense neural networks contain sparse subnetworks ("winning lottery tickets") that, when trained from the same random initialization, match the full network's performance. The experiment: train full network, prune by magnitude, reset remaining weights to their pre-training values, retrain. The winning ticket (sparse subnetwork + specific initialization) trains to the same accuracy as the full dense network in the same number of steps.

The LTH doesn't directly apply to production LLMs for two reasons. First, the "reset to initialization" step is impractical - for a 70B parameter LLM pre-trained on 2 trillion tokens, "retraining from initialization" means running the entire pre-training again, which costs millions of dollars. Second, the LTH results were demonstrated convincingly on smaller models (ResNets, small transformers) where the winning ticket sparsity was 80-95%; for large pre-trained LLMs, the effective winning ticket sparsity appears much lower, and the accuracy gap from retraining from scratch (vs. fine-tuning the pruned model) is larger. The practical takeaway from LTH for LLM practitioners: train larger models, prune with magnitude + fine-tune (not reset), and accept that some accuracy will be lost compared to the dense model.

Q: Why does 50% unstructured sparsity typically yield zero speedup on standard GPUs?

A: Standard GPU matrix multiplication operates in tiles. When computing C=ABC = AB, both matrices are loaded into SRAM in 128×128 (or similar) tile blocks. The Tensor Cores process all multiplications in a tile simultaneously - they don't inspect individual elements to decide whether to skip zeros. The hardware executes all 128×128=16384128 \times 128 = 16384 multiplications in the tile regardless of how many are multiplied by zero.

The memory bandwidth problem is equally fundamental: zeros are still stored in the weight matrix (they occupy bytes in HBM), so loading a 50% sparse weight matrix still reads the same number of bytes as the dense matrix. Neither compute nor memory bandwidth improves. Without hardware that: (a) stores zeros in a compressed format and (b) physically skips zero multiplications, unstructured sparsity provides zero inference speedup. NVIDIA Sparse Tensor Cores (Ampere+) are specifically designed to do both for the 2:4 pattern.

Q: What is 2:4 structured sparsity and how does NVIDIA hardware accelerate it?

A: 2:4 sparsity requires exactly 2 non-zero values per every 4 consecutive elements in each row of a weight matrix. NVIDIA Ampere+ GPUs include sparse tensor cores that natively support this format. The storage format: only the 2 non-zero values per group are stored (half the original storage), along with 2 bits of metadata indicating their positions within the group of 4.

During matrix multiply, the sparse tensor cores use the 2-bit metadata to efficiently route only the non-zero values into the multiply-accumulate units, achieving approximately 2× throughput compared to dense computation. The constraint (exactly 2:4) is regular enough for hardware to handle efficiently, unlike arbitrary unstructured sparsity. To apply 2:4 sparsity: for each group of 4 consecutive weights in a row, zero out the 2 with smallest magnitude (or lowest Wanda/SparseGPT importance score). The accuracy cost is typically < 1% on standard benchmarks because the pattern is flexible enough (which 2 out of 4 to keep is optimized per row group) to preserve the most important weights.

Q: How does SparseGPT differ from magnitude pruning, and when is the extra compute justified?

A: Magnitude pruning zeros out the smallest-absolute-value weights globally or per-layer, with no compensation for the error introduced. SparseGPT uses second-order information (the Hessian H=2E[XXT]H = 2\mathbb{E}[XX^T]) to compensate: after zeroing a weight, it adjusts the remaining weights in the same layer using the inverse Hessian to partially cancel the output error. The update rule: for weight wjw_j zeroed, remaining weights are adjusted by wjHjj1Hj,:1-\frac{w_j}{H^{-1}_{jj}} \cdot H^{-1}_{j,:}.

The accuracy difference is significant: at 50% sparsity, magnitude pruning typically increases WikiText-2 perplexity by 3-15 points for LLMs; SparseGPT achieves <1 point increase. The extra compute: Hessian computation from calibration data (128 samples, one forward pass) plus Cholesky factorization of each layer's Hessian matrix. For large LLMs, this is hours of preprocessing compute but is one-time. Wanda is a practical middle ground: uses Wij×Xj|W_{ij}| \times \|X_j\| as importance score (no Hessian needed), achieves accuracy within 1-2% of SparseGPT at a fraction of the compute.

Q: What compression stack would you use for maximum LLM compression while retaining the best accuracy on NVIDIA A100s?

A: The optimal stack for A100 (which supports 2:4 sparsity):

  1. Quantization first: Apply AWQ or GPTQ to quantize to INT4. This is the dominant memory reduction - 4× from FP16. Verify perplexity stays within 3% of FP16.

  2. 2:4 sparsity second: Apply Wanda or SparseGPT with 2:4 constraint on the quantized weights. This halves the storage again and enables sparse tensor core acceleration. Combined with INT4: effectively 2-bit average weight representation with hardware acceleration.

  3. Calibration data: Use 128-512 samples from the target domain for both AWQ and Wanda/SparseGPT. Domain mismatch in calibration data (e.g., using Wikipedia for a coding model) can add 2-5% accuracy loss.

  4. Recovery fine-tuning (if needed): If combined compression causes >5% task-specific accuracy loss, run 1-3 epochs of fine-tuning with frozen quantization (unquantize → fine-tune → re-quantize). This typically recovers 50-70% of the accuracy loss.

  5. Benchmark before shipping: Run task-specific benchmarks (arithmetic, code, retrieval) not just perplexity. Combined compression disproportionately hurts multi-step reasoning.

Expected result: 7B model in ~1.5 GB with ~1.7× sparse tensor core speedup on A100, ~3-5% accuracy loss on standard benchmarks versus full FP16.

© 2026 EngineersOfAI. All rights reserved.