:::tip 🎮 Interactive Playground Visualize this concept: Try the Knowledge Distillation demo on the EngineersOfAI Playground - no code required. :::
Structured Pruning: Removing Entire Building Blocks from Transformers
The 380ms Problem That Quantization Couldn't Fix
The inference team at a mid-sized fintech company had just shipped INT4 quantization across their document-processing pipeline. Memory usage was excellent - the model now fit on a single A10G instead of two, cutting GPU costs in half. Generation speed was up 2.1x. The team celebrated. Then the product manager opened the dashboard.
"We're at 380 milliseconds per response. The UX team says anything above 200ms feels laggy to users. Finance is asking whether to kill the feature." The product manager wasn't wrong - at 380ms, the interactive assistant felt sluggish. Users would type a question, watch the cursor blink, watch it blink again, then see text start to appear. The product instinct was correct: 200ms was the threshold between "responsive" and "waiting."
The engineering lead dug into the profiling data. INT4 quantization had optimized the memory bandwidth bottleneck - loading weights from HBM (high-bandwidth memory) was 4x faster. But the model was still performing a full forward pass through all 32 transformer layers on every token generation. Each layer contained 32 attention heads and 11,008 intermediate MLP neurons. The total operation count was unchanged. Quantization made reading weights faster; it did nothing to reduce the number of operations.
What the engineer needed was a different kind of compression - one that reduced the model's computational depth rather than just its memory footprint. They ran a quick analysis: measuring the angular distance (change in hidden state) between the input and output of each of the 32 transformer layers using 200 calibration examples. The result was stark. Eight layers had angular distances below 0.02 - their outputs were nearly identical to their inputs, like a relay station that receives a signal and passes it through unchanged. Four of those layers were contiguous blocks in the middle of the network.
Six hours later, after removing 6 layers and 22% of attention heads with low Taylor importance scores, then running 3 epochs of recovery fine-tuning, the model landed at 204ms. Not quite 200ms - but after combining this with the earlier INT4 quantization, the full pipeline reached 187ms. The product manager approved the rollout. The feature shipped.
This is what structured pruning delivers: actual latency improvements through reduced operation counts, not just memory savings. This lesson covers how to do it correctly, from measuring importance scores to combining pruning with quantization in a production pipeline.
Why Structure Matters for Hardware
To understand why structured pruning works when unstructured sparsity doesn't, you need a mental model of how GPUs execute matrix multiplies.
An A100 GPU processes matrix multiplications in tiles. When computing where and , both matrices are partitioned into 64×64 or 128×128 tiles that are loaded from HBM into SRAM (the fast on-chip cache), then processed on Tensor Cores. Tensor Cores execute a fixed-size matrix-multiply instruction - they compute an entire tile in one operation. The GPU does not inspect individual elements to decide whether to skip a multiply. It executes the whole tile regardless of whether some values are zero.
This is why unstructured sparsity fails to deliver latency improvements. Set 50% of a weight matrix's values to zero, and the GPU still:
- Loads the full matrix from HBM (same memory bandwidth)
- Partitions it into the same tiles (same tile count)
- Executes the same number of Tensor Core instructions (same compute)
- The zeros just produce 0 outputs - contributing nothing, but costing the same compute time
Structured pruning changes tensor shapes. Remove 25% of attention heads in a transformer layer: the Q, K, V projection matrices shrink from to . The GPU now multiplies genuinely smaller matrices. Fewer tiles. Fewer Tensor Core instructions. Fewer bytes loaded from HBM. The speedup is real and shows up in profiling.
The one exception worth knowing: NVIDIA's 2:4 structured sparsity format (exactly 2 non-zero values per every 4 consecutive elements) has dedicated hardware support on Ampere and later GPUs. The sparse matrix multiply unit can skip zero multiplications in this specific pattern, delivering ~1.5x speedup. But this is a constrained pattern that requires the network to be specifically trained or fine-tuned into this sparsity structure - it's not the same as general unstructured pruning.
The Three Levels of Structured Pruning
Structured pruning operates at three granularities, each with different accuracy/latency tradeoffs:
| Level | What is removed | Typical fraction removed | Latency improvement | Accuracy cost |
|---|---|---|---|---|
| Attention head | Individual heads from multi-head attention | 20-35% of heads | 1.15-1.4x | 0.5-2% |
| MLP neuron | Intermediate neurons in feed-forward layers | 20-40% of neurons | 1.2-1.6x | 1-3% |
| Transformer layer | Entire encoder/decoder blocks | 15-25% of layers | 1.2-1.8x | 1-4% |
| Combined | All three simultaneously | varies | 2-4x | 3-7% |
These techniques are complementary and multiplicative. A model with 20% of layers removed and 25% of heads removed achieves roughly the FLOP reduction before any quantization. Combined with INT4 quantization, total speedup can reach 4-6x over the original FP16 baseline.
:::info Why Prune Before Quantize? The correct order is always prune first, then quantize - not the other way around. Quantization algorithms like AWQ and GPTQ optimize weight scaling for a fixed model structure. If you quantize first, the INT4 weights are packed in a format that cannot be sliced to remove heads or neurons. If you prune first, the pruned model is a normal FP16 model that can be quantized using any standard tool. The exception: if using bitsandbytes for serving (not GPTQ/AWQ export), you can load in 4-bit and still measure importance scores - but recovery fine-tuning requires dequantization first. :::
Attention Head Pruning: Theory and Implementation
Multi-head attention computes independent attention patterns, then concatenates and projects them:
where each .
Research beginning with Michel et al. (2019) showed that in BERT-base, the vast majority of attention heads can be removed without meaningful accuracy loss for most tasks. Voita et al. (2019) demonstrated that most BERT heads fall into a handful of interpretable categories (positional, syntactic dependency, rare word attention), with many heads computing nearly identical attention distributions. The redundancy is real and consistent across models.
Method 1: Taylor Expansion Importance
The Taylor expansion head importance estimates how much removing a head would increase the training loss. For head in layer , the importance is:
This is a first-order approximation: if the gradient-weight product is large, changing those weights (or zeroing them out) will have a large effect on the loss. Heads with small gradient-weight products are likely redundant.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, List, Tuple
import numpy as np
from dataclasses import dataclass, field
@dataclass
class PruningConfig:
"""Configuration for structured pruning pipeline."""
head_pruning_fraction: float = 0.25 # Remove 25% of heads globally
layer_pruning_fraction: float = 0.20 # Remove 20% of layers
neuron_pruning_fraction: float = 0.30 # Remove 30% of MLP neurons
protect_first_n_layers: int = 2 # Never prune first N layers
protect_last_n_layers: int = 2 # Never prune last N layers
min_heads_per_layer: int = 2 # Always keep at least N heads per layer
n_calibration_batches: int = 100 # Batches for importance scoring
importance_method: str = "taylor" # "taylor", "magnitude", or "output_norm"
recovery_epochs: int = 3 # Epochs of fine-tuning after pruning
recovery_lr: float = 1e-5 # Lower than original training LR
def compute_head_importance_taylor(
model: nn.Module,
dataloader,
n_batches: int = 100,
device: str = "cuda",
) -> torch.Tensor:
"""
Taylor expansion head importance: E[|gradient * weight|].
For each attention head, the importance is the expected absolute value
of the product of gradients and weights - a first-order approximation
of how much the loss would increase if this head were removed.
High importance score = removing this head would hurt accuracy significantly.
Low importance score = this head contributes little and can be safely removed.
Args:
model: BERT-style transformer with model.encoder.layer
dataloader: Yields dicts with 'input_ids', 'attention_mask', 'labels'
n_batches: Number of batches for Monte Carlo estimation of E[|g*w|]
device: Compute device
Returns:
importance: FloatTensor of shape (n_layers, n_heads)
"""
model.train() # Need gradients
model.to(device)
n_layers = model.config.num_hidden_layers
n_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // n_heads
importance_scores = torch.zeros(n_layers, n_heads, device=device)
n_batches_processed = 0
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= n_batches:
break
batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
# Forward pass with gradient tracking
model.zero_grad()
outputs = model(**batch)
loss = outputs.loss
loss.backward()
# Extract gradient-weight products from query projections
for layer_idx, layer in enumerate(model.encoder.layer):
q_weight = layer.attention.self.query.weight # (hidden_size, hidden_size)
if q_weight.grad is None:
continue
# Reshape to (n_heads, head_dim, hidden_size) for per-head analysis
grad = q_weight.grad.view(n_heads, head_dim, -1)
weight = q_weight.data.view(n_heads, head_dim, -1)
# Taylor importance per head: sum |grad * weight| over all parameters in the head
head_importance = (grad * weight).abs().sum(dim=[1, 2]) # (n_heads,)
importance_scores[layer_idx] += head_importance.detach()
n_batches_processed += 1
if (batch_idx + 1) % 20 == 0:
print(f" Taylor importance: {batch_idx + 1}/{n_batches} batches")
model.eval()
return importance_scores / max(n_batches_processed, 1)
Method 2: Magnitude-Based Importance
Simpler than Taylor, but surprisingly competitive for moderate pruning targets. The importance of a head is the L2 norm of its query projection weights - larger-norm heads typically encode more information:
def compute_head_importance_magnitude(
model: nn.Module,
) -> torch.Tensor:
"""
Magnitude-based head importance: L2 norm of query weights per head.
No calibration data or gradients needed - purely weight-based.
Fast to compute, reasonable correlation with task importance for
moderate pruning (under 30%). Can underperform Taylor at high sparsity
because weight magnitude doesn't capture activation statistics.
Returns:
importance: FloatTensor of shape (n_layers, n_heads)
"""
n_layers = model.config.num_hidden_layers
n_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // n_heads
importance = torch.zeros(n_layers, n_heads)
for layer_idx, layer in enumerate(model.encoder.layer):
q_weight = layer.attention.self.query.weight.data # (hidden_size, hidden_size)
# Reshape to separate heads
q_heads = q_weight.view(n_heads, head_dim, -1)
# L2 norm per head over all weight elements
head_norms = q_heads.norm(dim=[1, 2])
importance[layer_idx] = head_norms.cpu()
return importance
def compute_head_importance_output_norm(
model: nn.Module,
dataloader,
n_batches: int = 50,
device: str = "cuda",
) -> torch.Tensor:
"""
Output-norm head importance: E[||head_output||_2].
Heads with consistently small output norms contribute little to the next
layer's residual stream. This is the activation-space version of magnitude:
rather than measuring weight norms, we measure output norms at runtime.
Requires forward hooks into the attention mechanism to capture per-head outputs.
The hook captures the attention output BEFORE the output projection W^O.
Returns:
importance: FloatTensor of shape (n_layers, n_heads)
"""
model.eval()
n_layers = model.config.num_hidden_layers
n_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // n_heads
importance = torch.zeros(n_layers, n_heads, device=device)
hooks = []
n_samples = 0
def make_head_output_hook(layer_idx: int):
def hook(module, inputs, output):
# For BERT's BertSelfAttention, output is (context_layer, attention_probs_or_None)
# context_layer shape: (batch, seq_len, n_heads * head_dim)
attn_output = output[0] if isinstance(output, tuple) else output
batch, seq_len, hidden = attn_output.shape
# Reshape to per-head: (batch, seq_len, n_heads, head_dim)
per_head = attn_output.view(batch, seq_len, n_heads, head_dim)
# L2 norm over head_dim, averaged over batch and sequence
head_norms = per_head.norm(dim=-1).mean(dim=[0, 1]) # (n_heads,)
importance[layer_idx] += head_norms.detach()
return hook
for layer_idx, layer in enumerate(model.encoder.layer):
h = layer.attention.self.register_forward_hook(make_head_output_hook(layer_idx))
hooks.append(h)
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= n_batches:
break
batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
model(**batch)
n_samples += 1
for h in hooks:
h.remove()
return importance / max(n_samples, 1)
Selecting and Applying Head Pruning
Once you have importance scores, you need a selection strategy that balances global and per-layer constraints:
def select_heads_to_prune(
importance_scores: torch.Tensor,
pruning_fraction: float = 0.25,
protect_layers: Optional[List[int]] = None,
min_heads_per_layer: int = 2,
) -> Dict[int, List[int]]:
"""
Select which attention heads to prune using global importance ranking.
Uses a global threshold rather than per-layer: this finds the globally
least important heads across ALL layers, which is more effective than
pruning the bottom N% from each layer independently.
Per-layer pruning can over-prune layers that are all important,
and under-prune layers with many redundant heads.
Args:
importance_scores: FloatTensor of shape (n_layers, n_heads)
pruning_fraction: Fraction of ALL heads to remove globally
protect_layers: List of layer indices to never prune (e.g., [0, 1, -1, -2])
min_heads_per_layer: Minimum heads that must remain in any layer
Returns:
Dict mapping layer_idx -> sorted list of head indices to prune
"""
n_layers, n_heads = importance_scores.shape
# Build protection mask
protected = torch.zeros(n_layers, n_heads, dtype=torch.bool)
if protect_layers:
for l in protect_layers:
l = l % n_layers # Handle negative indices
protected[l, :] = True
# Clone scores; mark protected positions as infinity (never chosen for pruning)
scores_masked = importance_scores.clone().float()
scores_masked[protected] = float('inf')
# Enforce min_heads_per_layer: protect the top-K heads in each layer
for layer_idx in range(n_layers):
layer_scores = scores_masked[layer_idx]
# Protect the top min_heads_per_layer by setting them to inf
top_k = torch.topk(layer_scores, min_heads_per_layer).indices
scores_masked[layer_idx, top_k] = float('inf')
# Global ranking: find the n_to_prune globally lowest-importance heads
flat_scores = scores_masked.reshape(-1)
n_total = flat_scores.shape[0]
n_to_prune = int(n_total * pruning_fraction)
if n_to_prune == 0:
print("Warning: 0 heads selected for pruning (fraction too small or all protected)")
return {}
sorted_scores, sorted_indices = torch.sort(flat_scores)
# Find the threshold at position n_to_prune
threshold = sorted_scores[n_to_prune - 1].item()
# Convert flat indices back to (layer, head) coordinates
heads_to_prune: Dict[int, List[int]] = {}
for layer_idx in range(n_layers):
heads = []
for head_idx in range(n_heads):
score = scores_masked[layer_idx, head_idx].item()
if score <= threshold and score < float('inf'):
heads.append(head_idx)
if heads:
heads_to_prune[layer_idx] = sorted(heads)
total_pruned = sum(len(h) for h in heads_to_prune.values())
total_heads = n_layers * n_heads
print(f"Head pruning plan: {total_pruned}/{total_heads} heads "
f"({total_pruned/total_heads*100:.1f}%) across {len(heads_to_prune)} layers")
return heads_to_prune
def prune_attention_heads(
model: nn.Module,
heads_to_prune: Dict[int, List[int]],
) -> None:
"""
Remove attention heads from the model in-place by slicing weight matrices.
For each head removed, we slice rows from Q, K, V projections
and the corresponding columns from the output projection W^O.
The operation is irreversible - pruned weights are deleted.
Always checkpoint the original model before calling this function.
Mathematical view:
Before: [head_0 | head_1 | ... | head_h] @ W^O
After: [head_0 | ... | head_{h-k}] @ W^O_pruned
(W^O columns corresponding to removed heads are also removed)
Args:
model: BERT-style transformer
heads_to_prune: Output of select_heads_to_prune()
"""
n_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size // n_heads
for layer_idx, head_indices in sorted(heads_to_prune.items()):
layer = model.encoder.layer[layer_idx]
attn_self = layer.attention.self
heads_to_keep = sorted(set(range(n_heads)) - set(head_indices))
if not heads_to_keep:
print(f" Layer {layer_idx}: Would prune ALL heads - skipping (bug in selection logic)")
continue
# Compute which rows/columns in the weight matrices correspond to kept heads
# Q, K, V weights: (hidden_size, hidden_size) = (n_heads*head_dim, hidden_size)
# Rows 0..head_dim-1 belong to head 0, rows head_dim..2*head_dim-1 to head 1, etc.
kept_dims = []
for h in heads_to_keep:
kept_dims.extend(range(h * head_dim, (h + 1) * head_dim))
kept_tensor = torch.tensor(kept_dims, dtype=torch.long)
with torch.no_grad():
# Prune Q, K, V: remove rows corresponding to pruned heads
for proj_name in ["query", "key", "value"]:
proj = getattr(attn_self, proj_name)
new_weight = proj.weight.data[kept_tensor, :]
proj.weight = nn.Parameter(new_weight)
if proj.bias is not None:
proj.bias = nn.Parameter(proj.bias.data[kept_tensor])
proj.out_features = len(kept_dims)
# Prune output projection W^O: remove columns for pruned heads
out_proj = layer.attention.output.dense
new_weight = out_proj.weight.data[:, kept_tensor]
out_proj.weight = nn.Parameter(new_weight)
out_proj.in_features = len(kept_dims)
# Update config to reflect new head count
# Note: this changes the global config - careful if different layers have different head counts
print(f" Layer {layer_idx}: {n_heads} -> {len(heads_to_keep)} heads "
f"(removed: {head_indices})")
# Update the global head count (simplified - assumes uniform pruning across layers)
# In production: track per-layer head counts separately
remaining_head_counts = []
for layer_idx, layer in enumerate(model.encoder.layer):
q_out_features = layer.attention.self.query.out_features
remaining_heads = q_out_features // head_dim
remaining_head_counts.append(remaining_heads)
min_heads = min(remaining_head_counts)
model.config.num_attention_heads = min_heads
print(f"\nAttention heads updated: min={min_heads}, counts per layer: {remaining_head_counts}")
:::warning The Config Update Problem
When pruning attention heads, model.config.num_attention_heads needs to be updated. But if you prune different numbers of heads per layer (which global pruning produces), a single config value cannot accurately represent the model. This is a known pain point. Options: (1) store per-layer head counts in a custom config attribute, (2) only do uniform pruning (same fraction per layer), or (3) use HuggingFace's built-in prune_heads() method which handles this correctly for supported model architectures.
:::
Layer Pruning: The Highest-Impact Technique
Layer pruning removes entire transformer blocks. It is the most aggressive form of structured pruning and also the most hardware-efficient - fewer sequential operations means direct proportional latency reduction, and sequential layers cannot be parallelized the way intra-layer computations sometimes can.
The Angular Distance Metric (ShortGPT)
The key insight from the ShortGPT paper (Men et al., 2024): measure how much each layer changes its input. If a layer's output is nearly identical to its input - if the hidden state barely moves - the layer is doing almost no useful work.
Formally, the Block Influence (BI) score is:
where and are the flattened input and output hidden states of layer . A score near 0 means the layer barely modifies its input (low importance). A large score means the layer significantly transforms the representation (high importance).
def compute_layer_angular_importance(
model: nn.Module,
dataloader,
n_batches: int = 100,
device: str = "cuda",
) -> torch.Tensor:
"""
Compute layer Block Influence (BI) scores via angular distance.
From ShortGPT (Men et al., 2024): layers with low BI scores barely change
their input representations and can be removed with minimal accuracy impact.
ShortGPT applied this to Llama-2-13B and removed 25% of layers with only
2-3% perplexity increase on WikiText-2.
Returns:
importance: FloatTensor of shape (n_layers,)
Higher = more important (do NOT prune)
Lower = likely redundant (prune these)
"""
model.eval()
model.to(device)
n_layers = model.config.num_hidden_layers
layer_inputs: Dict[int, torch.Tensor] = {}
layer_outputs: Dict[int, torch.Tensor] = {}
hooks = []
def make_input_hook(layer_idx: int):
def hook(module, args, kwargs):
# args[0] is the hidden states passed to the layer
if args:
layer_inputs[layer_idx] = args[0].detach().clone()
return hook
def make_output_hook(layer_idx: int):
def hook(module, args, output):
# BERT layers return (hidden_states, ...) tuple or just hidden_states
if isinstance(output, tuple):
layer_outputs[layer_idx] = output[0].detach().clone()
else:
layer_outputs[layer_idx] = output.detach().clone()
return hook
# Register hooks on each encoder layer
for layer_idx, layer in enumerate(model.encoder.layer):
hooks.append(layer.register_forward_pre_hook(make_input_hook(layer_idx)))
hooks.append(layer.register_forward_hook(make_output_hook(layer_idx)))
angular_distances = torch.zeros(n_layers)
n_samples_processed = 0
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= n_batches:
break
batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
model(**batch)
for layer_idx in range(n_layers):
if layer_idx not in layer_inputs or layer_idx not in layer_outputs:
continue
inp = layer_inputs[layer_idx].float() # (batch, seq_len, hidden_size)
out = layer_outputs[layer_idx].float() # (batch, seq_len, hidden_size)
# Flatten spatial dimensions for cosine similarity computation
inp_flat = inp.reshape(inp.shape[0], -1) # (batch, seq_len * hidden_size)
out_flat = out.reshape(out.shape[0], -1)
# Cosine similarity between input and output per sample in batch
cos_sim = F.cosine_similarity(inp_flat, out_flat, dim=1) # (batch,)
# Angular distance: 1 - cos_sim (near 0 = barely changed, near 1 = major change)
angular_dist = (1.0 - cos_sim).mean().item()
angular_distances[layer_idx] += angular_dist
n_samples_processed += 1
for hook in hooks:
hook.remove()
result = angular_distances / max(n_samples_processed, 1)
# Print sorted importances for inspection
sorted_idx = torch.argsort(result)
print("\nLayer angular distances (ascending - most prunable first):")
for i, idx in enumerate(sorted_idx[:8]):
print(f" Layer {idx.item():3d}: BI = {result[idx.item()]:.4f}")
return result
def select_layers_to_prune(
angular_importance: torch.Tensor,
pruning_fraction: float = 0.20,
protect_first_n: int = 2,
protect_last_n: int = 2,
) -> List[int]:
"""
Select layers to prune based on angular distance importance scores.
Protects the first and last N layers unconditionally - these have
consistently been found to be the most important in every studied architecture.
First layers: contextualize token embeddings (high angular distance always)
Last layers: prepare representations for output heads (task-critical)
Args:
angular_importance: FloatTensor of shape (n_layers,)
pruning_fraction: Fraction of total layers to remove
protect_first_n: Number of initial layers to never prune (recommended: 2)
protect_last_n: Number of final layers to never prune (recommended: 2)
Returns:
Sorted list of layer indices to remove
"""
n_layers = angular_importance.shape[0]
n_to_prune = int(n_layers * pruning_fraction)
# Mark protected layers
protected = torch.zeros(n_layers, dtype=torch.bool)
protected[:protect_first_n] = True
protected[n_layers - protect_last_n:] = True
# Clone and mask protected layers so they won't be selected
masked_importance = angular_importance.clone()
masked_importance[protected] = float('inf') # Effectively exclude from selection
# Select layers with the lowest angular distance (most redundant)
_, sorted_indices = torch.sort(masked_importance)
layers_to_prune = sorted(sorted_indices[:n_to_prune].tolist())
print(f"\nLayer pruning plan: removing {len(layers_to_prune)}/{n_layers} layers")
print(f"Layers to remove: {layers_to_prune}")
print(f"Their BI scores: {[f'{angular_importance[i]:.4f}' for i in layers_to_prune]}")
print(f"Protected (first {protect_first_n}, last {protect_last_n}): not pruned")
return layers_to_prune
def prune_layers(
model: nn.Module,
layer_indices_to_remove: List[int],
) -> None:
"""
Remove entire transformer blocks from the model.
This is irreversible - always save the original model first.
After removal, layer indices shift: if you remove layer 5, the old
layer 6 becomes the new layer 5. The function handles this correctly
by removing in reverse index order.
Args:
model: Transformer model with model.encoder.layer (ModuleList)
layer_indices_to_remove: List of layer indices to delete
"""
if not layer_indices_to_remove:
return
n_before = len(model.encoder.layer)
# Remove in reverse index order to preserve validity of remaining indices
for idx in sorted(layer_indices_to_remove, reverse=True):
if idx >= len(model.encoder.layer):
print(f" Warning: layer index {idx} out of range - skipping")
continue
del model.encoder.layer[idx]
n_after = len(model.encoder.layer)
model.config.num_hidden_layers = n_after
removed = n_before - n_after
print(f"\nLayer pruning applied: {n_before} -> {n_after} layers "
f"(removed {removed}: indices {sorted(layer_indices_to_remove)})")
:::tip The ShortGPT Finding on LLMs In LLaMA-2-13B, ShortGPT found that layers 6-18 (roughly the middle third of the network) had noticeably lower BI scores than layers 0-5 and 20-39. This is consistent with mechanistic interpretability research showing that early layers build syntactic structure, late layers build task-specific representations, and middle layers often perform less clearly interpretable redundant transformations. When in doubt about which layers to prune first, start with middle-network layers and work outward. :::
MLP Neuron Pruning: Targeting Dead Neurons
Feed-forward networks in transformers use a two-matrix structure:
where and . The intermediate dimension for most architectures (BERT: 3072 for base, 4096 for large; Llama: uses a gated architecture with two up-projection matrices).
Many of the intermediate neurons are "dead" - they output values near zero for the vast majority of inputs because the GELU activation suppresses negative pre-activations to near-zero. Neurons that are consistently near-zero contribute negligibly to the output and can be removed.
def compute_neuron_activation_frequency(
model: nn.Module,
dataloader,
n_batches: int = 100,
activation_threshold: float = 0.01,
device: str = "cuda",
) -> torch.Tensor:
"""
Measure how often each MLP neuron produces output above a threshold.
Neurons with activation frequency below 1-2% are strong candidates
for removal - they contribute negligible expected output.
Uses forward hooks to capture intermediate layer outputs (post-activation,
after GELU/SiLU). The hook measures the fraction of tokens where
|neuron_output| > activation_threshold.
Args:
model: BERT-style transformer
dataloader: Calibration data
n_batches: Number of batches for frequency estimation
activation_threshold: Neuron considered "active" if |output| > this
device: Compute device
Returns:
activation_freq: FloatTensor of shape (n_layers, intermediate_size)
Values in [0, 1] - fraction of tokens each neuron was active
"""
model.eval()
model.to(device)
n_layers = model.config.num_hidden_layers
d_ff = model.config.intermediate_size
activation_counts = torch.zeros(n_layers, d_ff)
n_tokens_total = 0
hooks = []
def make_intermediate_hook(layer_idx: int):
def hook(module, inputs, output):
# output shape: (batch, seq_len, intermediate_size) - post-activation
# Count positions where |output| > threshold (on CPU to avoid blocking GPU)
activated = (output.detach().abs() > activation_threshold).float()
batch_size, seq_len, _ = activated.shape
activation_counts[layer_idx] += activated.cpu().sum(dim=[0, 1])
return hook
# For BERT: layer.intermediate contains the linear + activation
for layer_idx, layer in enumerate(model.encoder.layer):
h = layer.intermediate.register_forward_hook(make_intermediate_hook(layer_idx))
hooks.append(h)
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= n_batches:
break
batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
model(**batch)
n_tokens_total += batch["input_ids"].numel()
for h in hooks:
h.remove()
activation_freq = activation_counts / max(n_tokens_total, 1)
# Print summary statistics
mean_freq = activation_freq.mean().item()
dead_fraction = (activation_freq < 0.01).float().mean().item()
print(f"\nNeuron activation analysis:")
print(f" Mean activation frequency: {mean_freq:.3f}")
print(f" Neurons active <1% of time: {dead_fraction*100:.1f}%")
return activation_freq
def select_neurons_to_prune(
activation_freq: torch.Tensor,
pruning_fraction: float = 0.30,
) -> Dict[int, List[int]]:
"""
Select MLP neurons to prune based on activation frequency.
Globally ranks neurons by activation frequency and removes the
least-active fraction. Dead neurons (< 1% activation frequency)
are always good candidates.
Args:
activation_freq: FloatTensor (n_layers, intermediate_size)
pruning_fraction: Fraction of all neurons to remove
Returns:
Dict mapping layer_idx -> sorted list of neuron indices to prune
"""
n_layers, d_ff = activation_freq.shape
n_total = n_layers * d_ff
n_to_prune = int(n_total * pruning_fraction)
flat_freq = activation_freq.reshape(-1)
sorted_vals, sorted_idx = torch.sort(flat_freq)
# Threshold: the n_to_prune'th lowest activation frequency
threshold = sorted_vals[n_to_prune - 1].item()
neurons_to_prune: Dict[int, List[int]] = {}
for layer_idx in range(n_layers):
neurons = [
neuron_idx for neuron_idx in range(d_ff)
if activation_freq[layer_idx, neuron_idx].item() <= threshold
]
if neurons:
neurons_to_prune[layer_idx] = sorted(neurons)
total_pruned = sum(len(v) for v in neurons_to_prune.values())
print(f"Neuron pruning plan: {total_pruned}/{n_total} neurons "
f"({total_pruned/n_total*100:.1f}%) across {len(neurons_to_prune)} layers")
return neurons_to_prune
def prune_mlp_neurons(
model: nn.Module,
neurons_to_prune: Dict[int, List[int]],
) -> None:
"""
Remove MLP neurons from the model in-place.
For each removed neuron j in a layer's FFN:
- Remove row j from W1 (intermediate/up-projection)
- Remove bias element j from b1
- Remove column j from W2 (output/down-projection)
This is the weight surgery view of removing one column from the
neuron basis of the FFN. The remaining neurons are unaffected.
Args:
model: BERT-style transformer
neurons_to_prune: Output of select_neurons_to_prune()
"""
d_ff = model.config.intermediate_size
for layer_idx, neuron_indices in sorted(neurons_to_prune.items()):
layer = model.encoder.layer[layer_idx]
neurons_to_keep = sorted(set(range(d_ff)) - set(neuron_indices))
if not neurons_to_keep:
print(f" Layer {layer_idx}: Would prune ALL neurons - skipping")
continue
kept = torch.tensor(neurons_to_keep, dtype=torch.long)
with torch.no_grad():
# Prune W1 (intermediate dense): rows → output neurons
# Shape: (intermediate_size, hidden_size) → (kept, hidden_size)
w1 = layer.intermediate.dense
w1.weight = nn.Parameter(w1.weight.data[kept, :])
if w1.bias is not None:
w1.bias = nn.Parameter(w1.bias.data[kept])
w1.out_features = len(neurons_to_keep)
# Prune W2 (output dense): columns → input neurons
# Shape: (hidden_size, intermediate_size) → (hidden_size, kept)
w2 = layer.output.dense
w2.weight = nn.Parameter(w2.weight.data[:, kept])
w2.in_features = len(neurons_to_keep)
n_removed = len(neuron_indices)
print(f" Layer {layer_idx}: {d_ff} -> {len(neurons_to_keep)} neurons "
f"(removed {n_removed}, {n_removed/d_ff*100:.1f}%)")
model.config.intermediate_size = len(neurons_to_keep) if neurons_to_prune else d_ff
Recovery Fine-Tuning: Adapting to the Pruned Structure
Pruning removes capacity. Even perfectly chosen pruning targets cause some accuracy loss because the remaining weights were trained with the now-removed components in mind. Recovery fine-tuning adapts the surviving weights to their new, smaller context.
The key insight: use a lower learning rate than original pre-training. The remaining weights already encode valuable representations. High learning rates will destroy them and the model will need to relearn from scratch - losing the benefit of the pre-trained initialization.
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
def recovery_fine_tune(
pruned_model: nn.Module,
tokenizer,
train_dataset,
eval_dataset,
output_dir: str,
n_epochs: int = 3,
learning_rate: float = 1e-5,
per_device_batch_size: int = 32,
warmup_ratio: float = 0.10,
) -> None:
"""
Fine-tune a pruned model to recover accuracy.
Learning rate schedule:
Original pre-training: typically 1e-4 to 5e-4
Original fine-tuning: typically 2e-5 to 5e-5
Recovery after pruning: 5e-6 to 2e-5 ← use this range
Too-high LR after pruning: catastrophic forgetting (worse than before)
Too-low LR after pruning: slow recovery, may not converge in 2-3 epochs
The warmup_ratio=0.10 provides a gradual ramp-up, which is especially
important after pruning because the model's internal activations are
initially incoherent until weights adapt to the removed components.
Args:
n_epochs: 2-5 is typical; more risks overfitting to the task
learning_rate: Lower than original fine-tuning - recommended 1e-5
warmup_ratio: Fraction of steps for LR warmup (0.05-0.15)
"""
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=n_epochs,
per_device_train_batch_size=per_device_batch_size,
per_device_eval_batch_size=per_device_batch_size * 2,
learning_rate=learning_rate,
lr_scheduler_type="cosine",
warmup_ratio=warmup_ratio,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=True,
logging_steps=50,
dataloader_num_workers=4,
report_to="none", # Disable wandb/tensorboard unless configured
)
trainer = Trainer(
model=pruned_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)
print(f"Starting recovery fine-tuning: {n_epochs} epochs, lr={learning_rate}")
train_result = trainer.train()
print(f"\nRecovery complete:")
print(f" Steps: {train_result.global_step}")
print(f" Train loss: {train_result.training_loss:.4f}")
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Recovered model saved to: {output_dir}")
Iterative Pruning: The Right Way to Aggressive Compression
One-shot pruning (remove 30% of heads in one step, then fine-tune) consistently produces worse results than iterative pruning (remove 10%, fine-tune, remove 10%, fine-tune, remove 10%, fine-tune). The iterative approach is more expensive in training compute but achieves better accuracy at the same final sparsity.
The intuition: each fine-tuning step lets the surviving weights reorganize to compensate for what was removed. This reorganization changes the importance landscape - some heads that were marginal become more important (others take over), while new redundancies emerge. Measuring importance after each recovery gives you a more accurate picture of what to prune next.
def iterative_pruning_pipeline(
model: nn.Module,
tokenizer,
dataloader,
train_dataset,
eval_dataset,
eval_fn, # callable(model) -> float, returns accuracy metric
config: PruningConfig,
output_dir: str,
device: str = "cuda",
) -> List[Dict]:
"""
Full iterative pruning pipeline: prune a fraction, recover, measure, repeat.
Typical usage: 3 iterations of head pruning, each removing ~8-10% of heads,
with 2 recovery epochs between. Final head count: ~75% of original.
The pipeline targets attention heads (most common), but the same pattern
applies to layer pruning and MLP neuron pruning.
Args:
eval_fn: Function that evaluates the model and returns a scalar metric
(e.g., accuracy, F1, or -1 * perplexity)
config: PruningConfig with all pruning parameters
Returns:
List of dicts with accuracy and pruning stats at each step
"""
n_iterations = 3
fraction_per_iter = config.head_pruning_fraction / n_iterations
history = []
# Baseline measurement
baseline_score = eval_fn(model)
history.append({
"step": "baseline",
"pruning_fraction": 0.0,
"score": baseline_score,
"n_heads": model.config.num_attention_heads * model.config.num_hidden_layers,
})
print(f"Baseline score: {baseline_score:.4f}")
print(f"Target: remove {config.head_pruning_fraction*100:.0f}% of heads "
f"over {n_iterations} iterations\n")
cumulative_fraction_pruned = 0.0
for iteration in range(n_iterations):
print(f"{'='*60}")
print(f"Iteration {iteration + 1}/{n_iterations}")
print(f"Pruning {fraction_per_iter*100:.1f}% of remaining heads this step")
# Step 1: Compute importance scores
print("\nComputing Taylor head importance scores...")
if config.importance_method == "taylor":
importance = compute_head_importance_taylor(
model, dataloader, n_batches=config.n_calibration_batches, device=device
)
elif config.importance_method == "magnitude":
importance = compute_head_importance_magnitude(model)
else:
importance = compute_head_importance_output_norm(
model, dataloader, n_batches=50, device=device
)
# Step 2: Select heads to prune
heads_to_prune = select_heads_to_prune(
importance,
pruning_fraction=fraction_per_iter,
protect_layers=[0, 1, -2, -1],
min_heads_per_layer=config.min_heads_per_layer,
)
# Step 3: Apply pruning
prune_attention_heads(model, heads_to_prune)
cumulative_fraction_pruned += fraction_per_iter
post_prune_score = eval_fn(model)
print(f"\nPost-pruning score: {post_prune_score:.4f} "
f"(Δ = {post_prune_score - baseline_score:+.4f})")
# Step 4: Recovery fine-tuning
iter_output_dir = f"{output_dir}/iteration_{iteration + 1}"
print(f"\nRunning {config.recovery_epochs} recovery epochs...")
recovery_fine_tune(
pruned_model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
output_dir=iter_output_dir,
n_epochs=config.recovery_epochs,
learning_rate=config.recovery_lr,
)
post_recovery_score = eval_fn(model)
recovered_delta = post_recovery_score - baseline_score
print(f"\nPost-recovery score: {post_recovery_score:.4f} "
f"(Δ from baseline = {recovered_delta:+.4f})")
history.append({
"step": f"iteration_{iteration + 1}",
"pruning_fraction_cumulative": cumulative_fraction_pruned,
"post_prune_score": post_prune_score,
"post_recovery_score": post_recovery_score,
"delta_from_baseline": recovered_delta,
"heads_pruned_this_step": sum(len(h) for h in heads_to_prune.values()),
})
print(f"\n{'='*60}")
print("Iterative pruning complete.")
print(f"Baseline score: {baseline_score:.4f}")
print(f"Final score: {post_recovery_score:.4f} "
f"(Δ = {post_recovery_score - baseline_score:+.4f})")
return history
The Combined Compression Pipeline: Pruning + Quantization
Structured pruning and quantization are orthogonal compression techniques - their savings multiply together. INT4 quantization achieves 4x memory reduction; removing 25% of layers and 20% of heads achieves ~1.6x FLOP reduction. Combined: approximately 6x total compression versus FP16 baseline.
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from awq import AutoAWQForCausalLM
import os
def run_full_compression_pipeline(
model_name_or_path: str,
output_base_dir: str,
calibration_dataloader,
train_dataset,
eval_dataset,
eval_fn,
pruning_config: PruningConfig,
device: str = "cuda",
) -> str:
"""
Full structured pruning + quantization pipeline.
Order: PRUNE → RECOVER → QUANTIZE (not quantize → prune).
Rationale for this order:
1. Quantization algorithms (AWQ, GPTQ) optimize for a FIXED model structure.
If you quantize first, the INT4 packed tensors cannot be sliced for pruning.
2. Recovery fine-tuning on FP16 is straightforward; fine-tuning after INT4
quantization requires dequantization (if using bitsandbytes) or is not
supported (if using GPTQ/AWQ packed formats).
3. The pruned FP16 model can be fed directly into any quantization tool.
Args:
model_name_or_path: HuggingFace model ID or local path
output_base_dir: Where to save intermediate and final models
calibration_dataloader: Calibration data for importance scoring
train_dataset, eval_dataset: For recovery fine-tuning
eval_fn: callable(model) -> float, returns the quality metric to track
pruning_config: PruningConfig specifying pruning targets
device: GPU device
Returns:
Path to the final quantized model
"""
os.makedirs(output_base_dir, exist_ok=True)
# -------------------------------------------------------------------------
# Phase 1: Load FP16 model
# -------------------------------------------------------------------------
print("\n" + "="*60)
print("PHASE 1: Loading FP16 model")
print("="*60)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.float16,
device_map=device,
)
model.eval()
baseline_score = eval_fn(model)
print(f"FP16 baseline score: {baseline_score:.4f}")
# -------------------------------------------------------------------------
# Phase 2: Layer pruning (angular distance metric)
# -------------------------------------------------------------------------
print("\n" + "="*60)
print("PHASE 2: Layer pruning")
print("="*60)
layer_importance = compute_layer_angular_importance(
model, calibration_dataloader,
n_batches=pruning_config.n_calibration_batches,
device=device,
)
layers_to_remove = select_layers_to_prune(
layer_importance,
pruning_fraction=pruning_config.layer_pruning_fraction,
protect_first_n=pruning_config.protect_first_n_layers,
protect_last_n=pruning_config.protect_last_n_layers,
)
prune_layers(model, layers_to_remove)
post_layer_prune_score = eval_fn(model)
print(f"After layer pruning: {post_layer_prune_score:.4f} "
f"(Δ = {post_layer_prune_score - baseline_score:+.4f})")
# -------------------------------------------------------------------------
# Phase 3: Attention head pruning (Taylor importance)
# -------------------------------------------------------------------------
print("\n" + "="*60)
print("PHASE 3: Attention head pruning")
print("="*60)
head_importance = compute_head_importance_taylor(
model, calibration_dataloader,
n_batches=pruning_config.n_calibration_batches,
device=device,
)
heads_to_prune = select_heads_to_prune(
head_importance,
pruning_fraction=pruning_config.head_pruning_fraction,
protect_layers=[0, 1, -2, -1],
min_heads_per_layer=pruning_config.min_heads_per_layer,
)
prune_attention_heads(model, heads_to_prune)
post_head_prune_score = eval_fn(model)
print(f"After head pruning: {post_head_prune_score:.4f} "
f"(Δ = {post_head_prune_score - baseline_score:+.4f})")
# -------------------------------------------------------------------------
# Phase 4: Recovery fine-tuning
# -------------------------------------------------------------------------
print("\n" + "="*60)
print("PHASE 4: Recovery fine-tuning")
print("="*60)
pruned_model_dir = os.path.join(output_base_dir, "pruned_recovered")
recovery_fine_tune(
pruned_model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
output_dir=pruned_model_dir,
n_epochs=pruning_config.recovery_epochs,
learning_rate=pruning_config.recovery_lr,
)
post_recovery_score = eval_fn(model)
print(f"After recovery: {post_recovery_score:.4f} "
f"(Δ = {post_recovery_score - baseline_score:+.4f})")
# -------------------------------------------------------------------------
# Phase 5: AWQ quantization on pruned+recovered model
# -------------------------------------------------------------------------
print("\n" + "="*60)
print("PHASE 5: AWQ INT4 quantization")
print("="*60)
# Save the FP16 pruned model, then reload with AutoAWQ for quantization
model.save_pretrained(pruned_model_dir)
tokenizer.save_pretrained(pruned_model_dir)
del model
torch.cuda.empty_cache()
awq_model = AutoAWQForCausalLM.from_pretrained(pruned_model_dir)
quant_config = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM",
}
# Prepare calibration data for AWQ (needs list of text strings)
awq_calib_data = ["Sample calibration text for AWQ scaling..." * 10] * 128
awq_model.quantize(
tokenizer,
quant_config=quant_config,
calib_data=awq_calib_data,
)
final_output_dir = os.path.join(output_base_dir, "pruned_recovered_awq_int4")
awq_model.save_quantized(final_output_dir)
tokenizer.save_pretrained(final_output_dir)
print(f"\nFinal quantized model saved to: {final_output_dir}")
return final_output_dir
:::danger Catastrophic Collapse at High Sparsity There is a hard cliff in structured pruning. For BERT-base, removing more than 85% of attention heads causes model accuracy to collapse toward chance - not graceful degradation. For Llama-style generative models, removing more than 30-35% of layers causes perplexity to spike dramatically. The collapse happens because the remaining network no longer has sufficient capacity to maintain coherent information flow through all 32+ residual stream dimensions that the output head expects.
Safe pruning ranges (validated in literature):
- Attention heads: up to 30-35% removal with recovery
- Layers (encoder models): up to 6/12 layers (50%) with significant accuracy loss
- Layers (generative LLMs): up to 20-25% with 2-4% perplexity increase
- MLP neurons: up to 30-40% of intermediate neurons
Never attempt one-shot pruning at the high end of these ranges. Use iterative pruning with validation at each step. :::
Benchmarks: What to Expect in Practice
| Model | Pruning Strategy | Fraction Removed | Accuracy Retained | Latency Speedup | Paper/Source |
|---|---|---|---|---|---|
| BERT-base | Attention heads (Taylor) | 50% heads | 99.0% (GLUE avg) | 1.4x | Michel et al. 2019 |
| BERT-large | Attention heads (magnitude) | 30% heads | 99.2% (GLUE avg) | 1.3x | Voita et al. 2019 |
| BERT-base | Layer pruning (every other) | 6/12 layers | 96.0% (GLUE avg) | 1.85x | BERT-of-Theseus |
| LLaMA-2-13B | Layer pruning (angular dist.) | 25% layers | 94-97% (0-shot tasks) | 1.25x | ShortGPT 2024 |
| LLaMA-2-7B | Layer pruning (angular dist.) | 20% layers | 95-97% | 1.20x | ShortGPT 2024 |
| Mistral-7B | Layer pruning (angular dist.) | 4/32 layers (12.5%) | 97.5% | 1.14x | ShortGPT 2024 |
| LLaMA-2-7B | Heads + layers combined | 20% heads + 4 layers | 93-95% | 1.6x | Community results |
| BERT-base | Full: heads + neurons + layers | 25% each | 93-95% | 2.2x | Combined approach |
:::warning Task-Specific vs General Benchmarks The "accuracy retained" percentages above are from the original papers' chosen benchmarks. In practice, accuracy degradation is task-dependent:
- Classification tasks (sentiment, NLI): very tolerant of pruning - often 98-99% retained
- QA tasks (SQuAD): moderately tolerant - 95-97% retained
- Generation quality (perplexity): less tolerant - each layer removal adds ~0.5-2 perplexity points
- Long-form generation coherence: most sensitive - pruned models can repetition-loop
Always evaluate on YOUR specific task and output length distribution before declaring a pruning strategy production-ready. :::
:::tip Use Perplexity for Generative Models For encoder-only models (BERT, RoBERTa), standard task benchmarks (accuracy, F1) are reliable proxies for pruning quality. For generative LLMs (Llama, Mistral), measure perplexity on a held-out text corpus at 2048+ token contexts. Short-context benchmarks can look fine while long-form generation is degrading. Additionally, inspect actual generations for repetition loops, topic drift, and grammatical breakdown - these don't show up in perplexity but matter enormously in production. :::
Interview Questions
Q1: Why does unstructured sparsity typically not improve inference latency on modern GPUs, while structured pruning does?
Modern GPU matrix multiplication operates in tiles. The tensor core instruction on NVIDIA hardware takes a fixed-size sub-matrix (e.g., 16x16 or 64x64 tiles on A100) and computes the inner product in a single pipelined operation. The GPU does not inspect individual weight values to decide whether to skip a computation - it executes the entire tile regardless of whether some values are zero. A zero weight still consumes a multiply-accumulate cycle (the product is just zero). Memory bandwidth is also unaffected: a matrix with 50% zeros still occupies the same number of bytes in HBM and requires the same number of memory transfers.
Structured pruning changes tensor dimensions. If you remove 25% of attention heads, the Q projection matrix shrinks from to . The GPU loads a genuinely smaller matrix from memory and executes fewer tile operations. The speedup is real and shows up in wall-clock latency profiling. The one hardware exception is NVIDIA 2:4 structured sparsity (exactly 2 nonzeros per 4 consecutive elements) on Ampere+ GPUs, which has dedicated sparse Tensor Core support delivering ~1.5x speedup - but this is a specific constrained sparsity pattern, not general unstructured pruning.
Q2: Explain the angular distance (Block Influence) metric for layer importance. What are its strengths and limitations?
Block Influence (BI), introduced in ShortGPT (Men et al., 2024), measures how much a transformer layer changes its input representation:
where and are flattened input and output hidden states of layer . A BI near 0 means the layer barely modifies its input (residual connection dominates) - these are the safe pruning candidates. A large BI means the layer significantly transforms its input - these are important layers to keep.
Strengths: Requires only a forward pass on calibration data (no gradients). Directly measures the layer's informational contribution to representation change. Correlates well with perplexity impact when layers are removed. Simple to implement with forward hooks.
Limitations: A layer with low BI might still perform a critical but narrow transformation that matters for rare-but-important inputs. BI is an average over calibration examples - if your calibration data doesn't cover all input types you'll encounter in production, BI scores can be misleading. Also: BI captures the overall hidden state change, but a layer might leave most dimensions unchanged while crucially modifying a small subset. Always validate BI-based pruning decisions with task-specific benchmarks before deploying.
Q3: What is the difference between one-shot and iterative pruning, and when should you prefer each?
One-shot pruning: Remove the target fraction of parameters in a single step, then fine-tune for recovery. Simple, requires less total training compute, but produces worse accuracy at high sparsity because removing many interconnected components simultaneously breaks the model's learned representations more severely than gradual removal.
Iterative pruning: Remove a small fraction, fine-tune for recovery, measure new importance scores, repeat. Each recovery step allows remaining weights to reorganize and compensate for removed components. New importance scores after recovery reflect the updated weight landscape - what was marginal may become important (other heads take over), and new redundancies emerge.
When to use one-shot: Pruning less than 15-20% of parameters. The task is simple and the model has clear redundancy (e.g., removing clearly dead neurons). Compute budget is tight. Prior experiments showed graceful one-shot degradation.
When to use iterative: Pruning more than 20-25%. Accuracy is critical and you cannot afford >2-3% degradation. Previous one-shot attempts showed too-large accuracy drops. Targeting aggressive sparsity (>35% heads). In production LLM deployments, iterative pruning with 3-5 cycles typically achieves 30-50% better accuracy retention at the same final sparsity compared to one-shot.
Q4: How does structured pruning of generative LLMs differ from pruning encoder-only models?
Encoder-only models (BERT, RoBERTa):
- Very tolerant of head pruning - up to 50% heads can be removed with minimal GLUE impact
- Head removal is safe because classification tasks mainly need final-layer representations
- Angular distance shows clear high/low patterns - easy to identify redundant layers
- Evaluation is clean: classification accuracy on standard benchmarks
Generative LLMs (Llama, Mistral):
- Layer pruning is relatively safe (20-25%) but head pruning is more sensitive
- Generative quality depends on diverse attention patterns for long-range coherence - too much head pruning causes repetition loops and topic drift even when short-context benchmarks look fine
- Middle layers (roughly 30-70% depth) tend to have lower BI scores and are safer to prune; early and late layers are critical
- Evaluation must include long-form generation at 2048+ contexts - not just classification benchmarks
- KV cache analysis can help identify which heads are responsible for long-range dependencies before pruning them
- MLP neuron pruning (targeting dead neurons by activation frequency) is often more effective than head pruning for generative models - many neurons fire <1% of the time and can be removed with minimal perplexity impact
Q5: How would you combine structured pruning with quantization for a production deployment requiring maximum compression?
The recommended pipeline:
Step 1: Baseline evaluation. Measure FP16 model accuracy on your specific task. This is your gold standard. Any compression technique must be measured against this baseline, not theoretical numbers from papers.
Step 2: Importance measurement. Run calibration (100-200 batches of domain-matched data) to compute: layer angular distances (BI scores), Taylor head importance scores, and neuron activation frequencies. No model changes yet - just measurement.
Step 3: Pruning plan. Select layers to remove (lowest BI, protected first/last 2 layers), heads to remove (lowest Taylor importance globally, protect first/last 2 layers), neurons to remove (lowest activation frequency, up to 25-30%). Validate the plan gives reasonable expected compression before executing.
Step 4: Apply pruning. Execute layer removal, then head removal, then neuron removal. Measure accuracy immediately after - this tells you the recovery gap to close.
Step 5: Recovery fine-tuning. 3-5 epochs at lr=1e-5 (lower than original fine-tuning). Use cosine schedule with 10% warmup. Monitor eval loss per epoch; stop if it plateaus.
Step 6: Verify pruned FP16 model. The pruned+recovered FP16 model should be within 1-2% of baseline. If it's worse than 3%, the pruning was too aggressive - start over with a smaller fraction.
Step 7: Quantize the pruned model. Apply AWQ (preferred for deployment flexibility) or GPTQ (preferred for per-sample calibration) to the FP16 pruned+recovered model. Use domain-matched calibration data.
Step 8: Final evaluation. Measure the quantized+pruned model on your benchmark. Expected: 5-6x total compression vs FP16, 3-6% accuracy loss, 3-5x latency improvement.
The reason for prune-before-quantize: GPTQ/AWQ produce packed INT4 tensors in a custom format that cannot be sliced to remove heads or neurons. If you quantize first, you lose the ability to do structured pruning without unpacking and repacking, which loses the quantization calibration. Prune first, recover, then quantize - this order gives you full flexibility and the best results.
Q6: What is the "protection rule" for first and last layers, and why does it matter more for generative models?
The protection rule: never prune the first 2 and last 2 transformer layers during structured pruning.
Why first layers: The first 1-2 layers perform initial contextualization - converting position-free token embeddings into contextual representations where each token's embedding reflects its neighbors. Angular distance is always high for these layers (they change the representation significantly). Pruning them causes the model to lose initial contextualization, degrading all downstream layers.
Why last layers: For encoder models, the last layer prepares the [CLS] representation for the classification head. For generative models, the last 2-4 layers are responsible for vocabulary distribution sharpening - converting the internal representation space into logits over the vocabulary. Pruning them causes the model to produce incoherent output distributions.
Why generative models are more sensitive: Encoder models pass the entire sequence through a classification head - they can compensate for last-layer pruning because the full sequence context is available. Generative models are autoregressive - each token's generation depends on the KV cache from all previous tokens. If the last layers are degraded, errors compound over long sequences, causing the generation quality to degrade super-linearly with sequence length. A pruned encoder might lose 2% accuracy on a 512-token classification task; a pruned generator might seem fine for short responses but fail catastrophically at 1000+ tokens.
