Skip to main content

Data Structures for ML Systems

Reading time: ~45 min · Interview relevance: Very High · Target roles: ML Engineer, MLOps Engineer, Systems Engineer, RecSys Engineer


The Vector Database That Didn't Scale

A startup was building a semantic search product. Their initial prototype worked: embed a query with BERT, brute-force compare against 100,000 stored embeddings using cosine similarity, return the top-10 results. Latency was 200ms - acceptable for a demo.

When they onboarded their first enterprise customer with 50 million documents, the math changed immediately. A single 768-dimensional BERT embedding is 3KB in float32. 50 million embeddings = 144GB just for the vectors. Brute-force cosine similarity over 50M 768-dimensional vectors at 1 GFLOP = 38.4 billion FLOP per query. At 10 TFLOPS: 3.8 milliseconds for the compute alone, plus memory bandwidth overhead for 144GB - realistically, several seconds.

The customer wanted 10ms latency at P99. This required a complete rethink. The solution was HNSW (Hierarchical Navigable Small World graphs): a data structure that enables approximate nearest neighbor search with logarithmic construction complexity and sub-linear query time. At 50M vectors, HNSW could answer queries in under 3ms with 95% recall@10. The tradeoff was 2GB of additional index memory - trivial compared to the 144GB of vectors.

This story repeats across ML infrastructure. The tokenizer at the heart of every LLM uses a trie (not a hash map) for efficient prefix matching during byte-pair encoding. The feature store serving millions of model predictions per second uses an LSM-tree (not B-tree) because writes dominate reads. The full-text search that retrieves context for RAG systems uses an inverted index (not sequential scan) to reduce query time from O(n) to O(relevant documents).

Every ML system beyond the prototype scale is constrained by data structure choices. The algorithms you choose determine whether your system handles production load or collapses at 10x traffic. This lesson covers the data structures that appear repeatedly in ML infrastructure: why each exists, what guarantees it provides, and how to implement or use it correctly.


Why This Exists

General-purpose data structures (arrays, hash maps, balanced BSTs) solve general problems. ML systems have specific access patterns that general structures handle inefficiently:

  • Tokenizers need prefix lookup: "what is the longest match for this text prefix in my vocabulary?" A hash map requires hashing every possible prefix. A trie provides this in O(L)O(L) time where LL is the prefix length.

  • Vector search needs approximate nearest neighbor: Exact nearest neighbor requires O(nd)O(n \cdot d) comparisons. For 768-dimensional vectors and millions of examples, this is too slow. HNSW reduces query time to O(lognpolylog(n)d)O(\log n \cdot \text{polylog}(n) \cdot d) with controlled approximation error.

  • Feature stores have write-heavy workloads: B-trees require random writes (expensive on SSDs). LSM-trees convert random writes to sequential appends, achieving 10-100x better write throughput.

  • Time series features need range queries: "what is the sum of events in window [t, t+h]?" Naive linear scan: O(n)O(n). Segment tree: O(logn)O(\log n) per query with O(nlogn)O(n \log n) preprocessing.

Understanding which data structure fits which access pattern is the difference between a system that works and one that works at scale. This requires knowing the structures well enough to identify when a general-purpose tool (hash map, B-tree) is being pushed beyond its design.


Historical Context

The trie (from "retrieval") was invented by René de la Briandais in 1959 and independently by Fredkin (1960), who coined the name. It became standard for dictionary implementations, spell checkers, and prefix search. In modern ML, it is the core of byte-pair encoding tokenizers - every GPT-family model uses a BPE tokenizer whose merge vocabulary is organized as a trie.

Segment trees and Fenwick trees (Binary Indexed Trees) emerged in the competitive programming community in the 1980s-1990s. Fenwick introduced the BIT in 1994. They are the standard for range sum queries in competitive programming and increasingly appear in ML feature engineering pipelines for time series.

Skip lists were invented by Pugh (1990) as a probabilistic alternative to balanced BSTs. They appear in LevelDB and RocksDB's memtable implementation.

HNSW was introduced by Malkov and Yashunin (2018), building on earlier NSW (Navigable Small World) graphs. It rapidly became the dominant algorithm for approximate nearest neighbor search, implemented in Faiss, Hnswlib, Weaviate, Pinecone, and virtually every modern vector database.

LSM trees (Log-Structured Merge-trees) were introduced by O'Neil et al. in 1996, inspired by write-ahead logs. Google's Bigtable (2006) popularized LSM at scale. LevelDB (2011) and RocksDB (2012) made LSM standard for write-heavy storage. Modern ML feature stores (Feast with Redis, Tecton with DynamoDB) all use LSM-based storage at their core.


Trie for Tokenizer Vocabulary

The Tokenization Problem

Modern LLMs use byte-pair encoding (BPE) or similar algorithms that build a vocabulary of ~50,000-200,000 subword tokens. Given input text, the tokenizer must find the longest matching vocabulary entry at each position. With 100,000 vocabulary entries of varying lengths, this is a prefix matching problem.

A hash map approach: for each possible prefix of the current position, check if it is in the vocabulary. For a word 15 characters long, this requires 15 hash lookups - O(L)O(L) lookups, each O(1)O(1) average. The trie approach: traverse a single path through the tree while accumulating characters, branching exactly once per character - O(L)O(L) total character comparisons with zero redundant work.

from typing import Optional, Dict, List, Tuple

class TrieNode:
"""A single node in the trie."""
__slots__ = ['children', 'token_id', 'is_end']

def __init__(self):
self.children: Dict[str, 'TrieNode'] = {}
self.token_id: Optional[int] = None
self.is_end: bool = False


class Trie:
"""
Trie (prefix tree) for tokenizer vocabulary lookup.
Provides O(L) prefix search where L is the length of the prefix.

In BPE tokenizers (GPT-2, LLaMA), the vocabulary is a set of
subword tokens. The trie enables efficient greedy longest-match
tokenization: at each position, traverse the trie as far as
possible to find the longest valid token.
"""

