Skip to main content

Mamba - Selective State Space Models

Opening Scenario: The Paper That Arrived in December 2023

On December 1, 2023, Albert Gu and Tri Dao posted "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" to arXiv. The AI community had a week before the NeurIPS conference. Researchers were expecting incremental progress on SSMs - better initialization, faster kernels, minor quality improvements.

What they got was a step change. Mamba-3B matched the perplexity of a transformer trained with 3× the compute. At 2,048 tokens, Mamba was 5 times faster than a transformer in inference throughput. The results held across language, audio, and genomics benchmarks. And the architecture was elegant: a single design decision - making the SSM parameters depend on the input - unlocked a qualitative jump in capability.

The key insight of the paper: S4 and its predecessors used input-independent SSM parameters. The state transition matrix A, the input projection B, and the output projection C were learned during training and fixed during inference. Every input was processed with the same "compression strategy," regardless of its content. Mamba changed this by making B, C, and the step size Δ\Delta functions of the current input - effectively letting the model decide, at each position, how much to let the new input update the hidden state.

This is selective state space modeling. The model can selectively remember information that seems important (by setting the step size large) and selectively ignore irrelevant information (by setting the step size small). The rest of this lesson explains exactly how this works, how it is implemented efficiently on GPUs, and what it means architecturally.

Why Selectivity Changes Everything

To understand why input-dependent parameters matter so much, consider a simple task: selective copying.

Given the sequence: A B C D E F G H ... [COPY] A B C D

The model is given a long sequence of tokens, then a special COPY token, and must reproduce a specific earlier subsequence. A fixed-parameter SSM must decide during training what to keep in its hidden state. Since the same state transition applies to every input token regardless of content, the model cannot make content-aware decisions about what to compress. It must use a generic compression strategy that works "on average."

A transformer handles this perfectly: attention lets the output tokens explicitly look back at the tokens they need to copy. An input-independent SSM struggles.

Mamba's solution: make Δ\Delta, BB, and CC depend on the current input utu_t. Now:

  • When the input is a token worth remembering (e.g., after COPY, the tokens to reproduce), Δ\Delta can be large, causing the current input to strongly update the state
  • When the input is background noise, Δ\Delta can be small, causing the state to barely change

The model can learn, from data, when to pay attention and when to forget. This is analogous to what attention does - but implemented inside the recurrent state update rather than through explicit key-query-value computation.

The Mamba Block: Architecture in Detail

A Mamba block takes an input of shape [batch, seq_len, d_model] and produces output of the same shape. Here is the complete data flow:

Let's trace through each component:

1. Linear Expansion (x2): The input is projected to twice the model dimension. This gives the SSM more internal bandwidth to work with.

2. Split: The expanded tensor is split into two equal halves: the x branch that goes through the SSM, and the z branch that serves as a gating signal.

3. Depthwise Conv1d: Before the SSM, a short (kernel size 4) depthwise convolution is applied to the x branch. This provides local context - a tiny receptive field that helps the model condition on immediately adjacent tokens before the long-range SSM processing.

4. SiLU Activation: The sigmoid linear unit (SiLU, also called Swish) activation: SiLU(x)=xσ(x)\text{SiLU}(x) = x \cdot \sigma(x).

5. Selective SSM: The core computation. Details below.

6. Gating: The SSM output is element-wise multiplied by SiLU(z). This gating mechanism (similar to GLU - Gated Linear Units) gives the model a second chance to suppress or amplify features based on the original input.

7. Output Projection: Project back to d_model.

The Selective SSM: The Core Innovation

Inside the selective SSM, the key parameters are computed from the current input:

# Inside the Mamba SSM block:
# x: [batch, L, d_inner] (d_inner = 2 * d_model)

# Fixed parameters (learned during training, input-INDEPENDENT)
A = -torch.exp(self.A_log) # shape [d_inner, d_state], always negative for stability

# Input-DEPENDENT parameters (computed from x at each position)
# This is the selective part:
x_proj = self.x_proj(x) # [batch, L, dt_rank + 2*d_state]
delta, B, C = x_proj.split([dt_rank, d_state, d_state], dim=-1)

