Context Window Extension - YaRN, LongRoPE, LongLoRA
The Problem of the Frozen Clock
Imagine a model that learned to read a clock during training. It saw clocks showing times from 12:00 to 4:00, and became excellent at reading them. Now you show it a clock at 4:30. The hands are in positions it's never seen. It may guess correctly based on extrapolation, but the further past 4:00 you go, the worse it gets.
This is exactly the problem with RoPE-based models at long contexts. A model trained on 4K tokens has seen rotation angles for positions 0-4095. At position 8000, the rotation angle is literally outside the training distribution. The model must extrapolate from patterns it's never encountered.
The question that launched a series of papers starting in 2023 was: can we rescale the clock? Instead of asking the model to extrapolate to new rotation angles, can we compress positions so that even long sequences use only the rotation angles the model has already seen?
This is the core idea behind every context extension technique covered in this lesson.
Why the Naive Approach Fails
Just Set max_position_embeddings Higher
The most obvious approach: set max_position_embeddings = 32768 in the model config, load the model, and run inference on 32K token sequences. This costs nothing computationally.
The result: catastrophic perplexity degradation. At position 8K on a model trained to 4K, the RoPE rotation angles are twice what the model has seen. The model has learned patterns based on specific rotation angle values - "when the angle difference between Q and K in dimension pair 7 is X, the tokens are about 300 positions apart." At 8K context, angle difference X now corresponds to 600 positions. Every positional signal is misinterpreted. The model breaks.
The Concrete Perplexity Numbers
From Chen et al. (2023), Llama-7B (trained on 2K) evaluated at longer contexts with no modification:
| Eval Length | Perplexity |
|---|---|
| 2,048 | 6.8 |
| 4,096 | 7.2 |
| 6,144 | 12.9 |
| 8,192 | 48.3 |
| 16,384 | >1000 |
The perplexity explosion at 8K shows the model is effectively producing random text. This is not a gradual degradation - it's a cliff edge.
Position Interpolation - The First Fix
The Insight
Shouyuan Chen et al. (2023) published "Extending Context Window of Large Language Models via Positional Interpolation". Their approach is elegantly simple:
Instead of letting position indices go from 0 to 32767 (which takes the rotation angles out of training range), compress the position indices to fit within the training range:
For a model trained at 2K extending to 8K:
Now position 8000 becomes position 2000 - well within the training range.
import torch
def position_interpolation_rope(
original_cos: torch.Tensor,
original_sin: torch.Tensor,
train_context_len: int,
target_context_len: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Extend RoPE via position interpolation.
Maps positions [0, target_len) to [0, train_len) by scaling.
All positions remain within the trained distribution.
Parameters
----------
original_cos/sin : precomputed RoPE frequencies for train_context_len
train_context_len : original training context length
target_context_len : desired new context length
Returns
-------
New cos/sin tensors valid for target_context_len positions.
"""
scale = train_context_len / target_context_len # < 1.0 for extension
# Recompute frequencies with scaled position indices
# Original: freqs = outer(positions, inv_freq) where positions = [0, 1, ..., L_train-1]
# New: freqs = outer(positions * scale, inv_freq) where positions = [0, 1, ..., L_target-1]
# This is equivalent to multiplying all position indices by scale
new_positions = torch.arange(target_context_len, dtype=torch.float32) * scale
# new_positions range: [0, train_context_len - scale] - stays within training range
# Recompute from the underlying inv_freq
# (In practice, we recompute from scratch rather than from original_cos/sin)
return new_positions # Returns positions; caller applies to precompute cos/sin
def apply_position_interpolation(
head_dim: int,
train_len: int,
target_len: int,
theta: float = 10000.0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Complete position interpolation implementation."""
# Scale factor: compress target positions into training range
scale = train_len / target_len
# Compute frequencies
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
# Scaled positions: [0, 1, ..., target_len-1] * scale
positions = torch.arange(target_len, dtype=torch.float32) * scale
# Outer product: positions × inv_freq
freqs = torch.outer(positions, inv_freq)
return freqs.cos(), freqs.sin()
Why Position Interpolation Works (And Its Limitation)
Position interpolation works because all rotational angles remain within the training distribution. The model never encounters "new" rotation values.
The limitation: by compressing positions, you've made adjacent tokens closer together in rotation space. Position 0 and position 1 now differ by only scale times the original angle difference. The model must distinguish tokens that are "closer together" in rotation space than it was trained to handle.
This is the trade-off: extrapolation is eliminated but precision is reduced. Models with position interpolation generalize to long contexts but are slightly worse at distinguishing nearby tokens than the original model.
Chen et al. found this limitation is largely overcome by fine-tuning: just 1000 gradient steps on data containing long sequences is sufficient to recover most of the lost precision.
NTK-Aware Scaling
The NTK (Neural Tangent Kernel) scaling insight came from the community rather than a formal paper. The key observation: position interpolation scales all frequency dimensions uniformly, but high-frequency and low-frequency dimensions have different sensitivity to position errors.
The Frequency-Sensitivity Problem
High-frequency RoPE dimensions (fast-changing) are used to distinguish nearby tokens. If you halve their spacing by interpolation, you've made adjacent-token distinction much harder.
Low-frequency dimensions (slow-changing) encode long-range positional relationships. These are the dimensions that matter most for long-context extension - and they can tolerate more compression without quality loss, because they were already changing slowly.
NTK-aware scaling applies different scale factors to different frequency bands:
Or equivalently, instead of scaling positions, scale the theta parameter itself:
This ensures high-frequency dimensions are scaled less (preserving local token distinction) while low-frequency dimensions are scaled more (extending the long-range position range).
def ntk_aware_rope(
head_dim: int,
train_len: int,
target_len: int,
original_theta: float = 10000.0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
NTK-aware RoPE scaling for context extension.
Instead of scaling positions uniformly, scales the theta parameter
non-uniformly to preserve high-frequency (local) information while
extending low-frequency (global) range.
"""
# Effective theta after NTK scaling
scale = target_len / train_len
new_theta = original_theta * (scale ** (head_dim / (head_dim - 2)))
print(f"NTK scaling: theta {original_theta:.0f} -> {new_theta:.0f} "
f"(scale={scale:.1f}x, context {train_len}->{target_len})")
# Compute inv_freq with scaled theta
inv_freq = 1.0 / (new_theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
# Standard positions (no position scaling - theta absorbs the extension)
positions = torch.arange(target_len, dtype=torch.float32)
freqs = torch.outer(positions, inv_freq)
return freqs.cos(), freqs.sin()
# Example: extending Llama-2 7B (theta=10000, train_len=4096) to 32K
cos, sin = ntk_aware_rope(head_dim=128, train_len=4096, target_len=32768, original_theta=10000)
# NTK scaling: theta 10000 -> 470000 (scale=8.0x, context 4096->32768)
# Note: 470K ≈ 500K, which is close to what Llama-3 chose for theta!
NTK-aware scaling is available in many frameworks as a drop-in option, often labeled "rope_scaling = {'type': 'dynamic'}". It requires no fine-tuning and works immediately at inference, though performance at very long contexts benefits from fine-tuning.
YaRN - Yet Another RoPE extensioN
YaRN (Peng et al. 2023, "YaRN: Efficient Context Window Extension of Large Language Models") is the most principled and widely-adopted context extension method. It's the approach used for Llama-3.1's extension to 128K.
The Core Insight: Ramp Function
YaRN observes that NTK-aware scaling is still too aggressive for high-frequency dimensions. Instead of a single scaling formula, YaRN uses a ramp function that smoothly transitions between:
- No scaling for high-frequency dimensions (where the original training is still adequate)
- Full interpolation scaling for low-frequency dimensions (which need the most extension)
def yarn_rope(
head_dim: int,
train_len: int,
target_len: int,
original_theta: float = 10000.0,
beta_slow: float = 1.0, # Low-frequency threshold
beta_fast: float = 32.0, # High-frequency threshold
scale: float | None = None,
mscale: float = 1.0, # Attention magnitude scaling
) -> tuple[torch.Tensor, torch.Tensor]:
"""
YaRN: Yet Another RoPE extensioN.
Applies a ramp function that scales different RoPE frequency bands
differently:
- High-frequency (local): no scaling (extrapolate)
- Transition: smooth ramp
- Low-frequency (long-range): full interpolation scaling
Parameters
----------
beta_slow : wavelength below which to apply full interpolation
beta_fast : wavelength above which to apply no scaling (pure extrapolation)
mscale : temperature factor for attention scaling (reduces magnitude)
"""
if scale is None:
scale = target_len / train_len
print(f"YaRN: extending {train_len} -> {target_len} (scale={scale:.1f}x)")
inv_freq_original = 1.0 / (original_theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
# Compute wavelength for each frequency band (in positions)
wavelength = 2 * torch.pi / inv_freq_original
# YaRN ramp function:
# - wavelength < beta_fast * train_len: no scaling (high-freq)
# - wavelength > beta_slow * train_len: full interpolation (low-freq)
# - in between: smooth linear ramp
low_freq_factor = train_len / wavelength # how many cycles per training window
high_freq_factor = low_freq_factor
# For each dimension, compute the interpolation factor
# Between beta_fast and beta_slow: linear interpolation
ramp = torch.zeros_like(inv_freq_original)
for i in range(len(inv_freq_original)):
wl = wavelength[i].item()
if wl < beta_fast * train_len:
# High frequency: no scaling (pure extrapolation)
ramp[i] = 0.0
elif wl > beta_slow * train_len:
# Low frequency: full scaling (pure interpolation)
ramp[i] = 1.0
else:
# Transition: smooth linear ramp
ramp[i] = (wl - beta_fast * train_len) / ((beta_slow - beta_fast) * train_len)
# Apply ramp: blend between original frequency and scaled frequency
# ramp=0 -> use original (extrapolate), ramp=1 -> scale by 1/scale (interpolate)
inv_freq_scaled = inv_freq_original / scale
inv_freq_yarn = (1 - ramp) * inv_freq_original + ramp * inv_freq_scaled
# Compute frequencies
positions = torch.arange(target_len, dtype=torch.float32)
freqs = torch.outer(positions, inv_freq_yarn)
# YaRN attention temperature scaling
# Corrects for the fact that scaled attention has different magnitude distribution
attention_scale = mscale
# (In practice, this is applied to the attention computation, not here)
return freqs.cos(), freqs.sin()
# How Llama-3.1 uses YaRN-like extension:
# Base model: Llama-3 (8K context, theta=500000)
# Extended: Llama-3.1 (128K context)
# Scale factor: 128K / 8K = 16x
# Uses modified YaRN with custom beta values tuned for Llama-3's theta
yarn_cos, yarn_sin = yarn_rope(
head_dim=128,
train_len=8192,
target_len=131072, # 128K
original_theta=500000,
beta_slow=1.0,
beta_fast=32.0,
scale=16.0,
)
YaRN's Key Hyperparameters
beta_fast: The wavelength threshold below which frequencies receive no scaling. Set to 32 by Peng et al. for most experiments. High-frequency dimensions (wavelength shorter than 32 × train_len) are kept as-is because they're still within training distribution for local attention patterns.
beta_slow: The wavelength threshold above which frequencies receive full interpolation. Set to 1 by Peng et al. Low-frequency dimensions with wavelengths longer than 1 × train_len are the ones most in need of extension.
mscale: An attention magnitude correction factor. When frequencies are scaled, attention scores change in magnitude. mscale compensates, typically set to .
YaRN vs Competitors at 128K
From Peng et al. 2023 evaluation on Llama-2 extended to 128K:
| Method | Perplexity at 128K | Fine-tuning required? |
|---|---|---|
| No extension | >1000 | No |
| Position interpolation | 12.4 | Yes (1000+ steps) |
| NTK-aware scaling | 9.8 | No (but better with fine-tune) |
| YaRN | 8.1 | Yes (400 steps sufficient) |
| Original (at 4K) | 7.5 | - |
YaRN nearly matches the original model's perplexity at 128K with only 400 fine-tuning steps on long-context data.
LongRoPE - Progressive 2M Context Extension
LongRoPE (Ding et al. 2024, "LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens") extends the YaRN approach to extreme context lengths (1M-2M tokens).
The Innovations
Progressive extension: Instead of extending from 4K to 128K in one step, LongRoPE extends progressively: 4K → 8K → 32K → 128K → 256K → 2M. Each step requires fine-tuning but less than extending in one jump.
Non-uniform scaling: Like YaRN, LongRoPE applies different scaling factors to different frequency bands. But rather than using a hand-crafted ramp function, LongRoPE searches for the optimal per-dimension scaling factors using an evolutionary algorithm.
Two-stage extension: (1) Find optimal scaling factors by optimizing on a perplexity objective with a small validation set. (2) Fine-tune with the optimal scaling factors.
This search-based approach finds better scaling factors than any hand-designed formula, at the cost of a hyperparameter search step.
LongLoRA - Efficient Fine-Tuning for Long Context
LongLoRA (Chen et al. 2023, "LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models") addresses the training efficiency problem: even with YaRN or position interpolation, fine-tuning at 128K is expensive because the standard training pipeline requires O(n²) attention compute.
The Shifted Sparse Attention Trick
LongLoRA introduces shifted sparse attention (S²-Attn) during fine-tuning:
- During fine-tuning: each token attends only to a local window of nearby tokens (sparse attention) with shifted windows between attention heads
- During inference: use standard full attention
Standard attention (training):
Each token attends to ALL tokens (O(n²) compute)
LongLoRA S²-Attn (training):
Head 0: attends to tokens in windows [0:w], [w:2w], [2w:3w], ...
Head 1: attends to tokens in windows [w/2:3w/2], [3w/2:5w/2], ... (shifted by w/2)
Head 2: same as head 0
Head 3: same as head 1
...
This creates overlapping windows where information crosses boundaries,
while reducing compute from O(n²) to O(n·w) where w is the window size.
The key insight: shifted windows allow information from context boundaries to propagate, making sparse attention during training effective for teaching the model to use long contexts. At inference, full attention is restored.
def shifted_sparse_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
window_size: int = 512,
shift: bool = True,
) -> torch.Tensor:
"""
LongLoRA's Shifted Sparse Attention for training efficiency.
Divides sequence into non-overlapping windows.
Alternates heads between shifted and non-shifted windows.
query/key/value: (batch, n_heads, seq_len, head_dim)
"""
import torch.nn.functional as F
batch, n_heads, seq_len, head_dim = query.shape
scale = head_dim ** -0.5
# Split heads into two groups: shifted and non-shifted
half_heads = n_heads // 2
# Group 1: standard windows (no shift)
q1, k1, v1 = query[:, :half_heads], key[:, :half_heads], value[:, :half_heads]
# Group 2: shifted by window_size // 2
shift_amount = window_size // 2
q2 = torch.roll(query[:, half_heads:], shifts=-shift_amount, dims=2)
k2 = torch.roll(key[:, half_heads:], shifts=-shift_amount, dims=2)
v2 = torch.roll(value[:, half_heads:], shifts=-shift_amount, dims=2)
def windowed_attention(q, k, v):
"""Apply attention within non-overlapping windows."""
b, h, s, d = q.shape
n_windows = s // window_size
# Reshape to (batch * n_heads * n_windows, window_size, head_dim)
q_win = q.view(b, h, n_windows, window_size, d).reshape(b * h * n_windows, window_size, d)
k_win = k.view(b, h, n_windows, window_size, d).reshape(b * h * n_windows, window_size, d)
v_win = v.view(b, h, n_windows, window_size, d).reshape(b * h * n_windows, window_size, d)
scores = torch.bmm(q_win, k_win.transpose(-2, -1)) * scale
weights = F.softmax(scores, dim=-1)
out = torch.bmm(weights, v_win)
return out.reshape(b, h, n_windows, window_size, d).view(b, h, s, d)
# Apply windowed attention to each group
out1 = windowed_attention(q1, k1, v1)
out2 = windowed_attention(q2, k2, v2)
# Undo shift for group 2
out2 = torch.roll(out2, shifts=shift_amount, dims=2)
# Concatenate along head dimension
return torch.cat([out1, out2], dim=1)
LongLoRA's Training Efficiency
On Llama-2-7B:
| Method | Context | Training Memory | Training Time |
|---|---|---|---|
| Full fine-tuning | 4K | 55 GB | 1× |
| Full fine-tuning | 32K | OOM (8× A100 80GB) | - |
| LongLoRA (S²-Attn + LoRA) | 32K | 25 GB | 1× |
| LongLoRA (S²-Attn + LoRA) | 100K | 56 GB | 2.6× |
LongLoRA enables 32K context fine-tuning on hardware that can't even load the model for standard 32K fine-tuning.
How Llama-3.1 Extended From 8K to 128K
Meta's recipe for Llama-3.1's context extension is a practical example of combining multiple techniques:
The data mixture for long-context pre-training is critical:
- Long documents from Common Crawl
- Full books (Books3 dataset)
- Long-form code (complete repositories)
- Synthetic long-context tasks (multi-document QA, long-range reasoning)
Data mixture ratios matter: too much long-document data degrades short-context performance. Meta found roughly 5% long-context data maintained quality across all context lengths.
Practical Code: Loading Extended Models
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
def load_extended_model(
model_id: str,
target_context_len: int | None = None,
rope_scaling_type: str = "yarn", # "linear", "dynamic" (NTK), "yarn", "longrope"
):
"""
Load a model with optional context window extension.
For models that already support long context (Llama-3.1, Claude, etc.):
just load normally - context extension is baked in.
For models needing extension (e.g., extending Llama-3-8B from 8K to 32K):
modify the config's rope_scaling parameters.
"""
config = AutoConfig.from_pretrained(model_id)
original_len = config.max_position_embeddings
if target_context_len and target_context_len > original_len:
scale = target_context_len / original_len
print(f"Extending {model_id} from {original_len} to {target_context_len} ({scale:.1f}x)")
if rope_scaling_type == "linear":
config.rope_scaling = {"type": "linear", "factor": scale}
elif rope_scaling_type == "dynamic": # NTK-aware
config.rope_scaling = {"type": "dynamic", "factor": scale}
elif rope_scaling_type == "yarn":
config.rope_scaling = {
"type": "yarn",
"factor": scale,
"original_max_position_embeddings": original_len,
}
config.max_position_embeddings = target_context_len
model = AutoModelForCausalLM.from_pretrained(
model_id,
config=config,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2", # Mandatory for long contexts
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
return model, tokenizer
# Example 1: Load Llama-3.1-8B (already extended to 128K by Meta)
model, tokenizer = load_extended_model("meta-llama/Meta-Llama-3.1-8B-Instruct")
# Example 2: Extend Llama-3-8B from 8K to 32K using YaRN (no fine-tuning)
# Note: zero-shot extension degrades quality; fine-tuning is recommended
model, tokenizer = load_extended_model(
"meta-llama/Meta-Llama-3-8B",
target_context_len=32768,
rope_scaling_type="yarn",
)
# Example 3: NTK-aware scaling for quick zero-shot extension
model, tokenizer = load_extended_model(
"mistralai/Mistral-7B-v0.1",
target_context_len=65536,
rope_scaling_type="dynamic",
)
Common Mistakes
:::danger Don't apply context extension without testing on your target context length Context extension methods reduce but don't eliminate perplexity degradation. A model extended with YaRN to 128K still performs worse at 128K than it does at 8K. Always evaluate on examples at your actual target context length. Don't assume that "support for 128K" means "optimal performance at 128K." :::
:::warning Fine-tuning data must include long-context examples If you fine-tune for context extension using only short documents, you'll apply the scaling correctly but the model won't learn to attend to distant context. Long-context fine-tuning data must include examples where relevant information is genuinely spread across the full context length. Synthetic long-context tasks (e.g., "given these 50 documents, which contains the answer to...") are often necessary to supplement natural long-document data. :::
:::tip Don't forget FlashAttention-2 for long context inference
Without FlashAttention-2, even a correctly configured long-context model will OOM or be prohibitively slow at 128K+ context. Always use attn_implementation="flash_attention_2" when loading models for long-context inference. Also check that your GPU supports it (Ampere or newer architecture).
:::
Interview Q&A
Q: Why does simply increasing max_position_embeddings fail, and what does position interpolation do differently?
A: Simply increasing max_position_embeddings causes positions beyond the training maximum to produce rotation angles that were never seen during training. The model has no learned behavior for these angles, causing catastrophic perplexity degradation. Position interpolation instead multiplies all position indices by train_len / target_len, compressing the new longer sequence into the trained position range. Position 10,000 in a 32K sequence becomes position 2,500 in a model trained on 8K - a position the model has actually seen. The tradeoff is reduced precision in distinguishing nearby tokens, which fine-tuning on long sequences can largely recover.
Q: What is the YaRN ramp function and why is it better than uniform position interpolation?
A: Uniform position interpolation scales all RoPE frequency bands by the same factor. This hurts high-frequency bands (which distinguish adjacent tokens) because they become "too close together" - the model's local token-distinction ability degrades. YaRN applies a ramp function that assigns different scaling factors to different frequency bands: high-frequency bands (short wavelength) receive little or no scaling (they extrapolate), while low-frequency bands (long wavelength, responsible for long-range position encoding) receive full interpolation scaling. This preserves local token-distinction ability while extending long-range position encoding, giving better results than uniform interpolation with the same fine-tuning budget.
Q: How does LongLoRA enable efficient fine-tuning at long contexts?
A: LongLoRA uses shifted sparse attention (S²-Attn) during training, where each token attends only to a local window rather than all other tokens. This reduces training compute from O(n²) to O(n × window_size), making 32K and even 100K context fine-tuning feasible on a single node. The windows are shifted differently across attention heads, creating overlapping coverage that allows information to propagate across window boundaries. At inference time, standard full attention is restored - S²-Attn is only used during training. LongLoRA combines S²-Attn with LoRA adapters rather than full fine-tuning, further reducing memory requirements.
Q: What data recipe does Meta use for Llama-3.1's extension to 128K?
A: Meta's recipe (from the Llama-3 technical report) combines: (1) Starting from Llama-3-8B which already uses theta=500,000, giving a wider base position range than Llama-2. (2) Applying YaRN-based frequency scaling for the 16× extension from 8K to 128K. (3) Continued pre-training on a mixture of long documents - approximately 5% of the training mixture - including Common Crawl long documents, books, and full code repositories. (4) Instruction fine-tuning with long-context examples at varying lengths. The 5% long-context ratio is calibrated to extend context without degrading short-context performance.
Q: What is NTK-aware scaling and how does it differ from standard position interpolation?
A: Both NTK-aware scaling and position interpolation aim to map target context positions into the trained position range. Position interpolation scales position indices directly by factor train_len / target_len. NTK-aware scaling instead modifies the theta parameter by a formula derived from the Neural Tangent Kernel analysis: theta_new = theta_original × (target_len / train_len)^(dim / (dim-2)). The key difference: NTK scaling automatically applies different effective scale factors to different frequency bands - high-frequency (short-wavelength) dimensions are scaled less than low-frequency dimensions. This makes NTK scaling more robust to no-fine-tuning use (immediately applicable at inference) compared to position interpolation, which hurts high-frequency precision without fine-tuning to compensate.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Context Window Extension demo on the EngineersOfAI Playground - no code required.
:::