def __init__(self):
self.root = TrieNode()
self.n_tokens = 0

def insert(self, token: str, token_id: int) -> None:
"""Insert a token with its ID into the trie. O(L)."""
node = self.root
for char in token:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.is_end = True
node.token_id = token_id
self.n_tokens += 1

def search(self, text: str) -> Optional[int]:
"""Check if text is a complete token. Returns token_id or None."""
node = self.root
for char in text:
if char not in node.children:
return None
node = node.children[char]
return node.token_id if node.is_end else None

def longest_prefix_match(self, text: str,
start: int = 0) -> Tuple[int, Optional[int]]:
"""
Find longest token in vocabulary starting at position start.
Returns (end_position, token_id) of the longest match.

This is the core operation of greedy BPE tokenization.
"""
node = self.root
last_match_end = start
last_match_id = None

i = start
while i < len(text):
char = text[i]
if char not in node.children:
break
node = node.children[char]
i += 1
if node.is_end:
last_match_end = i
last_match_id = node.token_id

return last_match_end, last_match_id

def tokenize(self, text: str) -> List[int]:
"""
Greedy longest-match tokenization.
At each position, find and consume the longest matching token.
Returns list of token IDs.
Falls back to character-level tokens for unknown characters.
"""
tokens = []
pos = 0

while pos < len(text):
end, token_id = self.longest_prefix_match(text, pos)

if token_id is not None and end > pos:
tokens.append(token_id)
pos = end
else:
# Unknown character: use character-level fallback
# In real BPE, this uses a byte-level fallback vocabulary
tokens.append(-1) # UNK
pos += 1

return tokens

def all_prefixes(self, prefix: str) -> List[Tuple[str, int]]:
"""
Find all vocabulary tokens that start with prefix.
Used for: autocomplete, trie-based beam search,
constrained decoding in LLMs.
"""
node = self.root
for char in prefix:
if char not in node.children:
return []
node = node.children[char]

# DFS to collect all tokens under this prefix node
results = []
self._dfs_collect(node, prefix, results)
return results

def _dfs_collect(self, node: TrieNode, current: str,
results: List) -> None:
if node.is_end:
results.append((current, node.token_id))
for char, child in node.children.items():
self._dfs_collect(child, current + char, results)

def memory_estimate_bytes(self) -> int:
"""Estimate trie memory usage."""
# Each node: ~64 bytes overhead + ~56 bytes per child entry
node_count = self._count_nodes(self.root)
return node_count * 120 # approximate

def _count_nodes(self, node: TrieNode) -> int:
return 1 + sum(self._count_nodes(child)
for child in node.children.values())


# Demonstrate with a small vocabulary
def demo_tokenizer_trie():
vocab = [
("the", 0), ("the ", 1), ("there", 2), ("they", 3),
("a", 4), ("and", 5), ("an", 6),
("token", 7), ("tokenize", 8), ("tokenizer", 9),
("##ing", 10), ("##tion", 11), # WordPiece-style subwords
]

trie = Trie()
for token, token_id in vocab:
trie.insert(token, token_id)

# Tokenize a sentence
text = "the tokenizer"
tokens = trie.tokenize(text)
print(f"Input: '{text}'")
print(f"Token IDs: {tokens}")

# Prefix search (for constrained decoding)
completions = trie.all_prefixes("the")
print(f"\nAll tokens starting with 'the':")
for token, tid in completions:
print(f" '{token}' -> id {tid}")

print(f"\nTrie memory estimate: {trie.memory_estimate_bytes() // 1024} KB")
print(f"Hash map equivalent for same vocab: "
f"~{sum(len(t) for t, _ in vocab) * 40 // 1024} KB "
f"(with hash overhead)")

demo_tokenizer_trie()

Segment Tree for Range Queries

Segment trees answer range aggregate queries (sum, min, max) in O(logn)O(\log n) time after O(n)O(n) preprocessing. This appears in time series feature engineering: "what is the sum of events in window [l,r][l, r]?" for sliding window features.

class SegmentTree:
"""
Segment tree for range sum queries with point updates.
O(n) build, O(log n) query, O(log n) update.

ML applications:
- Time series sliding window features: sum/max over any window
- Priority sampling: weighted random sampling by cumulative weight
- Range feature extraction for sequence models
"""

def __init__(self, data: List[float]):
self.n = len(data)
self.tree = [0.0] * (4 * self.n)
self._build(data, 1, 0, self.n - 1)

def _build(self, data: List[float], node: int,
start: int, end: int) -> None:
if start == end:
self.tree[node] = data[start]
else:
mid = (start + end) // 2
self._build(data, 2 * node, start, mid)
self._build(data, 2 * node + 1, mid + 1, end)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]

def update(self, idx: int, val: float) -> None:
"""Update value at index idx to val. O(log n)."""
self._update(1, 0, self.n - 1, idx, val)

def _update(self, node: int, start: int, end: int,
idx: int, val: float) -> None:
if start == end:
self.tree[node] = val
else:
mid = (start + end) // 2
if idx <= mid:
self._update(2 * node, start, mid, idx, val)
else:
self._update(2 * node + 1, mid + 1, end, idx, val)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]

def query(self, l: int, r: int) -> float:
"""Range sum query for [l, r]. O(log n)."""
return self._query(1, 0, self.n - 1, l, r)

def _query(self, node: int, start: int, end: int,
l: int, r: int) -> float:
if r < start or end < l:
return 0.0
if l <= start and end <= r:
return self.tree[node]
mid = (start + end) // 2
return (self._query(2 * node, start, mid, l, r) +
self._query(2 * node + 1, mid + 1, end, l, r))


# Feature engineering: sliding window sum over event stream
import numpy as np

def compute_sliding_window_features(events: List[float],
windows: List[int]) -> np.ndarray:
"""
Compute multiple sliding window sums using segment tree.
Much faster than repeated numpy slicing for many window sizes.
"""
n = len(events)
st = SegmentTree(events)
features = np.zeros((n, len(windows)))