# Delta (step size): determines how much each token updates the state
delta = F.softplus(self.dt_proj(delta)) # [batch, L, d_inner]

# B: input projection (how input maps to state update)
# B: [batch, L, d_state]

# C: output projection (how state maps to output)
# C: [batch, L, d_state]

The critical observation: A is still input-independent (it is shared across all positions), but B, C, and Delta are computed from the current token.

Why keep A fixed? Making A input-dependent would break the parallel scan algorithm (explained below). The selective mechanism achieves its expressivity through the interaction of input-dependent Δ\Delta with fixed AA - as Δ\Delta gets large, the discrete Aˉ=eΔA\bar{A} = e^{\Delta A} approaches 0 (because AA is negative), meaning the previous state is nearly forgotten. As Δ\Delta gets small, Aˉ\bar{A} approaches e0=Ie^0 = I, meaning the previous state is preserved.

This gives a content-dependent forgetting mechanism:

# Simplified: effect of delta on state retention
import torch
import numpy as np

A_cont = -1.0 # continuous A value (negative for stability)

print("State retention as function of delta:")
print(f"{'Delta':>10} | {'A_bar = e^(Δ*A)':>20} | {'Interpretation':>30}")
print("-" * 70)
for delta in [0.001, 0.01, 0.1, 0.5, 1.0, 5.0]:
A_bar = np.exp(delta * A_cont)
if A_bar > 0.99:
interp = "Almost no update (ignore input)"
elif A_bar > 0.5:
interp = "Partial update"
elif A_bar > 0.1:
interp = "Significant update"
else:
interp = "Nearly forget past state"
print(f"{delta:>10.3f} | {A_bar:>20.4f} | {interp:>30}")

Output:

State retention as function of delta:
Delta | A_bar = e^(Δ*A) | Interpretation
----------------------------------------------------------------------
0.001 | 0.9990 | Almost no update (ignore input)
0.010 | 0.9900 | Almost no update (ignore input)
0.100 | 0.9048 | Partial update
0.500 | 0.6065 | Partial update
1.000 | 0.3679 | Significant update
5.000 | 0.0067 | Nearly forget past state

By computing Δ\Delta from the current input, the model learns to output large Δ\Delta for important tokens (causing the state to be updated significantly) and small Δ\Delta for unimportant tokens (causing the state to barely change).

The Hardware-Aware Parallel Scan

The naive implementation of the selective SSM has a problem: because B, C, and Delta depend on the input, we cannot express the computation as a simple convolution (the kernel would be different at each position). We are back to a sequential recurrence. For a sequence of length L, this means L serial steps - slow during training.

The solution is the parallel scan algorithm (also called the prefix sum or scan operation), specialized for SSMs by the Mamba paper with hardware-aware CUDA kernels.

The Parallel Scan Principle

A linear recurrence hk=akhk1+bkh_k = a_k h_{k-1} + b_k (where aka_k and bkb_k can vary per step) can be computed in parallel using a tree-structured computation:

Sequential: h1 = a1*h0 + b1
h2 = a2*h1 + b2 = a2*(a1*h0+b1)+b2 = a2*a1*h0 + a2*b1 + b2
...
(L serial steps, cannot parallelize)

Parallel scan: compute all hk simultaneously in O(log L) parallel steps

The parallel scan works by observing that adjacent computations can be combined: the operation (a,b)(a,b)=(aa,ab+b)(a, b) \circ (a', b') = (a \cdot a', a \cdot b' + b) is associative. We can use a parallel prefix computation (like a parallel prefix sum) to compute all outputs simultaneously.

In practice, the Mamba implementation uses a CUDA kernel that:

  1. Loads the sequences for Δ\Delta, BB, CC, and uu into fast on-chip SRAM
  2. Performs the parallel scan entirely in SRAM, without writing intermediate states to HBM
  3. Writes only the final output sequence to HBM

This is the "hardware-aware" part: like Flash Attention, the implementation is designed around the memory hierarchy of modern GPUs (fast SRAM vs slow HBM), minimizing expensive HBM reads and writes.

The result: the selective SSM runs in O(L)O(L) time with good hardware utilization during training (parallel scan), and in O(1)O(1) memory with a sequential recurrent loop during inference.

