Speculative Decoding
The Bottleneck Nobody Expected
You've just finished quantizing your 70B parameter model to INT4. The model weights went from 140GB to 35GB. You're expecting a massive speedup. You benchmark it: 18 tokens per second. Before quantization: 14 tokens per second. A 28% improvement for a 4x compression in size. Something seems wrong.
You profile the hardware. The A100 is sitting at 85% memory bandwidth utilization and only 12% compute utilization. The model is generating one token at a time, and each token requires reading all 35GB of model weights from HBM (High Bandwidth Memory) into the compute units, performing a relatively small matrix-vector multiply, and discarding the intermediate results. The next token requires reading the same 35GB again. And again. And again.
You've optimized the memory footprint, but the fundamental problem isn't memory size - it's the access pattern. Autoregressive decoding is inherently sequential. You cannot generate token until token exists. Each generation step is a tiny computation (one token worth of multiplications) requiring an enormous data load (billions of parameters). This is the worst possible regime for GPU hardware: memory-bandwidth-bound with near-zero arithmetic intensity.
Quantization helps by reducing how much you read per step, but it doesn't change the fundamental bottleneck structure. To get real speedups, you need to either reduce the number of sequential steps or increase the arithmetic intensity per step. Speculative decoding does both.
The insight is deceptively simple: most of what a large model generates is predictable. Filler words, common phrases, syntactically constrained continuations - a much smaller model can guess these correctly most of the time. If you can verify candidate tokens in a single forward pass of the large model (which costs about the same as verifying one token), and accept most of them, you've effectively parallelized sequential generation.
Why This Exists
The Arithmetic Intensity Problem
To understand why speculative decoding works, you need to understand why autoregressive LLM decoding is so inefficient from a hardware perspective.
During generation, each forward pass processes a "batch" of exactly one new token (assuming batch size 1 for clarity). The model has roughly parameters. Each parameter is read once from HBM, used in a matrix-vector multiplication, and contributes to computing the next token's logits. Then those parameters are discarded from registers/cache and the next step starts.
The arithmetic intensity is:
For a matrix-vector multiply with a weight matrix of shape :
- FLOPs: (multiply-add)
- Bytes read:
On an A100-80GB, peak compute is 312 TFLOPS (BF16) and peak memory bandwidth is 2 TB/s. The roofline intersection is at FLOP/byte. At 1 FLOP/byte, we're 156x below the compute roofline. We're bottlenecked entirely by memory bandwidth, and we're using roughly 1/156th of the available compute.
Every optimization that doesn't change this fundamental ratio (quantization, weight compression, caching) can at best improve memory bandwidth utilization. Speculative decoding fundamentally changes the ratio by processing multiple tokens per "round" of weight reads.
What Approaches Were Tried Before
Batching: Serving multiple requests simultaneously improves GPU utilization because you amortize the weight-reading cost over multiple computations. But batching increases latency (requests wait for others to join a batch) and is limited by memory for KV caches. It doesn't help latency for a single request.
Continuous batching (Orca-style): Improved throughput by dynamically composing batches mid-inference. Still memory-bound at the per-token level.
Flash Attention: Reduced attention's memory footprint by recomputing instead of storing. Helped with long-context training and prefill, but decoding attention is already fast - the linear layers are the bandwidth bottleneck.
Smaller models: Faster per-token generation, worse output quality. Not a general solution.
Hardware improvements: H100 has 3.35 TB/s HBM bandwidth vs A100's 2 TB/s, a 1.67x improvement. Not enough to fundamentally change economics.
Speculative decoding, published by Chen et al. at Google DeepMind in 2023, and independently by Leviathan et al. at Google Research in the same year, took a different approach: change the algorithm to better utilize available compute, rather than trying to read memory faster.
Historical Context
The 2023 Parallel Discovery
Two teams independently published the speculative decoding idea within months of each other in 2023. This is a reliable signal that the idea was "in the air" - the hardware bottleneck was well-understood, the solution was elegant, and multiple groups reached it simultaneously.
Charlie Chen and Nate Frosst at Google DeepMind published "Accelerating Large Language Model Decoding with Speculative Sampling" in May 2023. Yaniv Leviathan, Matan Kalman, and Yossi Matias at Google Research published "Fast Inference from Transformers via Speculative Decoding" at ICML 2023 in the same period.
Both papers proved the same key result: you can use a small "draft" model to generate candidate tokens, verify them with the large "target" model in a single forward pass, and the output distribution of the resulting algorithm is identical to sampling from the target model alone. The speedup is real, and it's lossless in terms of output quality.
This lossless guarantee is crucial. Previous "draft and verify" ideas had been explored for other inference acceleration techniques, but they always changed the output distribution or required approximations. The speculative sampling algorithm avoids this by using a carefully designed rejection sampling scheme.
The Berkeley Sky Computing Lab and other groups quickly extended the work: Chou et al. introduced self-drafting with Medusa heads, Gharatkar et al. explored tree-based speculation, and the EAGLE (Efficient Acceleration with Greedy Algorithm for Low Latency decoding) paper from 2024 extended self-speculative decoding with feature-level prediction.
Core Concepts
The Basic Algorithm
Speculative decoding has two models: a small, fast draft model and a large, capable target model .
The algorithm proceeds in rounds:
Draft phase: The draft model autoregressively generates token candidates . This is fast because is small (e.g., a 7B model drafting for a 70B target).
Verify phase: The target model processes the original context PLUS all draft tokens in a single forward pass. Because we're processing positions simultaneously (the draft tokens plus one new position), this is a prefill-like computation, not decode - it runs at much higher arithmetic intensity.
Acceptance/rejection: For each draft token , we compare the draft model's probability to the target model's probability . We accept with probability:
If is rejected, we sample a replacement from an adjusted distribution:
This adjusted distribution ensures we sample from the part of that "missed." The key insight: if we always sample from this adjusted distribution when a rejection occurs, the overall output distribution equals exactly.
The Speedup Formula
If the draft acceptance rate is (probability a single draft token is accepted), the expected number of tokens generated per round is:
This follows from: with probability , the first token is accepted; with probability , the first two are accepted; and so on. When all draft tokens are accepted, the target model also generates one bonus token from the corrected distribution, giving tokens total.
The cost per round is approximately one target model forward pass (prefill of tokens) plus draft model forward passes. Denote the target model's single-token decode time as and the draft model's as . The speedup is:
For typical setups where the draft model is 10-15x faster than the target () and acceptance rate is , with speculative tokens:
Nearly 3x speedup with no change in output quality. In practice, measured speedups are typically 2-3.5x depending on the task.
Why Arithmetic Intensity Improves
When the target model verifies tokens simultaneously, the computation changes from decode mode (matrix-vector multiply) to prefill mode (matrix-matrix multiply). The weight matrices are read once from HBM, but used to compute output positions instead of 1.
Arithmetic intensity scales linearly with the number of simultaneous positions:
With : arithmetic intensity is 6x higher. At 6 FLOP/byte on the A100, we're still memory-bandwidth-bound, but making 6x better use of the memory bandwidth. This is where the speedup comes from: same weight reads, 6x the useful computation.
Acceptance Rate and Task Dependence
The acceptance rate varies dramatically by task:
- Code generation (completing Python functions, filling in structured code): 0.85-0.95. Code is highly predictable within syntactic constraints. A 7B draft model drafts excellent Python for a 70B target.
- Factual QA (answering factual questions): 0.75-0.85. Well-structured responses, but factual tokens are harder to predict.
- Creative writing: 0.60-0.75. More variable vocabulary, style choices the draft model doesn't share with the target.
- Multilingual tasks: 0.50-0.70. Draft models trained predominantly on English may poorly predict non-English tokens.
- Mathematical reasoning: 0.70-0.85. Structured but specialized notation.
This task-dependence is important for production deployment. Don't report a single "2.5x speedup" - measure acceptance rates per task category and compute expected speedup per workload mix.
Draft Model Selection
The draft model must balance two competing requirements: fast enough to not eat the savings, and accurate enough to maintain high acceptance rates.
Architecture-matched drafts (e.g., LLaMA-7B for LLaMA-70B): The same architecture, smaller scale. Share the same tokenizer, same training distribution. High acceptance rates because the models agree on most tokens. This was the original approach in the Chen et al. paper.
Distilled drafts: Train a small model specifically to match the large model's token distribution. Higher acceptance rates than random smaller models, at the cost of requiring a dedicated training run.
Medusa heads: Attach multiple "draft heads" directly to the target model, one per future token position. These are small feedforward networks that predict , , etc. based on the target model's hidden states. During inference, a single forward pass of the target model produces both the current token and draft predictions for the next tokens. Eliminates the separate draft model entirely. Published by Chou et al. at Cornell in 2024.
EAGLE (Self-speculative with feature prediction): Instead of predicting tokens directly, EAGLE trains a lightweight autoregressive model that predicts the target model's hidden states at the next token, then converts those predicted hidden states to token probabilities using the target model's head. Feature-level prediction is easier than token-level prediction, resulting in acceptance rates of 0.85-0.95 across nearly all tasks. EAGLE2 (2024) extended this with dynamic draft lengths.
Self-Speculative Decoding: Layer Skipping
An alternative approach: use the target model itself as its own draft by skipping some layers. The insight: early layers of a large model already produce a reasonable approximation to the final output. Generate a draft by running only the first layers (cheap), then verify with the full model.
This avoids holding two separate models in GPU memory - a significant practical advantage for memory-constrained deployments. The tradeoff: lower acceptance rates than using a purpose-trained draft model, because the early-exit hidden states are less accurate than a separately trained draft.
Architecture Diagrams
Autoregressive vs Speculative Decoding
Speculative Sampling Acceptance Logic
EAGLE vs Standard Draft-Verify
Code Examples
Implementing Speculative Decoding from Scratch
This implementation shows the exact algorithm from the Chen et al. paper, using HuggingFace models.
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Optional
import time
def speculative_sample(
draft_model: AutoModelForCausalLM,
target_model: AutoModelForCausalLM,
input_ids: torch.Tensor,
max_new_tokens: int = 200,
num_speculative_tokens: int = 5,
temperature: float = 1.0,
top_p: Optional[float] = None,
) -> tuple[torch.Tensor, dict]:
"""
Speculative decoding with rejection sampling.
Output distribution is identical to sampling from target_model alone.
Returns:
generated_ids: token IDs for full output
stats: dict with acceptance rates and round counts
"""
device = input_ids.device
generated = input_ids.clone()
stats = {
"total_rounds": 0,
"total_draft_tokens": 0,
"total_accepted_tokens": 0,
"tokens_generated": 0,
}
while stats["tokens_generated"] < max_new_tokens:
# --- DRAFT PHASE ---
draft_tokens = []
draft_probs = []
current = generated.clone()
for _ in range(num_speculative_tokens):
with torch.no_grad():
draft_out = draft_model(current)
logits = draft_out.logits[:, -1, :] # [batch, vocab]
if temperature != 1.0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
# Sample from draft distribution
next_token = torch.multinomial(probs, num_samples=1)
draft_tokens.append(next_token)
draft_probs.append(probs[0, next_token.item()])
current = torch.cat([current, next_token], dim=-1)
# --- VERIFY PHASE ---
# Stack all draft tokens and run target model in one forward pass
# Shape: [1, original_len + num_speculative_tokens]
candidate_ids = torch.cat([generated] + draft_tokens, dim=-1)
with torch.no_grad():
target_out = target_model(candidate_ids)
# logits[i] predicts token at position i+1
target_logits = target_out.logits[0, len(generated)-1:-1, :]
# --- ACCEPTANCE LOOP ---
accepted_count = 0
n = num_speculative_tokens
for i in range(n):
draft_token_id = draft_tokens[i].item()
if temperature != 1.0:
target_logits[i] = target_logits[i] / temperature
target_probs = F.softmax(target_logits[i], dim=-1)
p_target = target_probs[draft_token_id].item()
p_draft = draft_probs[i].item()
# Acceptance probability
accept_prob = min(1.0, p_target / (p_draft + 1e-8))
# Random acceptance decision
u = torch.rand(1).item()
if u < accept_prob:
# Accept this draft token
generated = torch.cat([generated, draft_tokens[i]], dim=-1)
accepted_count += 1
stats["total_accepted_tokens"] += 1
else:
# Reject: sample from adjusted distribution
# adjusted(t) = normalize(max(0, p_target(t) - p_draft(t)))
draft_full_probs = F.softmax(
draft_model(candidate_ids[:, :len(generated) + i]).logits[0, -1] / temperature,
dim=-1
)
adjusted = torch.clamp(target_probs - draft_full_probs, min=0)
adjusted_sum = adjusted.sum()
if adjusted_sum > 0:
adjusted = adjusted / adjusted_sum
replacement = torch.multinomial(adjusted, num_samples=1)
else:
# Fallback to target distribution
replacement = torch.multinomial(target_probs, num_samples=1)
generated = torch.cat([
generated,
replacement.unsqueeze(0)
], dim=-1)
stats["tokens_generated"] += 1
break # Stop this round after rejection
else:
# All draft tokens accepted: generate one more from target
final_logits = target_out.logits[0, -1, :] / temperature
final_probs = F.softmax(final_logits, dim=-1)
bonus_token = torch.multinomial(final_probs, num_samples=1)
generated = torch.cat([generated, bonus_token.unsqueeze(0)], dim=-1)
stats["tokens_generated"] += accepted_count + 1
stats["total_rounds"] += 1
stats["total_draft_tokens"] += n
stats["tokens_generated"] = max(
stats["tokens_generated"],
accepted_count
)
# Check EOS
if generated[0, -1].item() in [2]: # EOS token ID
break
stats["acceptance_rate"] = (
stats["total_accepted_tokens"] / max(stats["total_draft_tokens"], 1)
)
return generated, stats
def benchmark_speculative_vs_autoregressive(
draft_model_name: str,
target_model_name: str,
prompts: list[str],
num_speculative_tokens: int = 5,
max_new_tokens: int = 200,
):
"""
Compare speculative decoding vs standard autoregressive decoding.
Reports throughput and acceptance rates.
"""
print(f"Loading tokenizer from {target_model_name}...")
tokenizer = AutoTokenizer.from_pretrained(target_model_name)
print(f"Loading draft model: {draft_model_name}")
draft_model = AutoModelForCausalLM.from_pretrained(
draft_model_name,
torch_dtype=torch.float16,
device_map="cuda:0"
).eval()
print(f"Loading target model: {target_model_name}")
target_model = AutoModelForCausalLM.from_pretrained(
target_model_name,
torch_dtype=torch.float16,
device_map="cuda:1" # separate GPU for target in real setup
).eval()
results = []
for prompt in prompts:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# --- Autoregressive baseline ---
ar_start = time.time()
with torch.no_grad():
ar_output = target_model.generate(
input_ids.cuda(),
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=1.0,
)
ar_time = time.time() - ar_start
ar_tokens = ar_output.shape[1] - input_ids.shape[1]
# --- Speculative decoding ---
spec_start = time.time()
spec_output, stats = speculative_sample(
draft_model=draft_model,
target_model=target_model,
input_ids=input_ids.cuda(),
max_new_tokens=max_new_tokens,
num_speculative_tokens=num_speculative_tokens,
)
spec_time = time.time() - spec_start
spec_tokens = spec_output.shape[1] - input_ids.shape[1]
result = {
"prompt": prompt[:50] + "...",
"ar_tokens_per_sec": ar_tokens / ar_time,
"spec_tokens_per_sec": spec_tokens / spec_time,
"speedup": (spec_tokens / spec_time) / (ar_tokens / ar_time),
"acceptance_rate": stats["acceptance_rate"],
"avg_tokens_per_round": (
stats["tokens_generated"] / max(stats["total_rounds"], 1)
),
}
results.append(result)
print(f"\nPrompt: {result['prompt']}")
print(f" AR: {result['ar_tokens_per_sec']:.1f} tok/s")
print(f" Spec: {result['spec_tokens_per_sec']:.1f} tok/s")
print(f" Speedup: {result['speedup']:.2f}x")
print(f" Acceptance rate: {result['acceptance_rate']:.1%}")
print(f" Avg tokens/round: {result['avg_tokens_per_round']:.2f}")
return results
Using HuggingFace's Built-in Speculative Decoding
HuggingFace Transformers added native speculative decoding support in 4.31.0. This is the recommended production approach.
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
# Load models
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-chat-hf")
# Target: large model
target = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-chat-hf",
torch_dtype=torch.float16,
device_map="auto",
)
# Draft: smaller model (same family for best acceptance rates)
draft = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
torch_dtype=torch.float16,
device_map="cuda:0",
)
prompt = "[INST] Write a Python class implementing a binary search tree with insert, search, and delete methods. Include docstrings and type hints. [/INST]"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Standard autoregressive - baseline
print("=== Autoregressive Baseline ===")
start = time.time()
with torch.no_grad():
ar_out = target.generate(
**inputs,
max_new_tokens=512,
do_sample=False, # greedy for fair comparison
)
ar_time = time.time() - start
ar_new_tokens = ar_out.shape[1] - inputs["input_ids"].shape[1]
print(f"Generated {ar_new_tokens} tokens in {ar_time:.2f}s")
print(f"Throughput: {ar_new_tokens / ar_time:.1f} tok/s")
# Speculative decoding
print("\n=== Speculative Decoding ===")
start = time.time()
with torch.no_grad():
spec_out = target.generate(
**inputs,
assistant_model=draft, # HuggingFace API for speculative decoding
max_new_tokens=512,
do_sample=False,
num_assistant_tokens=5, # k speculative tokens per round
num_assistant_tokens_schedule="constant", # or "heuristic" for adaptive k
)
spec_time = time.time() - start
spec_new_tokens = spec_out.shape[1] - inputs["input_ids"].shape[1]
print(f"Generated {spec_new_tokens} tokens in {spec_time:.2f}s")
print(f"Throughput: {spec_new_tokens / spec_time:.1f} tok/s")
print(f"Speedup: {(spec_new_tokens / spec_time) / (ar_new_tokens / ar_time):.2f}x")
# Verify outputs are identical (greedy decoding should be deterministic)
assert torch.all(ar_out == spec_out), "Speculative and AR outputs differ!"
print("Output verification: PASSED (identical tokens)")
Medusa Heads: Self-Speculative Decoding
# Medusa adds multiple draft heads to the base model
# Each head predicts a future token position
# Paper: "Medusa: Simple LLM Inference Acceleration Framework
# with Multiple Decoding Heads" (Chou et al., 2024)
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaConfig
class MedusaHead(nn.Module):
"""
Single Medusa head: predicts token at position t+offset
given hidden state at position t.
"""
def __init__(
self,
hidden_size: int,
vocab_size: int,
num_layers: int = 1,
):
super().__init__()
layers = []
for _ in range(num_layers):
layers.extend([
nn.Linear(hidden_size, hidden_size, bias=False),
nn.SiLU(),
])
layers.append(nn.Linear(hidden_size, vocab_size, bias=False))
self.net = nn.Sequential(*layers)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.net(hidden_states)
class MedusaModel(nn.Module):
"""
Base LLM + multiple Medusa heads for speculative decoding.
The base model is frozen; only heads are trained.
"""
def __init__(
self,
base_model: LlamaForCausalLM,
num_medusa_heads: int = 5, # speculates 5 tokens ahead
num_head_layers: int = 1,
):
super().__init__()
self.base_model = base_model
config = base_model.config
# Freeze the base model
for param in self.base_model.parameters():
param.requires_grad = False
# Medusa heads - one per future token position
self.medusa_heads = nn.ModuleList([
MedusaHead(
hidden_size=config.hidden_size,
vocab_size=config.vocab_size,
num_layers=num_head_layers,
)
for _ in range(num_medusa_heads)
])
def forward(self, input_ids: torch.Tensor, **kwargs):
# Run base model to get hidden states
outputs = self.base_model(
input_ids,
output_hidden_states=True,
**kwargs
)
last_hidden = outputs.hidden_states[-1] # [batch, seq, hidden]
# Original logits from base model head
base_logits = outputs.logits
# Medusa head logits: each head predicts one future position
medusa_logits = [
head(last_hidden) for head in self.medusa_heads
]
return base_logits, medusa_logits
def generate_with_medusa(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
medusa_choices: int = 5, # top-k candidates per head
) -> torch.Tensor:
"""
Tree-based speculative generation with Medusa heads.
Each head generates a candidate; we verify all in parallel.
"""
generated = input_ids.clone()
for _ in range(max_new_tokens):
base_logits, medusa_logits = self.forward(generated)
# Sample candidates from each Medusa head
# In practice: build a tree of candidates, verify with base model
# Simplified here: use top-1 from each head
current_token_logits = base_logits[0, -1, :] / temperature
current_token = torch.argmax(current_token_logits)
medusa_candidates = [
torch.argmax(logits[0, -1, :])
for logits in medusa_logits
]
# Build candidate sequence: current + speculated
# In full implementation: build a tree and verify via attention mask
# For simplicity: linear speculation
candidate = torch.cat([
generated,
current_token.unsqueeze(0).unsqueeze(0)
] + [
tok.unsqueeze(0).unsqueeze(0)
for tok in medusa_candidates
], dim=-1)
# Verify candidates with a forward pass
# (accept/reject using standard speculative sampling scheme)
generated = torch.cat(
[generated, current_token.unsqueeze(0).unsqueeze(0)],
dim=-1
)
return generated
Measuring Acceptance Rates per Task Type
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch
import json
TASK_PROMPTS = {
"code_python": [
"Write a Python function that implements merge sort:",
"Create a class that implements a thread-safe queue in Python:",
"Write a Python decorator for caching function results:",
],
"factual_qa": [
"What is the capital of France?",
"Who wrote Pride and Prejudice?",
"What year did World War II end?",
],
"creative_writing": [
"Write a short poem about autumn leaves:",
"Begin a science fiction story set on Mars:",
"Describe a mysterious figure arriving at a train station:",
],
"math_reasoning": [
"Solve step by step: If a train travels 120 miles in 2 hours, what is its average speed?",
"Prove that the sum of angles in a triangle is 180 degrees:",
"A rectangle has perimeter 48cm and area 128cm^2. Find its dimensions:",
],
}
def measure_acceptance_rate(
draft_model,
target_model,
tokenizer,
prompts: list[str],
num_speculative_tokens: int = 5,
max_new_tokens: int = 100,
) -> dict:
"""
Measure speculative decoding acceptance rate across prompts.
Uses HuggingFace generation hooks to count accepts/rejects.
"""
total_draft = 0
total_accepted = 0
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# HuggingFace tracks speculative decoding stats in generation output
gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.8,
num_assistant_tokens=num_speculative_tokens,
)
with torch.no_grad():
output = target_model.generate(
**inputs,
assistant_model=draft_model,
generation_config=gen_config,
return_dict_in_generate=True,
)
# Extract stats if available (varies by HF version)
if hasattr(output, "assistant_tokens_num"):
total_draft += output.assistant_tokens_num.sum().item()
total_accepted += output.assistant_tokens_num.sum().item() # placeholder
return {
"total_draft_tokens": total_draft,
"total_accepted": total_accepted,
"acceptance_rate": total_accepted / max(total_draft, 1),
}
def run_task_benchmark(draft_model_name: str, target_model_name: str):
"""Run acceptance rate benchmark across all task types."""
tokenizer = AutoTokenizer.from_pretrained(target_model_name)
draft = AutoModelForCausalLM.from_pretrained(
draft_model_name, torch_dtype=torch.float16, device_map="cuda:0"
).eval()
target = AutoModelForCausalLM.from_pretrained(
target_model_name, torch_dtype=torch.float16, device_map="auto"
).eval()
results = {}
for task, prompts in TASK_PROMPTS.items():
print(f"Benchmarking: {task}...")
stats = measure_acceptance_rate(draft, target, tokenizer, prompts)
results[task] = stats
print(f" Acceptance rate: {stats['acceptance_rate']:.1%}")
print("\n=== Summary ===")
print(json.dumps(
{k: f"{v['acceptance_rate']:.1%}" for k, v in results.items()},
indent=2
))
return results
vLLM Speculative Decoding Configuration
from vllm import LLM, SamplingParams
# vLLM has native speculative decoding support (>= 0.4.0)
# Option 1: Draft model speculative decoding
llm = LLM(
model="meta-llama/Llama-2-70b-chat-hf",
speculative_model="meta-llama/Llama-2-7b-chat-hf",
num_speculative_tokens=5, # k tokens per round
speculative_draft_tensor_parallel_size=1, # TP for draft model
tensor_parallel_size=4, # TP for target model
gpu_memory_utilization=0.90,
)
# Option 2: EAGLE speculative decoding (requires EAGLE-trained model)
# EAGLE models are available for common architectures
llm_eagle = LLM(
model="meta-llama/Llama-2-70b-chat-hf",
speculative_model="yuhuili/EAGLE-LLaMA2-Chat-70B", # EAGLE draft
num_speculative_tokens=6,
use_v2_block_manager=True, # required for EAGLE in vLLM
tensor_parallel_size=4,
)
# Option 3: Medusa (requires Medusa-finetuned checkpoint)
llm_medusa = LLM(
model="FasterDecoding/medusa-1-llama-2-7b-chat",
speculative_model="[medusa]", # special vLLM value for built-in Medusa
num_speculative_tokens=5,
tensor_parallel_size=1,
)
# Sampling parameters are identical to non-speculative
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=512,
)
prompts = [
"Implement a Python class for a neural network layer with forward and backward methods.",
"Explain the difference between horizontal and vertical scaling in distributed systems.",
]
outputs = llm.generate(prompts, sampling_params)
for out in outputs:
print(f"Output: {out.outputs[0].text[:100]}...")
Profiling Speedup and Hardware Utilization
# Profile speculative vs standard decoding on A100
# First install profiling dependencies
pip install vllm nvidia-ml-py3
# Benchmark with vLLM's built-in benchmark tool
# Standard autoregressive
python -m vllm.entrypoints.benchmark_throughput \
--model meta-llama/Llama-2-70b-chat-hf \
--tensor-parallel-size 4 \
--num-prompts 200 \
--dataset sharegpt \
--output results_ar.json
# With speculative decoding
python -m vllm.entrypoints.benchmark_throughput \
--model meta-llama/Llama-2-70b-chat-hf \
--speculative-model meta-llama/Llama-2-7b-chat-hf \
--num-speculative-tokens 5 \
--tensor-parallel-size 4 \
--speculative-draft-tensor-parallel-size 1 \
--num-prompts 200 \
--dataset sharegpt \
--output results_spec.json
# GPU utilization monitoring during benchmark
nvidia-smi dmon -s u -d 1 -o DT > gpu_utilization.log &
NVIDIASMI_PID=$!
# Run benchmark
python -m vllm.entrypoints.benchmark_throughput \
--model meta-llama/Llama-2-70b-chat-hf \
--speculative-model meta-llama/Llama-2-7b-chat-hf \
--num-speculative-tokens 5 \
--tensor-parallel-size 4 \
--num-prompts 100
kill $NVIDIASMI_PID
# Profile with NSight Systems to see draft/verify phase breakdown
nsys profile \
--trace=cuda,cudnn,nvtx \
--output speculative_profile \
python -m vllm.entrypoints.benchmark_latency \
--model meta-llama/Llama-2-70b-chat-hf \
--speculative-model meta-llama/Llama-2-7b-chat-hf \
--num-speculative-tokens 5 \
--batch-size 1 \
--input-len 256 \
--output-len 128
Production Engineering Notes
When Speculative Decoding Helps (and When It Doesn't)
Speculative decoding improves latency for single-request or low-concurrency workloads. Under high batch concurrency, the picture is different.
Latency improvement: With batch size 1, speculative decoding consistently delivers 2-3x better time-to-last-token. This matters for interactive applications, coding assistants, and any workflow where users wait for a single response to complete.
Throughput at high concurrency: With batch size 32-128, the GPU is already well-utilized during autoregressive decoding (because batching increases arithmetic intensity). Adding speculative decoding may actually hurt throughput here - you're running the draft model for each request, consuming compute that could serve additional batch members.
The break-even point is typically batch size 4-8. Below that, speculative decoding wins on throughput. Above that, it may be neutral or negative.
Rule of thumb: Use speculative decoding for customer-facing latency-sensitive serving at low-to-moderate concurrency. Avoid it for high-throughput batch processing workloads.
Memory Overhead: Two Models
The most significant production constraint is memory. Standard speculative decoding requires holding both draft and target models in GPU memory simultaneously.
For LLaMA-2 70B (target, FP16) + LLaMA-2 7B (draft, FP16):
- Target: 140 GB
- Draft: 14 GB
- Total: 154 GB vs 140 GB alone (10% overhead)
This is manageable on a 4xA100 cluster (320 GB total HBM). But on smaller hardware, the draft model may push you over the limit. Options:
- Use INT4/INT8 for the draft model (reduces its footprint from 14 GB to 3.5-7 GB)
- Use Medusa or EAGLE heads (adds only 100-500 MB over the base model)
- CPU-offload the draft model (adds latency; defeats the purpose for low-latency serving)
Draft Model Selection in Practice
The best draft model is the smallest model that achieves alpha > 0.8 on your target workload.
For code generation workloads: LLaMA-2 7B drafting for LLaMA-2 70B achieves alpha = 0.85-0.92 on Python code. This is the sweet spot.
For general chat: alpha drops to 0.70-0.80, which is still beneficial (expected speedup ~2x) but less dramatic.
Test your specific use case before committing to a draft model. Run 1000 examples from your production traffic, measure alpha per task type, and compute expected speedup using the formula above. If alpha < 0.65 for your workload, speculative decoding may not be worth the operational complexity.
Adaptive Speculation: Dynamic k
The optimal number of speculative tokens depends on alpha. If you observe a cold cache with low alpha, a large wastes draft computation on tokens that will be rejected. If alpha is high, a larger gives better speedup.
HuggingFace's num_assistant_tokens_schedule="heuristic" adjusts based on recent acceptance rates. vLLM 0.5+ has adaptive speculation. In production, start with and measure whether adaptive scheduling improves your acceptance rates.
Common Mistakes
:::danger Using Different Tokenizers for Draft and Target
The speculative decoding acceptance criterion assumes draft and target models share the same vocabulary and tokenization. If you use a draft model with a different tokenizer (e.g., a GPT-2 based draft with a LLaMA-based target), the token IDs don't correspond and the acceptance ratio is computed over mismatched distributions.
The result is unpredictable output quality - the algorithm will reject or accept tokens for the wrong reasons, and the output distribution guarantee no longer holds.
Always verify draft_tokenizer.vocab == target_tokenizer.vocab before using speculative decoding. In practice, use models from the same family (LLaMA-7B with LLaMA-70B) or explicitly verify that the vocabularies are identical.
:::
:::danger Applying Speculative Decoding to High-Concurrency Batch Workloads
At batch sizes above 16-32, the target model forward pass is already compute-bound, not memory-bandwidth-bound. Speculative decoding's benefit comes entirely from the memory-bandwidth-bound regime. In compute-bound operation, the draft model adds overhead without improving throughput.
Measure before deploying. Run a throughput benchmark at your actual batch size with and without speculative decoding. If speculative decoding reduces throughput by more than 5%, it's not worth the complexity for that workload.
The easy mistake: benchmarking speculative decoding at batch size 1 (shows 3x speedup), deploying to a high-concurrency production cluster (shows 0.8x throughput, a regression), and not understanding why the benchmark didn't predict production behavior. :::
:::warning Forgetting That Speedup Is Task-Dependent
A single benchmark number for "speculative decoding speedup" is misleading. Acceptance rates vary from 0.5 (creative writing with a mismatched draft) to 0.95 (code generation with a same-family draft). The resulting speedup varies from 1.2x to 3.5x.
If your production workload mixes code generation (alpha=0.90), factual QA (alpha=0.78), and creative writing (alpha=0.62), your actual speedup may be 1.8x even though the "code generation benchmark" shows 3.0x.
Measure alpha on your actual traffic distribution. Log acceptance rates per request category in production. Set up alerts if alpha drops below your expected threshold - this may indicate a shift in traffic distribution or a bug in prompt construction. :::
:::warning The Synchronization Overhead at Large Tensor Parallel Degree
With tensor parallelism of degree 4 or 8, each forward pass involves AllReduce communication across GPUs (via NVLink). Speculative decoding introduces more forward passes per unit time (one draft pass per token plus the verify pass), each triggering AllReduce synchronizations.
At TP=4, the NVLink overhead per forward pass is small (microseconds). At TP=8, it's larger and the draft passes accumulate. This reduces the effective speedup from theoretical to practical.
Mitigation: run the draft model on a single GPU (TP=1) even if the target is TP=4 or TP=8. The draft model is small enough to fit on one GPU. This avoids the multi-GPU overhead for the draft phase while keeping full parallelism for the verify phase.
In vLLM: speculative_draft_tensor_parallel_size=1 even when tensor_parallel_size=4.
:::
Interview Q&A
Q1: Prove that speculative decoding with rejection sampling produces the same output distribution as sampling directly from the target model.
Answer: This is the core theoretical result from Chen et al. and Leviathan et al. We need to show that the marginal distribution over the accepted token equals the target distribution .
Consider a single draft token sampled from draft distribution . With probability , we accept it. Otherwise, we sample from the adjusted distribution .
The probability of outputting token is:
The first term is the probability of drafting AND accepting it. The second term is the probability of any rejection occurring AND then sampling from the adjusted distribution.
The acceptance probability for any is , so the total acceptance probability is . The rejection probability is .
After algebra: for all . The rejection sampling scheme exactly corrects for the draft distribution's divergence from the target. This is why speculative decoding is "lossless" - it changes latency but not output quality.
Q2: Your speculative decoding deployment is showing lower-than-expected acceptance rates (alpha = 0.6 vs expected 0.8). What are the likely causes and how do you debug them?
Answer: Several root causes are worth investigating systematically.
First, check tokenizer alignment. Verify that draft_tokenizer.vocab_size == target_tokenizer.vocab_size and spot-check that a few sample sentences tokenize identically. Mismatched tokenizers produce systematically low acceptance because token probabilities are being compared across different vocabularies.
Second, examine prompt distribution. Low acceptance often indicates the traffic has shifted to harder-to-predict content. Query "what percentage of requests are code vs. creative vs. factual?" If creative writing went from 10% to 40% of traffic, aggregate alpha drops significantly.
Third, check temperature settings. If the target model is being sampled at temperature 0.5 (sharp distribution) but the draft model at temperature 1.0 (flat distribution), the draft will propose low-probability tokens more often than the target would. Ensure both models use the same sampling parameters.
Fourth, check model version alignment. If target was updated to a finetuned checkpoint but draft is still the base model, the distributions diverged and alpha will drop.
To debug: log per-request alpha values with the request category. Plot the distribution. Look for bimodal patterns (some requests with alpha=0.90, others with alpha=0.40) which indicate category-based variation rather than a systematic bug.
Q3: Explain the hardware efficiency argument for why speculative decoding works. Why does processing k+1 tokens in the verify phase cost approximately the same as 1 token in standard decoding?
Answer: The cost of an LLM forward pass during decoding is dominated by weight reads from HBM, not compute. For each transformer layer, the dominant operation is multiplying the current activation vector (shape: [batch, seq, d_model]) against weight matrices (shape: [d_model, d_ffn]).
In standard decode (seq=1): we load the weight matrix once from HBM, perform a matrix-vector multiply (1 row of output). Arithmetic intensity: 1 FLOP/byte.
In speculative verify (seq=k+1): we load the same weight matrix once from HBM, perform a matrix-matrix multiply (k+1 rows of output). Arithmetic intensity: (k+1) FLOPs/byte.
Because HBM read is the bottleneck (not compute), the total time is proportional to "bytes read" not "FLOPs performed". Bytes read is the same for seq=1 and seq=k+1 (same weight matrix). So the wall-clock time for the verify forward pass is approximately the same as for a standard single-token decode pass.
There is a small additional cost: the attention computation in the verify pass involves a slightly larger KV cache, and there's more data written back to memory. For k=5 and typical sequence lengths, this adds 5-10% overhead to the verify pass. But processing k+1 positions for the cost of 1 remains the core efficiency gain.
The limits: at very large (k > 20), the matrix multiplications become large enough to become compute-bound, and the cost-per-verified-token starts rising. In practice, is the sweet spot where the verify pass is still effectively memory-bandwidth-bound.
Q4: How does Medusa differ from standard draft-then-verify speculative decoding, and when would you prefer one over the other?
Answer: Standard speculative decoding uses a separate, independently trained draft model that generates candidate tokens autoregressively. Medusa attaches multiple small feedforward "heads" to the target model, each trained to predict a future token position directly from the target model's hidden states.
The key differences:
Memory: Medusa requires only one model in memory (the target plus small heads). Standard speculative requires two models. For a 70B target + 7B draft, standard needs 154 GB vs Medusa's 140.5 GB.
Training required: Medusa heads need fine-tuning on the target model's hidden states. Standard speculative works with any compatible pre-trained smaller model from the same family.
Acceptance rates by task: Standard speculative with a same-family draft typically achieves higher acceptance rates on general tasks (the draft model has more capacity and expressiveness). Medusa heads are simpler networks and may achieve slightly lower alpha for complex reasoning tasks. However, for tasks where the model is highly predictable (code completion, template-following), Medusa performs comparably.
Operational complexity: Standard speculative requires serving infrastructure for two separate models, balancing their memory placement, and handling failures in either. Medusa is a single model with minor overhead - much simpler to deploy and monitor.
Prefer Medusa when: you can't afford the memory for two models, you want simplicity, and your workload is structured/predictable. Prefer standard speculative when: you have memory budget, you want maximum acceptance rates across diverse tasks, or you want to swap draft models to experiment with different quality-speed tradeoffs without retraining.
Q5: What happens to speculative decoding performance when you increase the temperature setting? Give the intuition and the math.
Answer: Higher temperature reduces acceptance rates, reducing speculative decoding's effectiveness.
The intuition: temperature scales logits before softmax, flattening the probability distribution. At temperature 0 (greedy), the target model picks one token with probability 1.0. A draft model that's even slightly wrong gets rejected with certainty. But at temperature 1.0, the target has a smoother distribution, and draft tokens that are "close to correct" get accepted more often. Counterintuitively, moderate temperature (0.7-1.0) is actually better for speculative decoding than low temperature.
The math: acceptance probability for a draft token is:
Where and are target and draft distributions at temperature . At very high temperature (tau = 5), both distributions approach uniform. The ratio approaches 1 for most tokens, and acceptance rate approaches 1.0. But generating near-random tokens isn't useful.
At very low temperature (tau approaching 0, greedy), concentrates on a single token. If the draft picks that exact token, acceptance is 1.0; if it picks anything else, acceptance is 0.0. The acceptance rate equals the probability that the draft picks the greedy token.
Empirically: for code generation at temperature 0.2, alpha drops to 0.70-0.80 because the target is sharp and demands exact token matches. At temperature 0.8, alpha rises to 0.85-0.92 because the target is more tolerant. This is why speculative decoding often works better for generation tasks (temperature > 0.5) than for precise completion tasks (temperature = 0).
Q6: A team is debating whether to use speculative decoding or continuous batching to improve their LLM serving infrastructure. The workload is 500 req/s with average 200 token prompts and 300 token outputs. Walk through how you would make this decision.
Answer: This is a throughput-dominated workload at 500 req/s. The decision framework:
First, calculate the required compute. 500 req/s at 300 output tokens = 150,000 tokens/s needed throughput. For LLaMA-2 70B on 4xA100 with standard continuous batching, throughput is roughly 3,000-5,000 tokens/s. You need roughly 30-50 A100s just to meet demand.
Second, assess where the bottleneck is. At 500 req/s, continuous batching will naturally form large batches (50-100 concurrent requests). At those batch sizes, the GPU is compute-bound (high arithmetic intensity), not memory-bandwidth-bound. Speculative decoding's benefit is specifically in the memory-bandwidth-bound regime.
Third, run targeted benchmarks. Measure throughput at batch size 50+ with and without speculative decoding on your actual workload mix. If speculative decoding reduces throughput at high batch sizes (likely), it's the wrong tool.
Fourth, consider the use cases. Is latency the bottleneck (users waiting for slow responses) or throughput (the system can't keep up with 500 req/s)? If the complaint is "I have to queue requests", continuous batching and more hardware solve throughput. If the complaint is "individual responses are too slow", speculative decoding helps.
Recommendation for this specific workload: prioritize continuous batching with properly tuned max_num_batched_tokens and max_num_seqs. At 500 req/s you need scale-out, not per-request optimization. Speculative decoding could complement this for latency-sensitive tier-1 users if you can afford the memory overhead of the draft model, but it shouldn't be the primary optimization for a throughput-constrained system.