for i in range(n):
for j, w in enumerate(windows):
l = max(0, i - w + 1)
features[i, j] = st.query(l, i)

return features

# Simulate user click events, extract multi-scale features
events = list(np.random.poisson(3.0, 10000).astype(float))
window_sizes = [1, 5, 10, 30, 60, 120] # 1min, 5min, ..., 2hr windows
features = compute_sliding_window_features(events[:100], window_sizes)
print(f"Feature shape: {features.shape}") # [100, 6]
print(f"First row features: {features[0].tolist()}")

Union-Find (DSU) for Clustering

Union-Find (Disjoint Set Union) tracks which elements belong to the same set, supporting two operations in near-constant amortized time: find(x) returns the canonical representative of x's set, and union(x, y) merges the sets containing x and y.

In ML: used in offline clustering (e.g., find connected components in similarity graphs), data deduplication (group near-duplicate documents into clusters), and Kruskal's MST algorithm (which is itself used for hierarchical clustering).

class UnionFind:
"""
Disjoint Set Union with path compression and union by rank.
Near O(1) amortized per operation (inverse Ackermann function).

ML applications:
- Connected components for graph-based clustering
- Data deduplication: group near-duplicate training examples
- Online clustering: merge clusters when distance threshold met
- Kruskal's MST for single-linkage hierarchical clustering
"""

def __init__(self, n: int):
self.parent = list(range(n))
self.rank = [0] * n
self.n_components = n

def find(self, x: int) -> int:
"""Find root of x's set with path compression."""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # path compression
return self.parent[x]

def union(self, x: int, y: int) -> bool:
"""Merge sets containing x and y. Returns True if merged."""
px, py = self.find(x), self.find(y)
if px == py:
return False

# Union by rank: attach smaller tree to larger
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1

self.n_components -= 1
return True

def same_set(self, x: int, y: int) -> bool:
return self.find(x) == self.find(y)

def get_clusters(self) -> Dict[int, List[int]]:
"""Return all clusters as {root: [members]}."""
clusters: Dict[int, List[int]] = {}
for i in range(len(self.parent)):
root = self.find(i)
if root not in clusters:
clusters[root] = []
clusters[root].append(i)
return clusters


def deduplicate_embeddings(embeddings: np.ndarray,
threshold: float = 0.95) -> List[int]:
"""
Group near-duplicate embeddings using Union-Find.
Returns list of cluster representative indices.

Used in LLM training data deduplication:
- Embed all documents
- Find pairs with cosine similarity > threshold
- Union those pairs
- Keep one representative from each cluster
"""
n = len(embeddings)
# Normalize for cosine similarity
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
normed = embeddings / (norms + 1e-10)

uf = UnionFind(n)

# Pairwise similarity (use ANN in production for large n)
for i in range(n):
for j in range(i + 1, n):
sim = np.dot(normed[i], normed[j])
if sim >= threshold:
uf.union(i, j)

# Keep one representative per cluster (the first element)
clusters = uf.get_clusters()
representatives = [min(members) for members in clusters.values()]
print(f"Original: {n} docs, After dedup: {len(representatives)} clusters")
return sorted(representatives)

HNSW - Hierarchical Navigable Small World

HNSW is the dominant algorithm for approximate nearest neighbor search. It builds a multi-layer graph where upper layers are sparse (for fast coarse navigation) and lower layers are dense (for precise local search). Querying starts at the top layer and greedily descends, narrowing the search region at each level.

Construction: Each new vector is assigned a random maximum layer ll (drawn from an exponential distribution). The vector is inserted into all layers 0,1,,l0, 1, \ldots, l. At each layer, it connects to its MM nearest neighbors found by greedy search.

Query: Start at the single entry point in the top layer. Greedily move to the nearest neighbor at each layer. When you cannot improve, descend to the next layer. At layer 0 (densest), collect the top-kk candidates and return them.

Time complexity: O(logn)O(\log n) for insertion and query (empirically). Space: O(nM)O(nM) where MM is the number of connections per node per layer.

import heapq
import math
import random
from typing import List, Tuple, Optional, Set

class HNSWNode:
"""Single node in the HNSW graph."""
__slots__ = ['vector', 'neighbors']

def __init__(self, vector: np.ndarray, n_layers: int):
self.vector = vector
# neighbors[layer] = list of neighbor indices
self.neighbors: List[List[int]] = [[] for _ in range(n_layers + 1)]


class HNSWIndex:
"""
Simplified HNSW implementation for educational clarity.
Production: use hnswlib (C++ bindings, 10x faster) or Faiss.

Key hyperparameters:
- M: max connections per node per layer (typ. 16-64)
- ef_construction: beam width during construction (typ. 100-200)
- ef_search: beam width during query (tradeoff recall vs speed)

M=16, ef=100: standard quality. M=32, ef=200: high quality.
"""

def __init__(self, dim: int, M: int = 16, ef_construction: int = 100):
self.dim = dim
self.M = M
self.Mmax = M # max connections at layer 0
self.Mmax0 = 2 * M # more connections at base layer
self.ef_construction = ef_construction
self.ml = 1.0 / math.log(M) # level generation parameter

self.nodes: List[HNSWNode] = []
self.entry_point: Optional[int] = None
self.max_layer: int = 0

def _distance(self, a: np.ndarray, b: np.ndarray) -> float:
"""Euclidean distance (use negative dot product for cosine)."""
return float(np.sum((a - b) ** 2))

def _random_level(self) -> int:
"""Sample maximum layer for new node. Exponential distribution."""
level = int(-math.log(random.random()) * self.ml)
return min(level, 16) # cap at 16 layers

def _search_layer(self, query: np.ndarray, entry_points: List[int],
ef: int, layer: int) -> List[Tuple[float, int]]:
"""
Greedy beam search at a single layer.
Returns ef nearest neighbors found.
"""
visited: Set[int] = set(entry_points)
# Max-heap for candidates (negate distance for max-heap semantics)
candidates = [] # min-heap: (dist, idx)
W = [] # max-heap: (neg_dist, idx) - current ef best