Full Mamba Implementation

Here is a working implementation of the Mamba block, simplified for clarity but faithful to the paper:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MambaBlock(nn.Module):
"""
Mamba block: the core building block of the Mamba architecture.

Reference: Gu & Dao, "Mamba: Linear-Time Sequence Modeling
with Selective State Spaces" (2023)
"""
def __init__(
self,
d_model: int,
d_state: int = 16, # SSM state dimension N
d_conv: int = 4, # Depthwise conv kernel size
expand: int = 2, # Expansion factor
dt_rank: str = "auto", # Delta projection rank
dt_min: float = 0.001,
dt_max: float = 0.1,
dt_init: str = "random",
dt_scale: float = 1.0,
dt_init_floor: float = 1e-4,
bias: bool = False,
conv_bias: bool = True,
):
super().__init__()

self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model) # 2 * d_model

if dt_rank == "auto":
self.dt_rank = math.ceil(self.d_model / 16)
else:
self.dt_rank = dt_rank

# Input projection: d_model -> 2*d_inner (for x and z branches)
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias)

# Depthwise conv for local context
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
groups=self.d_inner, # depthwise
padding=d_conv - 1, # causal padding
bias=conv_bias,
)

# Projection to get input-dependent SSM parameters
# Output: [dt_rank + 2*d_state] for (delta, B, C)
self.x_proj = nn.Linear(
self.d_inner, self.dt_rank + self.d_state * 2, bias=False
)

# Delta (step size) projection
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)

# Initialize dt_proj following the paper
dt_init_std = self.dt_rank ** -0.5 * dt_scale
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
dt = torch.exp(
torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
inv_dt = dt + torch.log(-torch.expm1(-dt)) # inverse softplus
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)

# A: fixed SSM transition matrix (diagonal, negative for stability)
# A_log stores log(-A), so A = -exp(A_log) < 0 always
A = torch.arange(1, self.d_state + 1, dtype=torch.float32).unsqueeze(0)
A = A.expand(self.d_inner, -1) # [d_inner, d_state]
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True

# D: skip connection (learned scalar per channel)
self.D = nn.Parameter(torch.ones(self.d_inner))
self.D._no_weight_decay = True

# Output projection
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias)

