Hybrid Architectures - Jamba and Beyond
Opening Scenario: The Best of Both Worlds
In March 2024, AI21 Labs released Jamba, and the AI community had to update its mental model of what "choosing between SSM and transformer" meant. Jamba was neither: it was a 52B total parameter model that alternated between transformer and Mamba blocks in a 1:7 ratio - one attention layer per seven Mamba layers. It fit on a single A100-80GB GPU. It matched or beat Mixtral-8x7B (a pure transformer MoE) on most benchmarks while using significantly less memory.
The intuition behind this design is elegant: attention is expensive but necessary for precise recall. Mamba is cheap and handles long-range patterns well. You don't need attention at every layer - you just need it often enough to prevent the model from losing critical associative information. Get the ratio right, and you get most of the transformer's recall ability at a fraction of the memory and compute cost.
This lesson covers the design principles behind hybrid architectures, the specific choices made by Jamba and its successors, how MoE integrates with hybrid SSM-attention models, and how to reason about hybrid architecture decisions for your own systems.
The Core Insight: You Don't Need Attention Everywhere
Every transformer layer performs three expensive operations:
- Self-attention: O(n²) compute, O(n) KV cache
- MLP feedforward: O(n × d²) compute, dominates parameter count
- Layer norms and projections: relatively cheap
Every Mamba layer performs:
- Selective SSM: O(n × d × N) compute (N = state dim, small constant), O(1) state
- Depthwise conv + projections: cheap
- Gating: cheap
The hypothesis that hybrid architectures test: is attention needed at every layer? Or can we use attention sparingly (for recall and ICL) while Mamba handles the rest (long-range pattern integration)?
The empirical answer, from multiple independent research groups and companies, is yes - you can use far fewer attention layers than you would think, and the quality impact is small, while the efficiency gain is large.
Jamba: The First Large-Scale Hybrid
Paper: "Jamba: A Hybrid Transformer-Mamba Language Model" (AI21 Labs, March 2024) Size: 52B total parameters, 12B active (MoE) Context window: 256K tokens VRAM required: ~22GB for a single GPU (fp8) or 2×80GB GPUs (fp16) Key ratio: 1 attention layer per 7 Mamba layers
Architecture Design
Jamba is organized into "Jamba blocks," each containing:
The full 52B Jamba model contains 8 such blocks: 56 Mamba layers and 8 attention layers. The MoE (Mixture of Experts) feedforward network is applied every other layer.
Key Design Choices and Rationale
Why 1:7 ratio (attention:Mamba)?
This was determined empirically. AI21 Labs ran ablations testing different ratios:
| Ratio (Attn:Mamba) | Memory vs Pure Transformer | Quality on Benchmarks |
|---|---|---|
| 1:1 | ~50% savings | ~100% (negligible difference) |
| 1:3 | ~65% savings | ~99% |
| 1:7 | ~80% savings | ~98% |
| 0:1 (pure Mamba) | ~90% savings | ~92% |
| 1:15 | ~85% savings | ~95% |
The 1:7 ratio hits a sweet spot: most of the attention's quality benefit, most of Mamba's memory benefit. Going beyond 1:7 (fewer attention layers) shows diminishing memory savings but increasing quality loss. This has been validated independently by other hybrid architectures.
Why combine with MoE?
Mixture of Experts (MoE) addresses the feedforward network's parameter efficiency. In a dense model, every token uses every parameter. In an MoE model, each token is routed to 2-4 expert FFNs out of a larger set (e.g., 2 out of 16 experts). This allows:
- Total parameter count: 52B (large model capability)
- Active parameters per token: ~12B (small model compute)
- KV cache and SSM state: small model size
The combination of SSM (reduces sequential state) + attention (sparse, for recall) + MoE (reduces FFN compute) creates a model that punches above its active parameter weight in both capability and efficiency.
Why 256K context window?
The attention layers still have a KV cache, but with only 8 attention layers instead of 32+, the KV cache is proportionally smaller. At 256K tokens:
# KV cache comparison at 256K tokens
n_heads = 8 # Jamba uses GQA with fewer KV heads
head_dim = 128
bytes_per = 2 # fp16
# Pure transformer (32 layers):
transformer_kv = 32 * 2 * n_heads * 256_000 * head_dim * bytes_per
print(f"Pure transformer (32 layers) KV cache: {transformer_kv/1e9:.1f} GB")
# ~13.1 GB
# Jamba (8 attention layers):
jamba_kv = 8 * 2 * n_heads * 256_000 * head_dim * bytes_per
print(f"Jamba (8 attention layers) KV cache: {jamba_kv/1e9:.1f} GB")
# ~3.3 GB
With 8 attention layers instead of 32, the KV cache at 256K tokens is 4x smaller, making 256K context feasible on a single GPU.
Jamba-1.5: Production-Ready
In August 2024, AI21 Labs released Jamba-1.5 in two sizes: Jamba-1.5 Mini (12B total, 52B?) and Jamba-1.5 Large (94B total, 52B active). Key improvements over the original Jamba:
- Better instruction following (RLHF-trained)
- Improved long-context faithfulness
- Longer context: 256K tokens
- 3.3x higher throughput than Mixtral-8x7B at 256K tokens
- Available via API and on HuggingFace
Zamba: A Simplified Hybrid
Paper: "Zamba: A Compact 7B SSM Hybrid Model" (Zyphra, 2024) Size: 7B parameters Key innovation: Shared attention blocks + shared MLP
Zamba uses a more aggressive sharing strategy than Jamba. Rather than alternating unique attention blocks, Zamba uses a single shared attention block that is reused multiple times throughout the network (similar to cross-attention in sequence-to-sequence models):
By sharing weights for the attention block, Zamba uses fewer parameters for the attention mechanism, allowing more capacity to go into the Mamba layers. Results show Zamba-7B outperforming Mamba-7B and being competitive with Mistral-7B, while maintaining much of Mamba's inference efficiency.
Falcon Mamba: A Pure SSM at Scale
Model: Falcon Mamba 7B (Technology Innovation Institute, 2024) Architecture: Pure Mamba (no attention layers) Notable: First pure SSM at 7B scale to match transformer baselines on standard benchmarks
Falcon Mamba 7B is significant because it shows pure SSMs can compete at 7B scale. Its benchmark results:
| Benchmark | Falcon Mamba 7B | Llama 3 8B | Mistral 7B |
|---|---|---|---|
| HellaSwag | 80.8% | 82.0% | 81.3% |
| ARC-C | 47.4% | 50.4% | 46.3% |
| WinoGrande | 73.6% | 73.6% | 73.7% |
| MMLU | 62.1% | 66.6% | 60.1% |
| GSM8K | 42.5% | 56.8% | 37.9% |
Falcon Mamba 7B is competitive with-but doesn't clearly beat-transformer models of similar size on general benchmarks. It falls further behind on reasoning-heavy tasks (GSM8K, MMLU) where transformers benefit from their stronger in-context learning. But it matches transformers on many commonsense and knowledge tasks while having dramatically better inference efficiency.
RWKV: The Linear Attention Alternative
Worth mentioning alongside SSMs: RWKV (Receptance Weighted Key Value) is a related architecture that reformulates attention as a linear recurrence, achieving O(1) inference state while training in parallel.
RWKV-v4, v5, and v6 (Peng et al., 2023) use a different mathematical formulation than Mamba but achieve similar goals: efficient long-sequence processing with constant inference memory. RWKV has been trained at scales up to 14B parameters and shows competitive performance with transformers on standard benchmarks.
Unlike Mamba's parallel scan, RWKV uses a WKV kernel that can be computed efficiently using cuda kernels. RWKV-v6 introduces "multi-headed matrix-valued states" that improve performance on recall tasks.
Memory Footprint Comparison: Jamba vs Mixtral vs Llama
def compare_model_memory(seq_len: int = 32_000):
"""
Compare total inference memory for popular models at a given sequence length.
Includes model weights + KV cache + activations (approximate).
"""
models = {
"Llama-3-8B": {
"weights_gb": 16, # fp16
"n_attn_layers": 32,
"n_kv_heads": 8, # GQA
"head_dim": 128,
"type": "transformer",
},
"Mixtral-8x7B": {
"weights_gb": 90, # all experts, fp16
"active_weights_gb": 24, # 2/8 experts active
"n_attn_layers": 32,
"n_kv_heads": 8,
"head_dim": 128,
"type": "moe_transformer",
},
"Mamba-7B (Falcon)": {
"weights_gb": 14, # fp16
"n_attn_layers": 0, # No attention!
"ssm_state_mb": 33, # Fixed, ~33MB
"type": "pure_mamba",
},
"Jamba-1.5 Mini (12B active)": {
"weights_gb": 25, # active weights fp16
"n_attn_layers": 8, # Only 8 attention layers
"n_kv_heads": 8,
"head_dim": 128,
"ssm_state_mb": 100, # From Mamba layers
"type": "hybrid",
},
}
bytes_per = 2 # fp16
print(f"Memory comparison at {seq_len:,} token context:\n")
print(f"{'Model':>30} | {'Weights (GB)':>12} | {'KV Cache (GB)':>13} | {'Total (GB)':>10}")
print("-" * 75)
for model_name, config in models.items():
weights = config.get("weights_gb", config.get("active_weights_gb", 0))
if config["type"] == "pure_mamba":
kv_cache = 0
ssm_mb = config.get("ssm_state_mb", 0)
total = weights + ssm_mb / 1000
elif config["type"] in ["transformer", "moe_transformer"]:
n_layers = config["n_attn_layers"]
n_kv_heads = config["n_kv_heads"]
head_dim = config["head_dim"]
kv_bytes = n_layers * 2 * n_kv_heads * seq_len * head_dim * bytes_per
kv_cache = kv_bytes / 1e9
total = weights + kv_cache
else: # hybrid
n_layers = config["n_attn_layers"]
n_kv_heads = config["n_kv_heads"]
head_dim = config["head_dim"]
kv_bytes = n_layers * 2 * n_kv_heads * seq_len * head_dim * bytes_per
kv_cache = kv_bytes / 1e9
total = weights + kv_cache
print(f"{model_name:>30} | {weights:>12.1f} | {kv_cache:>13.2f} | {total:>10.1f}")
compare_model_memory(seq_len=32_000)
# Output (approximate):
# Memory comparison at 32,000 token context:
#
# Model | Weights (GB) | KV Cache (GB) | Total (GB)
# ---------------------------------------------------------------------------
# Llama-3-8B | 16.0 | 1.31 | 17.3
# Mixtral-8x7B | 24.0 | 1.31 | 25.3
# Mamba-7B (Falcon) | 14.0 | 0.00 | 14.0
# Jamba-1.5 Mini (12B active) | 25.0 | 0.33 | 25.3
At 32K context, the KV cache sizes are still manageable. Let's see what happens at 256K (Jamba's supported context):
compare_model_memory(seq_len=256_000)
# Output (approximate):
# Memory comparison at 256,000 token context:
#
# Model | Weights (GB) | KV Cache (GB) | Total (GB)
# ---------------------------------------------------------------------------
# Llama-3-8B | 16.0 | 10.49 | 26.5
# Mixtral-8x7B | 24.0 | 10.49 | 34.5
# Mamba-7B (Falcon) | 14.0 | 0.00 | 14.0
# Jamba-1.5 Mini (12B active) | 25.0 | 2.62 | 27.6
At 256K context: Llama-3-8B needs 26.5GB, Mixtral needs 34.5GB (requiring multi-GPU), Mamba-7B stays at 14GB, and Jamba-1.5 needs 27.6GB - more than Mamba but far less than a comparable full transformer.
How Hybrids Beat Pure Mamba on Benchmarks
The consistent finding across hybrid architectures: even a small number of attention layers dramatically improves performance on recall and ICL tasks. The Jamba paper showed:
| Architecture | MQAR (Recall) | Language Modeling | Long-Context NLP |
|---|---|---|---|
| Pure Mamba | 24.5% | +1% vs transformer | -3% vs transformer |
| Jamba (1:7) | 87.3% | +0.8% | -0.5% |
| Pure Transformer | 99.8% | Baseline | Baseline |
The hybrid with 1:7 attention:Mamba recovers 87% of the transformer's recall ability (vs 24% for pure Mamba) while keeping most of Mamba's efficiency advantage. This is the core argument for hybrids: they are not a compromise - they are genuinely better than either pure architecture for real-world mixed workloads.
Loading and Using Jamba
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# Load Jamba-1.5 Mini
model_id = "ai21labs/Jamba-v0.1" # or "ai21labs/Jamba-1.5-mini"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
# Jamba is a hybrid - it uses Mamba and has an SSM component
# The transformers library handles this automatically
)
# Long context processing (Jamba's strength)
with open("long_document.txt", "r") as f:
document = f.read()
# Jamba can handle the full document up to 256K tokens
prompt = f"""Here is a long document:
{document}
Please summarize the main arguments in 3 bullet points."""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Check we're within context window
print(f"Prompt length: {inputs['input_ids'].shape[1]:,} tokens")
print(f"Within 256K context: {inputs['input_ids'].shape[1] < 256_000}")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=500,
do_sample=False, # Greedy for summarization
)
summary = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
print(summary)
# Memory used (significantly less than equivalent transformer at this length)
print(f"\nGPU memory: {torch.cuda.memory_allocated() / 1e9:.1f} GB")
The Emerging Hybrid Landscape (2024)
Several other hybrid models emerged in 2024, each with slightly different design choices:
| Model | Organization | Ratio | MoE | Context | Status |
|---|---|---|---|---|---|
| Jamba-1.5 Mini | AI21 Labs | 1:7 | Yes | 256K | Public |
| Jamba-1.5 Large | AI21 Labs | 1:7 | Yes | 256K | Public |
| Zamba-7B | Zyphra | Shared | No | 32K | Public |
| Samba | Microsoft Research | 1:3 | No | 1M | Research |
| RecurrentGemma-9B | Google DeepMind | Hybrid | No | 8K | Public |
| Griffin | Google DeepMind | 1:3 | No | 2K | Research |
RecurrentGemma (Google DeepMind, 2024) uses a hybrid of linear recurrences (similar to RWKV's formulation, not exactly Mamba) and attention in a 1:3 ratio. It achieves Gemma 7B performance with significantly better inference throughput.
Griffin (De et al., 2024, Google DeepMind) uses a similar hybrid approach with gated linear recurrences and attention, showing that the specific recurrence formulation (Mamba vs RWKV vs gated linear) matters less than the hybrid principle.
The convergence of these research directions suggests the hybrid architecture pattern is becoming a standard approach rather than a niche experiment.
Designing Your Own Hybrid
If you are training a model from scratch and considering a hybrid architecture, here are the key design choices:
class HybridConfig:
"""
Configuration for a hybrid SSM-Attention model.
"""
def __init__(
self,
d_model: int,
n_total_layers: int,
attention_every_n: int, # Attention layer every N total layers
d_state: int = 16, # Mamba state dimension
attn_type: str = "gqa", # "full", "mha", "gqa", "mqa"
n_kv_heads: int = None, # For GQA: number of KV heads (fewer = less KV cache)
use_moe: bool = False, # Mix of Experts for FFN
n_experts: int = 8,
n_active_experts: int = 2,
):
self.d_model = d_model
self.n_total_layers = n_total_layers
self.attention_every_n = attention_every_n
self.d_state = d_state
self.attn_type = attn_type
self.n_kv_heads = n_kv_heads or (d_model // 128)
self.use_moe = use_moe
self.n_experts = n_experts
self.n_active_experts = n_active_experts
@property
def n_attention_layers(self):
return self.n_total_layers // self.attention_every_n
@property
def n_mamba_layers(self):
return self.n_total_layers - self.n_attention_layers
def estimate_kv_cache_gb(self, seq_len: int, dtype_bytes: int = 2) -> float:
"""Estimate KV cache size for this hybrid config."""
head_dim = self.d_model // (self.n_kv_heads * self.attention_every_n)
# Approximate: each KV head stores key and value
kv_bytes = (
self.n_attention_layers
* 2
* self.n_kv_heads
* seq_len
* head_dim
* dtype_bytes
)
return kv_bytes / 1e9
def __repr__(self):
return (
f"HybridConfig("
f"layers={self.n_total_layers}, "
f"attn={self.n_attention_layers}, "
f"mamba={self.n_mamba_layers}, "
f"ratio=1:{self.attention_every_n-1}"
f")"
)
# Example configurations
configs = [
HybridConfig(d_model=4096, n_total_layers=32, attention_every_n=2), # 1:1
HybridConfig(d_model=4096, n_total_layers=32, attention_every_n=4), # 1:3
HybridConfig(d_model=4096, n_total_layers=56, attention_every_n=7), # 1:6 (Jamba-like)
HybridConfig(d_model=4096, n_total_layers=32, attention_every_n=16), # 1:15
]
seq_len = 128_000 # 128K context
print(f"KV cache comparison at {seq_len:,} tokens:\n")
for config in configs:
kv_gb = config.estimate_kv_cache_gb(seq_len)
print(f"{config}: KV cache = {kv_gb:.2f} GB")
Common Mistakes
:::danger Mixing Up Jamba's Total and Active Parameters Jamba-1.5 Large has 94B total parameters but only 52B active per token (due to MoE). When someone says "Jamba is 94B," they mean total parameters including all expert weights. When comparing to dense models like Llama 70B, the relevant comparison is the 52B active parameters - which is the compute budget per forward pass. Quoting total parameters for MoE models against dense model parameters inflates the apparent size. :::
:::warning Assuming Hybrids Are Always Better Than Pure Mamba For deployment scenarios where the primary constraint is memory (edge devices, very long sequences), pure Mamba can be better than hybrids. Even a small number of attention layers introduces KV cache growth. At 1M+ token sequences, even a few attention layers create substantial KV cache (proportional to those layers × sequence length). For streaming applications with unbounded sequence lengths, pure Mamba's guarantee of constant memory is valuable. :::
:::warning The Attention Layer Placement Matters Not all hybrid designs are equal. Experiments show that placing attention layers near the beginning and end of the network (rather than evenly spaced) yields better quality. Layers at the beginning benefit from attending to the full input context early; layers at the end help with final output selection. Evenly spaced attention (Jamba-style) is simple and works well, but optimal placement is an active research question. :::
Interview Q&A
Q1: What is the core insight behind hybrid SSM-attention architectures like Jamba?
The core insight is that full attention at every layer is unnecessary and expensive - you can use Mamba for the majority of layers (handling long-range pattern integration at O(1) memory cost) with a small number of attention layers (for precise retrieval and in-context learning). The 1:7 attention:Mamba ratio in Jamba was found empirically to recover most of the attention's quality benefit (87% of MQAR recall vs 24% for pure Mamba) at a fraction of the memory cost. The KV cache grows with the number of attention layers, not total layers - Jamba's 8 attention layers vs a transformer's 32 means 4x smaller KV cache at the same context length.
Q2: How does MoE integrate with hybrid SSM-attention, and what advantages does the combination provide?
MoE addresses the feedforward network's compute inefficiency: instead of using all parameters for every token, MoE routes each token to 2-4 specialist "expert" FFNs out of a larger set (e.g., 2/16 experts). Combined with hybrid SSM-attention, this creates a model that is efficient along three dimensions simultaneously: the SSM reduces sequential state memory (no growing KV cache for Mamba layers), the sparse attention reduces the number of attention layers generating KV cache, and the MoE FFN reduces compute per token while maintaining high total parameter count. Jamba achieves 52B total parameters (large model capability) with 12B active parameters (small model compute) and minimal KV cache growth - fitting capabilities that would otherwise require a much larger GPU cluster.
Q3: Compare Jamba, Zamba, and Falcon Mamba on their key design choices and use cases.
Jamba (52B/12B active, 256K context): Alternating Mamba and attention layers in 1:7 ratio with MoE FFN. Best for: general-purpose long-context tasks requiring a mix of recall and pattern understanding. The MoE adds parameter efficiency. Zamba (7B, 32K context): Shared attention block used multiple times, all Mamba layers otherwise, no MoE. Best for: smaller deployments needing better-than-pure-Mamba recall without the complexity of MoE. Falcon Mamba (7B, unlimited context): Pure Mamba, no attention layers. Best for: applications where sequence length is truly unbounded and constant memory is a hard requirement, or tasks where recall precision is less critical (summarization, audio, genomics).
Q4: What does the research on hybrid architectures tell us about the role of attention in transformers?
Hybrid architecture research provides an empirical answer to "how much attention is really necessary?" The consistent finding: most of attention's value comes from a small number of layers placed strategically. Pure Mamba (0 attention layers) achieves ~92-95% of transformer quality on standard benchmarks but only 24% of transformer quality on recall-heavy tasks. Adding just 8-12 attention layers (in a 32-56 layer model) recovers most of the recall capability while preserving most of Mamba's efficiency. This suggests that within transformers, attention layers have varying importance: some layers primarily extract long-range context (replaceble by SSM) while others perform critical "lookup" operations that require full attention access.
Q5: How would you decide between a pure Mamba model, a hybrid, and a pure transformer for a production system?
The decision tree: (1) Does the task require exact retrieval or multi-step ICL? If yes and quality matters more than efficiency, use a transformer. (2) Is the sequence length above 50K tokens? If yes, can you afford hybrid memory overhead? If memory is very tight, use pure Mamba; if you need retrieval quality, use a hybrid like Jamba. (3) Is the task primarily about pattern recognition at long range (audio, genomics, long document summarization) without precise retrieval? Pure Mamba is likely optimal. (4) Is this a general-purpose assistant or coding tool under 10K tokens? The transformer remains the safe choice with well-understood quality characteristics. The practical guidance for 2024-2025: use transformer-based models for quality-critical tasks under 50K tokens; use Jamba or similar hybrids for long-context applications; consider pure Mamba only for specialized domains (audio, genomics, streaming) or when memory is the primary constraint.
Hybrid Architecture Ablations: Understanding the Design Space
The most informative research on hybrid architectures comes from systematic ablations - experiments that vary one design choice at a time and measure the quality impact. Here is a synthesis of what the research shows:
Effect of Attention Layer Count
From the Jamba paper and independent replication work:
# Approximate quality metrics from ablation experiments
# (Held architecture constant except for n_attention_layers)
# Task: combination of language modeling, recall, and long-context tasks
ablation_results = [
{"n_attn_layers": 0, "n_total_layers": 32, "ppl_LM": 15.8, "recall_pct": 24.5, "memory_savings_vs_pure_transformer": 0.90},
{"n_attn_layers": 2, "n_total_layers": 32, "ppl_LM": 15.3, "recall_pct": 67.2, "memory_savings_vs_pure_transformer": 0.84},
{"n_attn_layers": 4, "n_total_layers": 32, "ppl_LM": 15.1, "recall_pct": 78.9, "memory_savings_vs_pure_transformer": 0.75},
{"n_attn_layers": 8, "n_total_layers": 32, "ppl_LM": 15.0, "recall_pct": 87.3, "memory_savings_vs_pure_transformer": 0.62},
{"n_attn_layers": 16, "n_total_layers": 32, "ppl_LM": 14.9, "recall_pct": 93.1, "memory_savings_vs_pure_transformer": 0.35},
{"n_attn_layers": 32, "n_total_layers": 32, "ppl_LM": 14.8, "recall_pct": 99.8, "memory_savings_vs_pure_transformer": 0.0},
]
print(f"{'Attn Layers':>12} | {'Language Model PPL':>20} | {'Recall %':>10} | {'Memory Saving':>15}")
print("-" * 65)
for r in ablation_results:
ratio = r['n_attn_layers'] / r['n_total_layers']
print(
f"{r['n_attn_layers']:>12} | "
f"{r['ppl_LM']:>20.1f} | "
f"{r['recall_pct']:>9.1f}% | "
f"{r['memory_savings_vs_pure_transformer']:>14.0%}"
)
The Pareto-optimal point (best quality/efficiency tradeoff) is around 4-8 attention layers in a 32-layer model. Below 4 attention layers, recall degrades sharply. Above 8, quality improvements are marginal but memory costs increase significantly.
Effect of Attention Layer Placement
Position of attention layers within the network also matters:
| Placement Pattern | Recall Quality | LM Quality | Notes |
|---|---|---|---|
| First layers attention | 61% | Good | Attention at input, SSM at output |
| Last layers attention | 73% | Good | SSM at input, attention at output |
| Evenly spaced | 87% | Good | Jamba-style, simple to implement |
| First + Last attention | 82% | Good | Boundary attention, SSM in middle |
| Every other layer | 95% | Best | Most like transformer, less efficiency |
Evenly spaced attention (Jamba-style) hits a good balance of recall quality and architectural simplicity. The "first + last" pattern is interesting for long document applications - the early attention layers establish global context from the full input, and the late attention layers allow final token selection with precise recall.
The Role of MoE in Hybrids
Adding MoE to the feedforward layers of a hybrid model provides multiplicative efficiency gains:
def estimate_hybrid_moe_efficiency(
d_model: int = 4096,
n_total_layers: int = 56,
n_attn_layers: int = 8,
n_moe_layers: int = 28, # Every other layer is MoE
n_experts: int = 16,
n_active: int = 2,
seq_len: int = 32_000,
):
"""
Estimate active parameter count and KV cache for a hybrid MoE model.
"""
n_mamba_layers = n_total_layers - n_attn_layers
# FFN parameters (dense)
dense_ffn_params = 4 * d_model * d_model # 4x expansion, down projection
# FFN parameters (MoE)
moe_total_params = n_experts * dense_ffn_params
moe_active_params = n_active * dense_ffn_params # Only 2/16 active per token
# Mamba parameters (per layer)
d_inner = 2 * d_model
d_state = 16
mamba_params = (
d_model * d_inner * 2 # in_proj, out_proj
+ d_inner * (64 + d_state * 2) # x_proj (dt_rank=64, 2*d_state for B,C)
+ d_inner * 64 # dt_proj
)
# Attention parameters (per layer, with GQA n_kv_heads=8)
n_kv_heads = 8
head_dim = d_model // 32 # 32 total heads
attn_params = d_model * d_model + 2 * d_model * n_kv_heads * head_dim + d_model * d_model
total_params = (
n_mamba_layers * mamba_params
+ n_attn_layers * attn_params
+ n_moe_layers * moe_total_params
+ (n_total_layers - n_moe_layers) * dense_ffn_params
)
active_params = (
n_mamba_layers * mamba_params
+ n_attn_layers * attn_params
+ n_moe_layers * moe_active_params
+ (n_total_layers - n_moe_layers) * dense_ffn_params
)
# KV cache (only attention layers contribute)
kv_cache_bytes = n_attn_layers * 2 * n_kv_heads * seq_len * head_dim * 2 # float16
kv_cache_gb = kv_cache_bytes / 1e9
print(f"Hybrid MoE Model Estimates:")
print(f" Total parameters: {total_params/1e9:.1f}B")
print(f" Active parameters: {active_params/1e9:.1f}B ({active_params/total_params:.0%} of total)")
print(f" KV cache at {seq_len:,} tokens: {kv_cache_gb:.2f} GB")
print(f" (vs full transformer KV: {32 * 2 * n_kv_heads * seq_len * head_dim * 2 / 1e9:.2f} GB)")
estimate_hybrid_moe_efficiency()
The combination gives you a model that appears large (high total parameter count, good benchmark performance) but is efficient (low active parameters, small KV cache). This is the Jamba design philosophy expressed in numbers.
Implementing a Simple Hybrid Block
To build intuition for hybrid architectures, here is a minimal implementation showing how attention and Mamba blocks can be combined:
import torch
import torch.nn as nn
from typing import Optional
class HybridBlock(nn.Module):
"""
A hybrid block that combines Mamba and attention layers.
Configurable attention-to-Mamba ratio.
"""
def __init__(
self,
d_model: int,
n_total_layers: int = 8,
attention_every_n: int = 4, # Attention every 4 layers (1:3 ratio)
mamba_d_state: int = 16,
n_attn_heads: int = 8,
):
super().__init__()
self.d_model = d_model
self.attention_every_n = attention_every_n
self.layers = nn.ModuleList()
for i in range(n_total_layers):
if (i + 1) % attention_every_n == 0:
# Attention layer
self.layers.append(
AttentionResidualBlock(d_model, n_attn_heads)
)
else:
# Mamba layer (simplified)
self.layers.append(
MambaResidualBlock(d_model, mamba_d_state)
)
self.norm = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
for layer in self.layers:
if isinstance(layer, AttentionResidualBlock):
x = layer(x, attention_mask)
else:
x = layer(x)
return self.norm(x)
class AttentionResidualBlock(nn.Module):
"""Minimal attention block with residual connection."""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
def forward(self, x, mask=None):
normed = self.norm(x)
attn_out, _ = self.attn(normed, normed, normed, attn_mask=mask)
return x + attn_out
class MambaResidualBlock(nn.Module):
"""Placeholder for Mamba block - use MambaBlock from Lesson 3 in practice."""
def __init__(self, d_model: int, d_state: int):
super().__init__()
self.norm = nn.LayerNorm(d_model)
# In practice: self.mamba = MambaBlock(d_model, d_state)
self.linear = nn.Linear(d_model, d_model) # Simplified placeholder
def forward(self, x):
return x + self.linear(self.norm(x))
# Test the hybrid architecture
model = HybridBlock(d_model=256, n_total_layers=8, attention_every_n=4)
x = torch.randn(2, 512, 256) # batch=2, seq=512, d_model=256
output = model(x)
print(f"Input: {x.shape} → Output: {output.shape}")
# Count attention vs Mamba layers
attn_layers = sum(1 for l in model.layers if isinstance(l, AttentionResidualBlock))
mamba_layers = sum(1 for l in model.layers if isinstance(l, MambaResidualBlock))
print(f"Attention layers: {attn_layers}, Mamba layers: {mamba_layers}")
print(f"Ratio: 1:{mamba_layers//attn_layers}")
This implementation illustrates how straightforward the hybrid concept is: it is simply a normal residual stack where some positions have attention blocks and others have SSM blocks. The complexity is in choosing the ratio and placement, and in the training recipe that makes the two block types work well together.
:::tip 🎮 Interactive Playground
Visualize this concept: Try the Hybrid Attention + SSM Architectures demo on the EngineersOfAI Playground - no code required.
:::