for ep in entry_points:
d = self._distance(query, self.nodes[ep].vector)
heapq.heappush(candidates, (d, ep))
heapq.heappush(W, (-d, ep))

while candidates:
d_c, c = heapq.heappop(candidates)

# Pruning: if closest candidate is farther than the worst in W, stop
if len(W) >= ef and d_c > -W[0][0]:
break

for neighbor_idx in self.nodes[c].neighbors[layer]:
if neighbor_idx in visited:
continue
visited.add(neighbor_idx)

d_n = self._distance(query, self.nodes[neighbor_idx].vector)

if len(W) < ef or d_n < -W[0][0]:
heapq.heappush(candidates, (d_n, neighbor_idx))
heapq.heappush(W, (-d_n, neighbor_idx))
if len(W) > ef:
heapq.heappop(W)

# Return (distance, index) pairs, sorted by distance
return sorted((-neg_d, idx) for neg_d, idx in W)

def add(self, vector: np.ndarray) -> int:
"""Insert vector into the HNSW index. Returns index of new node."""
level = self._random_level()
node_idx = len(self.nodes)
self.nodes.append(HNSWNode(vector.copy(), level))

if self.entry_point is None:
# First node becomes entry point
self.entry_point = node_idx
self.max_layer = level
return node_idx

# Phase 1: Descend from top layer to level+1
ep = [self.entry_point]
for lc in range(self.max_layer, level, -1):
ep_results = self._search_layer(vector, ep, ef=1, layer=lc)
ep = [ep_results[0][1]] if ep_results else ep

# Phase 2: Insert at each layer from min(level, max_layer) down to 0
for lc in range(min(level, self.max_layer), -1, -1):
# Find nearest neighbors at this layer
M_at_layer = self.Mmax0 if lc == 0 else self.Mmax
neighbors = self._search_layer(
vector, ep, ef=self.ef_construction, layer=lc
)
# Connect to at most M nearest neighbors
neighbors = neighbors[:M_at_layer]

for _, nb_idx in neighbors:
# Bidirectional connection
self.nodes[node_idx].neighbors[lc].append(nb_idx)
self.nodes[nb_idx].neighbors[lc].append(node_idx)

# Prune connections if exceeding M
if len(self.nodes[nb_idx].neighbors[lc]) > M_at_layer:
# Keep M_at_layer closest connections
nb_vec = self.nodes[nb_idx].vector
conn_dists = [
(self._distance(nb_vec, self.nodes[c].vector), c)
for c in self.nodes[nb_idx].neighbors[lc]
]
conn_dists.sort()
self.nodes[nb_idx].neighbors[lc] = [
c for _, c in conn_dists[:M_at_layer]
]

ep = [nb_idx for _, nb_idx in neighbors[:1]] if neighbors else ep

# Update entry point if new node has higher layer
if level > self.max_layer:
self.max_layer = level
self.entry_point = node_idx

return node_idx

def search(self, query: np.ndarray, k: int = 10,
ef: int = 50) -> List[Tuple[float, int]]:
"""
Find k approximate nearest neighbors.
ef: search beam width (higher = better recall, slower)
"""
if self.entry_point is None:
return []

# Descend from top to layer 1
ep = [self.entry_point]
for lc in range(self.max_layer, 0, -1):
ep_results = self._search_layer(query, ep, ef=1, layer=lc)
ep = [ep_results[0][1]] if ep_results else ep

# Search at layer 0 with full ef
results = self._search_layer(query, ep, ef=max(ef, k), layer=0)
return results[:k]


def benchmark_hnsw_vs_brute_force():
"""Compare HNSW recall and speed against exact brute-force search."""
import time

np.random.seed(42)
n, d = 10000, 128

# Build index
index = HNSWIndex(dim=d, M=16, ef_construction=100)
vectors = np.random.randn(n, d)

print("Building HNSW index...")
t0 = time.time()
for v in vectors:
index.add(v)
build_time = time.time() - t0
print(f"Build time: {build_time:.2f}s for {n} vectors")

# Test queries
n_queries = 100
queries = np.random.randn(n_queries, d)
k = 10

# Brute-force ground truth
bf_results = []
for q in queries:
dists = np.sum((vectors - q) ** 2, axis=1)
bf_results.append(set(np.argsort(dists)[:k]))

# HNSW results
hnsw_results = []
t0 = time.time()
for q in queries:
results = index.search(q, k=k, ef=50)
hnsw_results.append(set(idx for _, idx in results))
hnsw_time = time.time() - t0

# Brute-force timing
t0 = time.time()
for q in queries:
np.sum((vectors - q) ** 2, axis=1).argsort()[:k]
bf_time = time.time() - t0

# Recall
recalls = [len(hnsw_results[i] & bf_results[i]) / k
for i in range(n_queries)]

print(f"\nResults for n={n}, d={d}, k={k}:")
print(f"HNSW Recall@10: {np.mean(recalls):.2%}")
print(f"HNSW query time: {1000*hnsw_time/n_queries:.2f}ms per query")
print(f"Brute-force time: {1000*bf_time/n_queries:.2f}ms per query")
print(f"Speedup: {bf_time/hnsw_time:.1f}x")

benchmark_hnsw_vs_brute_force()

KD-Tree vs HNSW: When to Use Which


The inverted index is the foundation of every search engine and RAG (Retrieval-Augmented Generation) system. Instead of "document -> list of words," it stores "word -> list of documents containing that word."

A query for "machine learning" resolves to: find documents in index["machine"] AND index["learning"], rank by BM25 score.

import math
from collections import defaultdict

class InvertedIndex:
"""
Inverted index for full-text retrieval.
Supports BM25 ranking (standard for RAG retrieval).

Used in: Elasticsearch, OpenSearch (Solr), every RAG system's
sparse retrieval component. Dense retrieval (HNSW) + sparse
retrieval (BM25) + reciprocal rank fusion = state-of-art RAG.
"""