def forward(self, u: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
u: [batch, seq_len, d_model]
Returns: [batch, seq_len, d_model]
"""
batch, L, d = u.shape
assert d == self.d_model

# Step 1: Input projection
xz = self.in_proj(u) # [batch, L, 2*d_inner]
x, z = xz.chunk(2, dim=-1) # each [batch, L, d_inner]

# Step 2: Depthwise conv (local context)
# Conv1d expects [batch, channels, length]
x = x.transpose(1, 2) # [batch, d_inner, L]
x = self.conv1d(x)[:, :, :L] # causal: trim to L
x = x.transpose(1, 2) # [batch, L, d_inner]

# Step 3: SiLU activation
x = F.silu(x)

# Step 4: Selective SSM
y = self.selective_scan(x)

# Step 5: Gating
y = y * F.silu(z)

# Step 6: Output projection
output = self.out_proj(y)

return output

def selective_scan(self, x: torch.Tensor) -> torch.Tensor:
"""
The core selective SSM computation.
x: [batch, L, d_inner]
Returns: [batch, L, d_inner]

Note: This is a simplified sequential implementation.
The actual Mamba uses a custom CUDA kernel with parallel scan.
"""
batch, L, d_inner = x.shape

# Get fixed A matrix
A = -torch.exp(self.A_log) # [d_inner, d_state], always negative

# Compute input-dependent parameters B, C, delta from x
xBC_delta = self.x_proj(x) # [batch, L, dt_rank + 2*d_state]
delta_raw, B, C = xBC_delta.split(
[self.dt_rank, self.d_state, self.d_state], dim=-1
)

# Delta: step size - determines how much each token updates state
delta = F.softplus(self.dt_proj(delta_raw)) # [batch, L, d_inner]

# Discretize: Ā = exp(Δ * A), B̄ = Δ * B (simplified ZOH)
# delta: [batch, L, d_inner], A: [d_inner, d_state]
# delta_A: [batch, L, d_inner, d_state]
delta_A = torch.einsum("b l d, d n -> b l d n", delta, A)
A_bar = torch.exp(delta_A) # [batch, L, d_inner, d_state]

# B̄: [batch, L, d_inner, d_state]
B_bar = torch.einsum("b l d, b l n -> b l d n", delta, B)

# Sequential scan (simplified; production uses parallel scan)
h = torch.zeros(batch, d_inner, self.d_state, device=x.device, dtype=x.dtype)
ys = []

for t in range(L):
# State update: h = Ā * h + B̄ * x
h = A_bar[:, t] * h + B_bar[:, t] * x[:, t, :].unsqueeze(-1)

# Output: y = C * h (batched dot product over state dim)
y_t = torch.einsum("b n, b d n -> b d", C[:, t], h)
ys.append(y_t)

y = torch.stack(ys, dim=1) # [batch, L, d_inner]

# Skip connection
y = y + x * self.D

return y


class MambaModel(nn.Module):
"""A stack of Mamba blocks for sequence modeling."""

def __init__(
self,
d_model: int,
n_layers: int,
vocab_size: int,
d_state: int = 16,
d_conv: int = 4,
expand: int = 2,
):
super().__init__()

self.embedding = nn.Embedding(vocab_size, d_model)
self.layers = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(d_model),
MambaBlock(d_model, d_state, d_conv, expand),
)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

# Tie embedding weights with LM head (standard practice)
self.lm_head.weight = self.embedding.weight

def forward(
self,
input_ids: torch.Tensor,
hidden_states: list = None, # For recurrent inference
) -> torch.Tensor:
x = self.embedding(input_ids)

for i, layer in enumerate(self.layers):
x = x + layer(x) # residual connection

x = self.norm(x)
logits = self.lm_head(x)
return logits


# Quick test
if __name__ == "__main__":
torch.manual_seed(42)

model = MambaModel(
d_model=256,
n_layers=4,
vocab_size=50257,
d_state=16,
)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

# Test forward pass
input_ids = torch.randint(0, 50257, (2, 512))
logits = model(input_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"Expected: [2, 512, 50257]")

Mamba-2: State Space Duality

In June 2024, Gu and Dao published the Mamba-2 paper (Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality).

The key insight of Mamba-2: there is a deep mathematical connection between SSMs and attention. Specifically, the Mamba-2 paper showed that the SSM computation can be written as a form of attention - a "structured" attention where the attention matrix has specific mathematical structure.

This duality (called State Space Duality, or SSD) has practical implications:

  1. Larger state dimensions: Mamba's hardware-efficient scan becomes more efficient for larger state sizes (from d_state=16 in Mamba-1 to d_state=64+ in Mamba-2)

  2. Tensor parallelism: The SSD formulation enables the same tensor-parallel training strategies used for transformers, making it easier to train Mamba-2 on large GPU clusters

  3. New hybrid possibilities: The mathematical equivalence makes it clear exactly what SSMs can and cannot express, guiding hybrid architecture design

The Mamba-2 block replaces the selective SSM with the SSD layer, which computes the same operation but using a chunked algorithm that is more efficient on modern hardware:

def ssd_chunk_scan(
x: torch.Tensor, # [batch, L, d_inner]
A: torch.Tensor, # [d_inner, d_state]
B: torch.Tensor, # [batch, L, d_state] - input dependent
C: torch.Tensor, # [batch, L, d_state] - input dependent
delta: torch.Tensor, # [batch, L, d_inner] - input dependent
chunk_size: int = 64,
) -> torch.Tensor:
"""
SSD (State Space Duality) chunked computation.

Processes the sequence in chunks of size chunk_size.
Within each chunk: use fast matrix multiply (attention-like).
Between chunks: use recurrent state passing.

This is O(L * chunk_size * d_state) - linear in L.
"""
batch, L, d_inner = x.shape
n_chunks = L // chunk_size

# ... (full implementation is complex, see the Mamba-2 codebase)
# Key idea: the within-chunk computation maps to matrix multiply,
# which is highly optimized on GPUs (BLAS routines).

pass # Sketch only; refer to mamba-ssm library for full implementation

Mamba in the Real World: Using Pre-trained Models

from transformers import AutoTokenizer, MambaForCausalLM
import torch

# Load the Mamba-2.8B model (available on HuggingFace)
model_name = "state-spaces/mamba-2.8b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = MambaForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
)

# The Mamba model API is identical to transformer models
prompt = "State space models are interesting because"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=100,
do_sample=True,
temperature=0.7,
top_p=0.9,
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

# Key difference: Mamba has no KV cache
# Memory at inference scales with model size, NOT sequence length
memory_allocated = torch.cuda.memory_allocated() / 1e9
print(f"\nGPU memory at inference: {memory_allocated:.2f} GB")
print("This does NOT grow as we generate more tokens (unlike transformer)")

Parameter Count and Architecture Comparison

ModelParamsd_modelLayersMechanism
Mamba-130M130M76824Selective SSM
Mamba-370M370M102448Selective SSM
Mamba-790M790M153648Selective SSM
Mamba-1.4B1.4B204848Selective SSM
Mamba-2.8B2.8B256064Selective SSM
GPT-Neo-2.7B2.7B256032Full Attention
Mamba-2 (8B)8B409664SSD (Mamba-2)
Falcon Mamba 7B7B409664Selective SSM

Mamba models use roughly the same parameter counts as similarly-sized transformers, but have a different distribution: more of the parameters go into the SSM components (the A, B, C, D matrices and projections) rather than attention matrices.

Production Engineering Notes

Memory at inference: Unlike transformers, Mamba's inference memory is constant with respect to sequence length. The hidden state for a 7B Mamba model is approximately:

n_layers × d_inner × d_state × dtype_bytes
= 64 × 8192 × 16 × 2 = 16 MB

16MB of state for a 7B model at any sequence length. Compare to 52GB+ KV cache for a 7B transformer at 100K tokens.

Streaming inference: Mamba is ideal for streaming applications. You can process tokens one at a time, updating the hidden state incrementally, with no memory growth. For real-time transcription, live document processing, or continuous monitoring applications, this is a significant architectural advantage.

Throughput vs latency: At short sequences (under 2K tokens), transformers can be faster due to better GPU utilization in the prefill phase. Mamba's advantage grows with sequence length. For batch inference at long sequences, Mamba is dramatically faster.

# Streaming inference with Mamba
def mamba_streaming_generate(
model,
tokenizer,
prompt: str,
max_new_tokens: int = 200,
stream_callback = None,
):
"""
Generate tokens one at a time with O(1) state update.
stream_callback(token_str): called with each generated token.
"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = inputs["input_ids"]

# Process prompt (prefill) - in recurrent mode, token by token
# In practice, Mamba can still prefill in parallel using convolutional mode
generated_ids = input_ids.clone()

cache = None # Mamba's "cache" is the hidden SSM state (tiny!)
with torch.no_grad():
for step in range(max_new_tokens):
outputs = model(
input_ids=generated_ids[:, -1:], # Only last token
cache_params=cache,
use_cache=True,
)
logits = outputs.logits[:, -1, :]
cache = outputs.cache_params # Hidden SSM state (a few MB)

# Sample next token
probs = torch.softmax(logits / 0.7, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)

generated_ids = torch.cat([generated_ids, next_token], dim=-1)

# Stream the token
token_str = tokenizer.decode(next_token[0])
if stream_callback:
stream_callback(token_str)

return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

Common Mistakes

:::danger Using Mamba for Tasks Requiring Precise Lookup Mamba compresses past context into a fixed-size hidden state. For tasks requiring precise retrieval ("what is the exact value mentioned on page 47?"), Mamba will underperform transformers. The compressed state loses specific details that attention's full KV cache preserves. Do not use Mamba for applications where exact verbatim recall of earlier content is required. :::

:::warning Confusing Mamba-1 and Mamba-2 APIs Mamba-1 and Mamba-2 use different internal computations (selective scan vs SSD), different state dimensions, and different checkpoint formats. The HuggingFace models mamba-*-hf (Mamba-1) and mamba2-* (Mamba-2) have slightly different APIs. Always check which variant you are using and use the appropriate configuration. Mamba-2 requires the mamba_ssm package installed separately. :::

:::warning The Depthwise Conv Has Causal Padding Requirements The Mamba block includes a depthwise conv1d with causal padding (d_conv - 1 padding on the left only). If you implement Mamba from scratch and use incorrect padding, the conv will leak future information into past positions, breaking causality. The padding must ensure that position t only receives contributions from positions t and earlier. :::

Interview Q&A

Q1: What is the key architectural difference between Mamba and S4?

S4 uses input-independent SSM parameters: matrices A, B, and C are learned during training and fixed during inference. Every token at every position is processed with the same state transition regardless of its content. Mamba makes B, C, and the discretization step size Delta input-dependent - they are computed from the current token via a linear projection. This allows the model to selectively update the hidden state based on content: important tokens can cause large state updates (large Delta), while irrelevant tokens cause small updates. A remains input-independent to preserve the parallel scan's computational structure. This selectivity enables Mamba to model content-dependent memory in a way that S4 cannot.

Q2: Explain the Mamba block's architecture and the role of each component.

The Mamba block processes input [batch, L, d_model] through: (1) Linear expansion to 2*d_inner and split into x and z branches; (2) Depthwise Conv1d (kernel size 4) on the x branch for local context - a short receptive field before the long-range SSM; (3) SiLU activation; (4) The selective SSM, which computes input-dependent B, C, Delta and runs the parallel scan; (5) Element-wise gating: SSM output multiplied by SiLU(z), giving the model control over which features to pass through; (6) Output projection back to d_model. The z branch gating is inspired by GLU (Gated Linear Units) and provides an additional multiplicative control mechanism.

Q3: Why does Mamba keep A input-independent when B, C, and Delta are input-dependent?

Making A input-dependent would break the parallel scan algorithm. The parallel scan works because the recurrence hk=akhk1+bkh_k = a_k h_{k-1} + b_k has an associative structure that allows parallel computation. If A varied per position, this associativity is harder to exploit efficiently. More importantly, the selectivity mechanism achieves its expressivity through Delta alone: by varying Delta, the effective discrete A_bar = exp(Delta * A) varies dramatically even with fixed continuous A. Small Delta → A_bar near I (preserve state). Large Delta → A_bar near 0 (forget state). The fixed A constrains the "shape" of memory (which modes decay fastest) while Delta controls the "speed."

Q4: What is the State Space Duality (SSD) introduced in Mamba-2, and why does it matter?

SSD shows that the Mamba SSM computation can be mathematically rewritten as a form of linear attention - specifically, an attention mechanism where the attention matrix has particular structure (it is the product of the A_bar matrices and the B, C projections). This equivalence has several practical implications: (1) It reveals exactly what SSMs can and cannot represent relative to attention; (2) It enables the chunked SSD algorithm, where within-chunk computation uses matrix multiply (highly GPU-optimized) while between-chunk computation uses recurrent state passing; (3) It supports larger state dimensions (64+ vs 16 in Mamba-1), which improves quality; (4) It enables standard tensor-parallel training techniques used for transformers.

Q5: How does Mamba's inference compare to a transformer in terms of memory and speed? At what sequence lengths does Mamba win?

Mamba's inference memory is constant: only the hidden SSM state (~16MB for a 7B model) regardless of sequence length, compared to the transformer's KV cache that grows linearly (52GB at 100K tokens for 7B). For throughput at long sequences, Mamba wins dramatically: the Mamba paper reports 5x higher throughput than transformers at 2K sequence length, increasing further at longer sequences. For short sequences (under 512 tokens), transformers can be faster due to better GPU utilization in the attention computation (dense matrix operations vs the scan algorithm). The crossover point varies by hardware but is typically around 1K–2K tokens. For latency-sensitive applications at short sequence lengths (chatbots, short Q&A), a transformer with Flash Attention is typically faster per token. For throughput-sensitive applications at long sequence lengths (document processing, genomics, audio), Mamba is significantly more efficient.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Mamba State Space Model demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.