def __init__(self, k1: float = 1.5, b: float = 0.75):
"""
k1: term frequency saturation (1.2-2.0 typical)
b: length normalization (0 = no normalization, 1 = full)
"""
self.k1 = k1
self.b = b
# {term: {doc_id: term_frequency}}
self.index: Dict[str, Dict[int, int]] = defaultdict(dict)
self.doc_lengths: Dict[int, int] = {}
self.doc_store: Dict[int, str] = {}
self.avg_doc_length: float = 0
self.n_docs: int = 0

def _tokenize(self, text: str) -> List[str]:
"""Simple whitespace tokenizer. Production: use BPE or WordPiece."""
return text.lower().split()

def add_document(self, doc_id: int, text: str) -> None:
"""Index a document. O(|doc|) time."""
tokens = self._tokenize(text)
self.doc_store[doc_id] = text
self.doc_lengths[doc_id] = len(tokens)
self.n_docs += 1

# Update average document length (incremental)
self.avg_doc_length = (
(self.avg_doc_length * (self.n_docs - 1) + len(tokens))
/ self.n_docs
)

# Count term frequencies in this document
term_counts: Dict[str, int] = defaultdict(int)
for token in tokens:
term_counts[token] += 1

# Update inverted index
for term, tf in term_counts.items():
self.index[term][doc_id] = tf

def bm25_score(self, term_freqs: Dict[int, int],
doc_ids: List[int]) -> Dict[int, float]:
"""
BM25 scoring for a single term across candidate documents.
BM25 = IDF * (tf * (k1+1)) / (tf + k1 * (1 - b + b * dl/avgdl))
"""
n = self.n_docs
df = len(term_freqs) # document frequency for this term

# IDF: log((N - df + 0.5) / (df + 0.5) + 1)
idf = math.log((n - df + 0.5) / (df + 0.5) + 1)

scores = {}
for doc_id in doc_ids:
tf = term_freqs.get(doc_id, 0)
dl = self.doc_lengths.get(doc_id, 0)
avgdl = self.avg_doc_length if self.avg_doc_length > 0 else 1

# BM25 TF normalization
normalized_tf = (tf * (self.k1 + 1)) / (
tf + self.k1 * (1 - self.b + self.b * dl / avgdl)
)
scores[doc_id] = idf * normalized_tf

return scores

def search(self, query: str, top_k: int = 10) -> List[Tuple[float, int, str]]:
"""
BM25 search over indexed documents.
Returns top_k (score, doc_id, text) tuples.
"""
query_terms = self._tokenize(query)

# Accumulate BM25 scores across query terms
total_scores: Dict[int, float] = defaultdict(float)

for term in query_terms:
if term not in self.index:
continue
term_postings = self.index[term]
# Candidate documents for this term
candidate_docs = list(term_postings.keys())
term_scores = self.bm25_score(term_postings, candidate_docs)
for doc_id, score in term_scores.items():
total_scores[doc_id] += score

# Sort by score and return top_k
ranked = sorted(total_scores.items(), key=lambda x: -x[1])[:top_k]
return [(score, doc_id, self.doc_store[doc_id])
for doc_id, score in ranked]


def demo_inverted_index():
"""Build a small RAG retrieval corpus."""
corpus = [
"Graph neural networks learn representations on graph-structured data",
"Transformers use self-attention to model long-range dependencies",
"HNSW enables fast approximate nearest neighbor search in high dimensions",
"Gradient descent minimizes the loss by following the negative gradient",
"BPE tokenization splits words into subword units for vocabulary efficiency",
"Adam optimizer combines momentum and adaptive learning rates",
"Inverted indexes enable fast full-text search over large document collections",
"Reservoir sampling maintains a uniform sample from a data stream",
]

idx = InvertedIndex()
for i, doc in enumerate(corpus):
idx.add_document(i, doc)

# Search
query = "fast approximate search"
results = idx.search(query, top_k=3)
print(f"Query: '{query}'\nTop results:")
for score, doc_id, text in results:
print(f" [{score:.3f}] {text}")

demo_inverted_index()

Product Quantization for Memory-Efficient Vector Storage

Storing 100M vectors of dimension 768 in float32 requires 100M x 768 x 4 = 288GB. This is impractical for most systems. Product quantization (PQ) compresses each vector by dividing it into MM sub-vectors and quantizing each sub-vector using a codebook of K=256K = 256 centroids. Storage per vector: MM bytes (one byte per sub-vector, indexing into a 256-entry codebook). Compression ratio: 768×4/(768/M)768 \times 4 / (768/M) = up to 96x for M=12M=12.

from sklearn.cluster import MiniBatchKMeans

class ProductQuantizer:
"""
Product Quantization for memory-efficient vector compression.
Jégou, Douze, Schmid, TPAMI 2011.

Divides each d-dimensional vector into M sub-vectors of d/M dimensions.
Each sub-vector is quantized to the nearest centroid in a codebook of K entries.
Storage: M bytes per vector (vs 4*d bytes uncompressed).

Used in: Faiss (IVFPQ index), Milvus, Pinecone internal compression.
Enables 50-100x memory reduction with ~2-10% recall loss at R@10.
"""

def __init__(self, d: int, M: int = 8, K: int = 256):
"""
d: vector dimension (must be divisible by M)
M: number of sub-quantizers (sub-vector count)
K: codebook size per sub-quantizer (usually 256 = 1 byte)
"""
assert d % M == 0, f"d={d} must be divisible by M={M}"
self.d = d
self.M = M
self.K = K
self.d_sub = d // M # sub-vector dimension
self.codebooks = None # will be [M, K, d_sub] after training

def train(self, X: np.ndarray, n_iter: int = 100) -> None:
"""
Train codebooks using k-means on each sub-space.
X: [n_train, d]
"""
self.codebooks = np.zeros((self.M, self.K, self.d_sub))

for m in range(self.M):
# Extract m-th sub-vector for all training points
sub_vecs = X[:, m * self.d_sub:(m + 1) * self.d_sub]

# Fit k-means on sub-vectors
km = MiniBatchKMeans(n_clusters=self.K, max_iter=n_iter,
random_state=42, n_init='auto')
km.fit(sub_vecs)
self.codebooks[m] = km.cluster_centers_

def encode(self, X: np.ndarray) -> np.ndarray:
"""
Encode vectors to PQ codes.
X: [n, d]
Returns: [n, M] uint8 array (M bytes per vector)
"""
assert self.codebooks is not None, "Train the PQ before encoding"
n = X.shape[0]
codes = np.zeros((n, self.M), dtype=np.uint8)

for m in range(self.M):
sub_vecs = X[:, m * self.d_sub:(m + 1) * self.d_sub]
# Find nearest centroid for each sub-vector
# [n, d_sub] vs [K, d_sub] -> compute distances
diffs = sub_vecs[:, np.newaxis, :] - self.codebooks[m][np.newaxis, :, :]
dists = np.sum(diffs ** 2, axis=2) # [n, K]
codes[:, m] = dists.argmin(axis=1).astype(np.uint8)

return codes

def decode(self, codes: np.ndarray) -> np.ndarray:
"""Decode PQ codes back to approximate vectors."""
n = codes.shape[0]
X_approx = np.zeros((n, self.d))

for m in range(self.M):
centroid_indices = codes[:, m].astype(int)
X_approx[:, m * self.d_sub:(m + 1) * self.d_sub] = (
self.codebooks[m][centroid_indices]
)

return X_approx

def memory_usage(self, n_vectors: int) -> dict:
"""Compare memory usage: PQ vs exact float32."""
pq_bytes = n_vectors * self.M # M bytes per vector
exact_bytes = n_vectors * self.d * 4 # 4 bytes per float32
codebook_bytes = self.M * self.K * self.d_sub * 4 # codebook overhead

return {
'n_vectors': n_vectors,
'pq_storage_gb': (pq_bytes + codebook_bytes) / 1e9,
'exact_storage_gb': exact_bytes / 1e9,
'compression_ratio': exact_bytes / (pq_bytes + codebook_bytes),
}


# Demonstrate PQ compression
np.random.seed(42)
n_train, n_query, d, M = 10000, 1000, 128, 8

X_train = np.random.randn(n_train, d)
X_query = np.random.randn(n_query, d)

pq = ProductQuantizer(d=d, M=M, K=256)
pq.train(X_train)

codes = pq.encode(X_train)
X_reconstructed = pq.decode(codes)

# Reconstruction error
rel_error = np.linalg.norm(X_train - X_reconstructed) / np.linalg.norm(X_train)
print(f"PQ reconstruction error: {rel_error:.4f}")

mem = pq.memory_usage(n_train)
print(f"Exact storage: {mem['exact_storage_gb']*1000:.1f} MB")
print(f"PQ storage: {mem['pq_storage_gb']*1000:.1f} MB")
print(f"Compression ratio: {mem['compression_ratio']:.1f}x")

LSM Tree for Write-Heavy Feature Stores

Why B-Trees Fail for Feature Stores

ML feature stores need to handle millions of writes per second: user activity events, model predictions, and feature updates all arrive continuously. B-trees are write-inefficient because inserting or updating a key requires a random read (find the leaf node), modify in place, and potentially propagate rebalancing changes up the tree. On SSDs, random writes are expensive and cause write amplification.

LSM-trees convert random writes to sequential appends. All writes go to an in-memory buffer (MemTable). When the buffer fills, it is flushed as a sorted file (SST - Sorted String Table) to disk. Reads merge results from the MemTable and all SST files. Periodically, SST files are compacted (merged and sorted) to reduce the number of files and improve read performance.

import os
import json
import time
from bisect import bisect_left, insort

class MemTable:
"""In-memory sorted buffer for LSM-tree writes."""

def __init__(self, max_size_bytes: int = 4 * 1024 * 1024):
self.data: dict = {} # key -> (value, timestamp)
self.max_size = max_size_bytes
self.current_size = 0

def put(self, key: str, value: any) -> None:
size_delta = len(str(key)) + len(str(value)) - (
len(str(self.data.get(key, ('', 0))[0])) if key in self.data else 0
)
self.data[key] = (value, time.time())
self.current_size += max(0, size_delta)

def get(self, key: str):
entry = self.data.get(key)
return entry[0] if entry else None

def is_full(self) -> bool:
return self.current_size >= self.max_size

def to_sorted_list(self) -> list:
return sorted(self.data.items())

def size(self) -> int:
return self.current_size


class SimpleLSMTree:
"""
Simplified LSM-tree demonstrating the write path.
Production: use RocksDB (Python: rocksdb3 package) or LevelDB.

Used in ML feature stores:
- Feast with Redis backend (Redis Streams + RDB snapshots)
- Tecton with DynamoDB (LSM-based)
- Hopsworks Feature Store (RocksDB for online features)
- TFX Feature Store (BigTable = distributed LSM)
"""

def __init__(self, data_dir: str = '/tmp/lsm_demo'):
os.makedirs(data_dir, exist_ok=True)
self.data_dir = data_dir
self.memtable = MemTable()
self.sst_files: List[str] = []
self.level0_limit = 4 # compact when 4 SST files in level 0

def put(self, key: str, value: any) -> None:
"""Write to MemTable. O(1) amortized (fast!)"""
self.memtable.put(key, value)

if self.memtable.is_full():
self._flush_memtable()

def _flush_memtable(self) -> None:
"""Flush MemTable to SST file (sorted on disk)."""
sorted_data = self.memtable.to_sorted_list()
sst_path = os.path.join(
self.data_dir, f"sst_{int(time.time()*1000)}.json"
)
with open(sst_path, 'w') as f:
json.dump(sorted_data, f)
self.sst_files.append(sst_path)
self.memtable = MemTable()

if len(self.sst_files) >= self.level0_limit:
self._compact()

def _compact(self) -> None:
"""
Merge SST files into a single larger file (L0 -> L1 compaction).
Removes duplicate keys (keep latest timestamp).
This is what RocksDB's Leveled Compaction does, but much simplified.
"""
merged = {}
for sst_path in self.sst_files:
with open(sst_path, 'r') as f:
entries = json.load(f)
for key, (value, ts) in entries:
if key not in merged or ts > merged[key][1]:
merged[key] = (value, ts)

# Write compacted file
compacted_path = os.path.join(
self.data_dir, f"compacted_{int(time.time()*1000)}.json"
)
sorted_merged = sorted(merged.items())
with open(compacted_path, 'w') as f:
json.dump(sorted_merged, f)

# Remove old SST files
for sst_path in self.sst_files:
try:
os.remove(sst_path)
except FileNotFoundError:
pass
self.sst_files = [compacted_path]

def get(self, key: str) -> Optional[any]:
"""
Read: check MemTable first, then SST files (newest first).
Add Bloom filter to skip SST files that definitely don't contain key.
"""
# Check MemTable (most recent writes)
val = self.memtable.get(key)
if val is not None:
return val

# Check SST files from newest to oldest
best_value = None
best_ts = -1

for sst_path in reversed(self.sst_files):
with open(sst_path, 'r') as f:
entries = json.load(f)
# Binary search in sorted SST (O(log n) per file)
keys_only = [k for k, _ in entries]
idx = bisect_left(keys_only, key)
if idx < len(entries) and entries[idx][0] == key:
value, ts = entries[idx][1]
if ts > best_ts:
best_ts = ts
best_value = value

return best_value

Production Engineering Notes

HNSW Hyperparameter Tuning in Production

The three key HNSW parameters control the recall-speed-memory tradeoff:

  • M (connections per node): higher M improves recall at cost of memory and build time. M=16 is standard for 95%+ recall. M=32 gives 98%+ recall at 2x memory.
  • ef_construction: larger value improves index quality (better graph structure) but slows builds. 100-200 is standard.
  • ef_search: controls query-time beam width. ef=50 gives fast queries at ~95% recall; ef=200 gives ~99% recall at 2-3x slower queries.
# Production HNSW with hnswlib (10x faster than pure Python)
# pip install hnswlib
import hnswlib
import numpy as np

def build_production_hnsw(vectors: np.ndarray,
M: int = 16,
ef_construction: int = 100) -> hnswlib.Index:
"""
Build production-grade HNSW index with hnswlib.
Backed by C++ implementation - ~100x faster than pure Python.
"""
n, d = vectors.shape
index = hnswlib.Index(space='cosine', dim=d)

# Initialize with M and ef_construction
index.init_index(max_elements=n * 2, # leave room for future additions
ef_construction=ef_construction,
M=M)

# Multi-threaded insertion
index.add_items(vectors, num_threads=4)

# Set query ef (can be changed at query time)
index.set_ef(50)

return index
warning

HNSW Does Not Support Deletions Natively

Standard HNSW marks deleted vectors as "deleted" but keeps them in the graph structure. Deleted nodes still participate in graph navigation - they are just excluded from results. For datasets with frequent deletions (recommendation systems with item expiry), this means the graph degrades over time: more traversal nodes are deleted, requiring longer search paths. Solutions: periodically rebuild the index, use Faiss's DiskANN which supports dynamic updates, or track deletion ratio and trigger rebuilds when it exceeds 10-20%.

danger

LSM-Tree Read Amplification Without Bloom Filters

An LSM-tree can have many SST files, and a read that misses the MemTable must search through all of them. Without Bloom filters, a read for a missing key (e.g., a feature lookup for a new user) scans ALL SST files. In RocksDB with 10 levels and 100 SST files per level, a worst-case read touches 1000 files. RocksDB and LevelDB add a Bloom filter to each SST file: a space-efficient probabilistic structure that answers "is this key definitely NOT in this file?" with 99% accuracy, allowing you to skip ~99% of files per read. Always ensure Bloom filters are enabled in your LSM-tree-backed feature store configuration.


Interview Questions and Answers

Q1: Why does a trie work better than a hash map for tokenizer vocabulary lookup?

The tokenizer needs to find the longest vocabulary match starting at each position in the text. With a hash map, for a position that could match a token of up to length LL, you must hash all LL possible prefixes to check which ones are in the vocabulary - that is LL hash computations and LL memory lookups.

A trie organizes the vocabulary as a tree where each edge represents a character. Starting from the root, you traverse the trie following the characters of the input text. At each step, you check whether the current node is a complete token (update "last match"). You stop when you hit a node with no child for the next character. This completes in exactly LL character comparisons with zero redundant work - one comparison per character, period.

Additional advantages of tries: (1) all prefixes are automatically enumerated - essential for constrained decoding in LLMs where you want to limit the model to only generate valid tokens; (2) you can iterate over all completions of a prefix in O(completions)O(\text{completions}) time - useful for autocomplete.

The main disadvantage is memory: each trie node stores a dictionary of children, and for a 100K-token vocabulary with average token length 6, you might have 500K-600K nodes each with dictionary overhead. Practical BPE tokenizers use byte-level tries which are more memory-efficient.

Q2: Explain how HNSW achieves sub-linear time approximate nearest neighbor search.

HNSW builds a hierarchical graph where each level is a progressively sparser subgraph of the complete neighbor graph. The top layer has very few nodes connected over long distances; the bottom layer (level 0) has all nodes with dense local connections.

During a query, you start at the single entry point in the top layer and greedily move toward the query's nearest neighbor. Since the top layer is sparse, each greedy step covers a large region of the space. Once you cannot improve at the current layer, you descend to the next (denser) layer and continue greedy search in the local neighborhood you identified.

This is the same "zoom in" strategy as navigating with a map: first find the right city (sparse layer), then the right neighborhood (denser layer), then the right street (densest layer). Time complexity is empirically O(lognMd)O(\log n \cdot M \cdot d) where MM is the connection count per node.

The approximation comes from the greedy search: you might navigate into a local minimum in the upper layers and miss the true global nearest neighbor. The ef (search beam width) parameter controls this - instead of greedy single-path descent, you maintain a beam of ef candidates, dramatically reducing missed retrievals at the cost of more computation.

Q3: What is the LSM-tree write path and why is it faster than B-tree writes?

B-tree write path: (1) find the leaf page containing the key - random read from disk or OS cache; (2) modify the page in memory; (3) write the dirty page back to disk - random write; (4) potentially update parent pages and rebalance - more random writes. Random writes on SSDs cause write amplification: the SSD controller must read-modify-write 4KB pages even for single byte changes, amplifying writes by 10-100x at the storage layer.

LSM-tree write path: (1) append to in-memory MemTable - O(1), no disk I/O; (2) when MemTable fills, flush to a sorted SST file sequentially - one large sequential write; (3) periodically compact (merge and sort) SST files in the background. All disk writes are large sequential writes, which are 10-100x faster than random writes on SSDs and have zero write amplification from the application's perspective.

The tradeoff: LSM reads are slower than B-tree reads. A read must check the MemTable plus potentially many SST files. Bloom filters reduce the cost of negative reads (key not present), but reads that touch many SST levels still pay for multiple file accesses. This is why LSM is optimal for feature stores (write-heavy: continuous feature updates) and B-trees are optimal for transactional databases (read-heavy, balanced reads/writes).

Q4: Explain product quantization. What accuracy do you sacrifice for what compression?

Product quantization compresses a dd-dimensional float32 vector by dividing it into MM sub-vectors of dimension d/Md/M and replacing each sub-vector with the index of its nearest centroid in a codebook of KK entries (typically K=256K=256, requiring 1 byte per sub-vector).

Storage: MM bytes per vector instead of 4d4d bytes. Compression ratio: 4d/M4d/M. For d=768,M=8d=768, M=8: compression from 3072 bytes to 8 bytes = 384x compression.

Accuracy impact: each sub-vector is approximated by its nearest codebook centroid. Distance computations use precomputed lookup tables: for a query qq, precompute the distance from qq's mm-th sub-vector to each of the KK centroids in codebook mm. For any database vector, its approximate distance to qq is the sum of MM table lookups. This is fast (much faster than exact distance computation) but introduces quantization error.

In practice (Faiss IVFPQ benchmarks): on 1 billion 128-dimensional SIFT vectors, M=8 gives ~90% recall@10 with 16x compression. M=16 gives ~95% recall@10 with 8x compression. The sweet spot for production (e.g., Pinecone's internal storage) is M=16-32 with recall@10 > 95%.

Q5: What is the difference between an inverted index and a forward index? When would you use each?

A forward index maps document IDs to the list of terms they contain. A query "what terms are in document 5?" is O(1). But a query "which documents contain term X?" requires scanning all documents - O(n).

An inverted index maps terms to the list of documents containing them (with term frequencies). A query "which documents contain term X?" is O(df_X) where df_X is the document frequency of term X - potentially much less than O(n). Conjunction queries ("documents containing X AND Y") use intersection of the posting lists.

Use forward index: when you need to analyze document content given an ID (e.g., retrieving a document for display, computing per-document features). Use inverted index: for keyword search, BM25 retrieval, faceted search, and the sparse retrieval component of hybrid RAG systems.

Modern RAG systems use both: an inverted index for BM25 sparse retrieval and an HNSW index for dense vector retrieval. Reciprocal Rank Fusion combines their rankings. This hybrid approach consistently outperforms either alone on open-domain QA benchmarks (BEIR, TREC-DL).

Q6: You need to build a vector search system for 500M embeddings of dimension 512. Walk through your design.

Storage: 500M x 512 x 4 bytes = 1TB uncompressed. Use product quantization with M=16 (32 bytes/vector): 16GB total. Add codebook overhead: 16 x 256 x 32 x 4 = 524KB - negligible. Total storage with PQ: ~16GB.

Index: HNSW with M=32 for high recall. But HNSW graph storage: 500M x 32 connections x 4 bytes x 2 layers (average) = ~128GB. This needs a memory-mapped file or distribution across machines.

Production architecture: Faiss IVFPQ - inverted file with PQ compression. The IVF (inverted file) partitions vectors into K clusters (~65536 for 500M vectors). At query time, probe the nearest nprobe centroids (~128), then use PQ to rank candidates within those clusters. Memory: ~16GB PQ codes + ~256MB centroids. Query time: ~1-5ms on a single GPU.

Alternatively, use DiskANN (Microsoft): stores the HNSW-like graph on SSD, serving queries with 5-20ms latency at 99% recall. Cost-effective when GPU memory is the constraint. This is what Microsoft Azure Cognitive Search uses at scale.


Summary

Data structure choices in ML systems are not an afterthought - they are the foundation that determines whether a system handles production scale or collapses at 10x the demo load. Every component of a production ML system has a specialized structure that fits its access pattern:

The trie at the heart of every BPE tokenizer enables O(L)O(L) prefix matching that makes real-time tokenization fast. HNSW enables million-scale vector search at millisecond latency by exploiting hierarchical navigability in the proximity graph. The inverted index enables sub-linear keyword search that is the backbone of every RAG retrieval pipeline. Product quantization compresses vector storage by 50-100x, making billion-scale semantic search economically viable. LSM trees handle the write-heavy access patterns of feature stores without the random-write penalty that would cripple B-trees at streaming event rates.

The common thread: each structure exploits a specific structural property of its data (sorted vocabulary for tries, hierarchical graph structure for HNSW, sorted immutable files for LSM) to achieve performance that general-purpose structures cannot match. Knowing when to use each structure - and understanding its failure modes - is what separates ML engineers who build systems that scale from those who build systems that demo.

© 2026 EngineersOfAI. All rights reserved.