Sorting and Search for ML
The Inference Bottleneck That Was Not the Model
A team deploying a large language model for a search application had an interesting problem. The model itself was fast - 80ms per forward pass on their A100s. But end-to-end latency was 3 seconds. Three seconds for a query that the model processes in 80ms. The culprit? Post-processing. After the model generated log-probabilities over the 50,000-token vocabulary at each step, the system was calling Python's built-in sorted() on the full vocabulary list to pick the top-5 candidate tokens for beam search.
sorted() on 50,000 floats at every decoding step, 100 steps per generation, batched at 32 requests: that is sort operations. They were sorting the entire vocabulary to find 5 tokens.
The fix was immediate once they saw it: torch.topk(logits, k=5) instead of sorting and slicing. torch.topk uses a partial sort algorithm (quickselect-based) with complexity where is vocabulary size and . For , this is nearly - just one linear scan with a small heap. Total latency dropped from 3 seconds to 120ms. The model was no longer the bottleneck.
This is the pattern sorting and search shows up in ML: not as academic exercises, but as the hidden cost in retrieval pipelines, beam decoders, ranking systems, and data preprocessing. Understanding the algorithms and knowing which tool to reach for at each scale is what separates engineers who debug latency issues in hours from those who spend days measuring the wrong thing.
Sorting and search appear at every layer of the ML stack. Ranking outputs by score. Finding the most similar embeddings in a database. Selecting the top-k candidates during beam search. Bucketing training samples by sequence length to minimize padding waste. Deduplicating a training corpus. Binary searching for optimal thresholds in a calibration curve. Each of these has a right algorithm, and the wrong choice can make a system 10-100x slower than it needs to be.
This lesson covers the full toolkit: comparison-based sorts, counting-based sorts, binary search variants, top-k selection, sorting in distributed settings, and priority queues for beam search. Each topic is grounded in ML applications and accompanied by working code.
Why This Exists
The fundamental problem sorting solves is ordering: given a collection of items with a defined comparison relation, arrange them from smallest to largest (or most to least relevant). Search extends this: given a sorted structure, find items efficiently without scanning everything.
Before efficient sorting algorithms, the only way to find the most relevant items from candidates was to score all and scan through them. For small this is fine. For retrieval from a billion-item database, it is catastrophically slow. Efficient sorting and search are what make large-scale retrieval, ranking, and decoding possible at all.
The history of sorting algorithms is one of the richest in computer science. The bubble sort and insertion sort were obvious. The merge sort was invented by John von Neumann in 1945. Quicksort - which achieves average but worst case - was invented by Tony Hoare in 1959 and remains the basis for most production sort implementations. Timsort, which powers Python's list.sort() and NumPy's np.sort, was developed by Tim Peters in 2002 specifically to exploit natural ordering in real-world data.
Core Concepts
Comparison-Based Sorting: The Floor
Any algorithm that sorts by comparing elements pairwise cannot do better than in the worst case. This is a provable lower bound: a comparison sort builds a decision tree where each internal node is a comparison and each leaf is one of the possible orderings. A binary decision tree with leaves has depth at least by Stirling's approximation.
So merge sort, heapsort, and timsort are all optimal in the comparison model. Quicksort is also average but worst case (though with good pivot selection strategies, the worst case is extremely rare in practice).
Mergesort: Divide-and-conquer. Split array in half, sort each half recursively, merge the two sorted halves. Guaranteed always. Stable (equal elements maintain relative order). Requires extra space for the merge buffer. Used by Python for sorting objects.
Heapsort: Build a max-heap in , then repeatedly extract the max in . Total . extra space. Not stable. Poor cache performance because heap operations access memory in non-sequential patterns. In practice slower than quicksort and timsort despite same asymptotic complexity.
Quicksort: Pick a pivot, partition elements into "less than pivot" and "greater than pivot", recurse on each partition. average, worst case (when pivot always picks the smallest or largest element). Excellent cache performance because it accesses elements sequentially within partitions. The basis for most production sort implementations.
Timsort: The Sort Python and NumPy Actually Use
Timsort (Tim Peters, 2002) is a hybrid of merge sort and insertion sort, designed to run efficiently on the "real-world" data that mostly has long already-sorted runs. It is the algorithm behind Python's list.sort(), sorted(), Java's Arrays.sort() for objects, and NumPy's np.sort() with kind='stable'.
The core insight: real data is rarely uniformly random. A list of timestamps is mostly sorted. A list of scores after a training step is partially ordered. A list of document IDs has long monotone runs. Timsort detects "natural runs" (already sorted subsequences) in the input and merges them. For already-sorted input, timsort is . For random input, it degrades to like merge sort.
Key properties:
- worst case (guaranteed, unlike quicksort)
- best case (fully sorted input)
- Stable (equal elements maintain order - critical for multi-key sorting)
- extra space in the worst case
In ML: when you sort training samples by sequence length (to pack similar-length samples into batches and minimize padding), timsort's natural-run detection makes repeated sorts fast because lengths do not change between epochs.
# NumPy sort options
import numpy as np
x = np.array([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0])
# Default sort (introsort for numeric, timsort for stable)
sorted_x = np.sort(x) # returns new array
np.sort(x, kind='stable') # timsort (stable)
np.sort(x, kind='quicksort') # introsort (fast, not stable)
np.sort(x, kind='mergesort') # timsort (same as 'stable')
np.sort(x, kind='heapsort') # heapsort (in-place, not stable)
# argsort: returns indices that would sort the array
idx = np.argsort(x) # O(n log n)
# x[idx] gives the sorted array
# idx gives the original positions - crucial for sorting embeddings
# Example: sort embeddings by similarity score
scores = np.array([0.82, 0.91, 0.76, 0.95, 0.88])
embeddings = np.random.randn(5, 768) # 5 embeddings, 768-dim
sorted_idx = np.argsort(scores)[::-1] # descending order
sorted_embeddings = embeddings[sorted_idx] # embeddings sorted by score
sorted_scores = scores[sorted_idx]
# PyTorch equivalent
import torch
scores_t = torch.tensor([0.82, 0.91, 0.76, 0.95, 0.88])
sorted_vals, sorted_idx = torch.sort(scores_t, descending=True)
Counting Sort and Radix Sort: Breaking the Floor
When the comparison lower bound does not apply - specifically when keys are non-negative integers from a bounded range - sorting can be done in where is the range of values. This is faster than when .
Counting sort: Count how many times each value appears, then reconstruct the sorted array from counts. Time , space .
def counting_sort(arr: list, max_val: int) -> list:
"""O(n + R) sort for non-negative integers in [0, max_val]."""
counts = [0] * (max_val + 1)
for val in arr:
counts[val] += 1
result = []
for val, count in enumerate(counts):
result.extend([val] * count)
return result
Radix sort: Sort by least significant digit first, then next digit, etc. Each pass is a counting sort over a small range (e.g., 256 values for byte-level radix sort). For -bit integers, passes suffices. Total: for fixed-width integers.
ML use cases for integer sorts:
- Bucketing by sequence length: sequence lengths are integers in a small range. Counting sort to group samples by length for efficient batching.
- Token ID sorting: vocabulary indices are integers. Radix sort for fast reordering.
- Sorting rank indices in sparse attention masks.
- GPU sorting: CUDA implementations of radix sort are the fastest sort on GPUs because they map naturally to parallel prefix operations.
Sorting Networks and GPU Parallelism
A sorting network is a fixed sequence of comparisons that sorts any input, independent of the data values. Unlike comparison-based sorts where the sequence of comparisons depends on the data, a sorting network always makes the same comparisons.
This is ideal for GPUs: the GPU can execute all comparisons in a level of the network in parallel, since they are data-independent. Bitonic sort is the most common: for elements, it uses parallel steps. With a GPU processing all comparisons in a step simultaneously, the wall-clock time scales as rather than .
In practice: torch.sort() on GPU uses a GPU-optimized sort (usually thrust or CUB radix sort for numeric types) that is extremely fast for batch sorting of tensors. For sorting logits during beam search, always use torch.sort or torch.topk on GPU rather than moving data to CPU for sorting.
Binary Search and Its Variants
Binary search finds a target value in a sorted array in time. The key variants:
bisect_left(a, x): returns the leftmost index wherexcould be inserted to keepasorted. Ifxis ina, returns the index of the first occurrence.bisect_right(a, x): returns the rightmost insertion point - index after the last occurrence ofx.
import bisect
# Find exact position
a = [1, 3, 3, 5, 7, 9, 11]
bisect.bisect_left(a, 3) # returns 1 (first 3)
bisect.bisect_right(a, 3) # returns 3 (after last 3)
bisect.bisect_left(a, 4) # returns 3 (where 4 would go)
# Count occurrences in O(log n) - useful for computing quantiles
def count_le(sorted_arr, x):
"""Number of elements <= x in sorted array."""
return bisect.bisect_right(sorted_arr, x)
# Find the index where a score threshold should be applied
# e.g., find minimum confidence threshold for 95% precision
sorted_scores = sorted([0.3, 0.5, 0.55, 0.72, 0.88, 0.91, 0.95])
threshold_idx = bisect.bisect_left(sorted_scores, 0.7)
# All elements from threshold_idx onward are >= 0.7
Binary search for hyperparameter tuning: Many hyperparameter searches can be formulated as finding the optimal value in a sorted/monotone search space. For example, finding the minimum batch size that still fits in GPU memory (binary search over batch sizes). Or finding the maximum learning rate before divergence (binary search over a log-scale grid).
Ternary search for unimodal functions: When a metric like validation loss is unimodal over a hyperparameter (decreases then increases), ternary search finds the minimum in evaluations where is the search range. Each step narrows the range by 1/3 by evaluating at the two trisection points and discarding the worse third.
def ternary_search(f, lo: float, hi: float,
tolerance: float = 1e-4,
max_iter: int = 100) -> float:
"""
Find minimum of unimodal function f on [lo, hi].
Classical ternary search: each step evaluates f at two points,
discards the worse third of the interval.
Converges in O(log_{1.5}(range/tolerance)) iterations.
ML use case: find optimal learning rate when loss vs LR is unimodal.
"""
for _ in range(max_iter):
if hi - lo < tolerance:
break
m1 = lo + (hi - lo) / 3
m2 = hi - (hi - lo) / 3
if f(m1) < f(m2):
hi = m2 # minimum is in [lo, m2]
else:
lo = m1 # minimum is in [m1, hi]
return (lo + hi) / 2
# Example: find optimal weight decay for a small model
# (assumes validation loss is unimodal over log(weight_decay))
import math
def mock_val_loss(log_wd: float) -> float:
"""Simulated U-shaped validation loss over log(weight_decay)."""
wd = math.exp(log_wd)
# Minimum around wd = 0.01 (log_wd = -4.6)
return (log_wd + 4.6) ** 2 + 0.5 + 0.1 * hash(int(log_wd * 1e6) % 100) / 1e6
optimal_log_wd = ternary_search(mock_val_loss, lo=-9.0, hi=0.0, tolerance=0.01)
print(f"Optimal weight decay: {math.exp(optimal_log_wd):.6f}")
Top-k Selection: When You Do Not Need Full Sorting
When you only need the top-k elements (not the full sorted order), you can do much better than . The key algorithms:
Quickselect: Partition-based selection. Pick a pivot, partition into "smaller" and "larger". If the pivot lands at position , you are done. Otherwise recurse on the appropriate half. Average , worst . Same pivot-selection problem as quicksort. Used internally by np.partition.
Heap-based top-k: Maintain a min-heap of size as you scan the array. For each new element, if it is larger than the heap minimum, pop the minimum and push the new element. After scanning all elements, the heap contains the top-. Time: . Space: .
For (vocabulary) and : vs - about 6.7x faster.
import numpy as np
import torch
import heapq
import time
def topk_comparison(n: int = 50_000, k: int = 10):
"""Compare different top-k implementations."""
scores = np.random.randn(n).astype(np.float32)
scores_t = torch.from_numpy(scores)
# Method 1: Full sort then slice - O(n log n)
start = time.perf_counter()
for _ in range(1000):
idx1 = np.argsort(scores)[-k:][::-1]
t1 = (time.perf_counter() - start) / 1000
# Method 2: np.argpartition - O(n) average via quickselect
start = time.perf_counter()
for _ in range(1000):
partition_idx = np.argpartition(scores, -k)[-k:]
# Note: argpartition does NOT give sorted order within the top-k
idx2 = partition_idx[np.argsort(scores[partition_idx])[::-1]]
t2 = (time.perf_counter() - start) / 1000
# Method 3: torch.topk - O(n log k), GPU-optimized
start = time.perf_counter()
for _ in range(1000):
vals3, idx3 = torch.topk(scores_t, k=k)
t3 = (time.perf_counter() - start) / 1000
# Method 4: heapq.nlargest - O(n log k)
start = time.perf_counter()
for _ in range(1000):
result4 = heapq.nlargest(k, range(n), key=lambda i: scores[i])
t4 = (time.perf_counter() - start) / 1000
print(f"n={n}, k={k}")
print(f" np.argsort (full sort): {t1*1e6:.1f} us")
print(f" np.argpartition (partial): {t2*1e6:.1f} us {t1/t2:.1f}x faster")
print(f" torch.topk: {t3*1e6:.1f} us {t1/t3:.1f}x faster")
print(f" heapq.nlargest: {t4*1e6:.1f} us {t1/t4:.1f}x faster")
topk_comparison(n=50_000, k=5)
topk_comparison(n=50_000, k=100)
FAISS IVF: Sorting in the Embedding Space
FAISS IVF (Inverted File Index) is the most widely deployed approximate nearest neighbor data structure for large-scale ML. It is essentially a sorted/clustered structure over embeddings.
The construction: k-means partition the dataset into clusters (centroids). For each embedding, record which cluster it belongs to. At query time:
- Compute distance from query to all centroids:
- Sort centroids by distance:
- Search the top nearest centroids' posting lists:
Total query complexity: . Optimal gives - much better than brute force .
The posting lists within each cluster are sorted by the cluster centroid's local distance, and searching through them is binary-search-like. The name "Inverted File" comes from information retrieval: each cluster is like a word, and its posting list is the list of documents (embeddings) that "contain" that word (fall within that cluster's region).
# Conceptual IVF implementation to understand the algorithm
import numpy as np
from typing import List, Tuple
class SimpleIVFIndex:
"""
Simplified IVF index to illustrate the algorithm.
Use FAISS in production.
"""
def __init__(self, n_clusters: int = 100):
self.n_clusters = n_clusters
self.centroids = None # shape: (n_clusters, d)
self.posting_lists = None # list of lists of (embedding, original_id)
def train(self, embeddings: np.ndarray):
"""
k-means to find cluster centroids.
embeddings: (N, d)
"""
from sklearn.cluster import MiniBatchKMeans
kmeans = MiniBatchKMeans(n_clusters=self.n_clusters, random_state=42)
kmeans.fit(embeddings)
self.centroids = kmeans.cluster_centers_.astype(np.float32)
# Build posting lists
labels = kmeans.labels_
self.posting_lists = [[] for _ in range(self.n_clusters)]
for idx, label in enumerate(labels):
self.posting_lists[label].append((embeddings[idx], idx))
def search(self, query: np.ndarray, k: int = 10,
n_probe: int = 10) -> Tuple[np.ndarray, np.ndarray]:
"""
Approximate nearest neighbor search.
n_probe: number of nearest clusters to search (recall vs speed tradeoff)
"""
# Step 1: find nearest centroids - O(C * d)
centroid_distances = np.linalg.norm(
self.centroids - query[np.newaxis, :], axis=1
) # shape: (n_clusters,)
# Step 2: sort centroids by distance - O(C log C)
nearest_clusters = np.argsort(centroid_distances)[:n_probe]
# Step 3: search posting lists of nearest clusters
candidates = []
for cluster_id in nearest_clusters:
for emb, orig_id in self.posting_lists[cluster_id]:
dist = np.linalg.norm(emb - query)
candidates.append((dist, orig_id))
# Step 4: top-k among candidates - O(M log k) where M = n_probe * N/C
candidates.sort(key=lambda x: x[0])
top_k = candidates[:k]
distances = np.array([d for d, _ in top_k])
indices = np.array([i for _, i in top_k])
return distances, indices
Priority Queues for Beam Search
Beam search - the standard decoding algorithm for sequence generation in LLMs, translation, ASR - uses a priority queue to maintain the top- partial sequences at each step.
A priority queue (heap) supports:
push(item, priority):pop(): extract min/max priority item:peek(): view min/max without removing:
Python's heapq is a min-heap (smallest priority at top). For beam search where we want the highest-score sequences, use negative scores.
import heapq
import numpy as np
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
@dataclass(order=True)
class BeamHypothesis:
"""A single beam hypothesis."""
neg_score: float # Negative log-prob (min-heap pops lowest = best)
token_ids: List[int] = field(compare=False)
score: float = field(compare=False)
@property
def last_token(self) -> int:
return self.token_ids[-1]
class BeamSearchDecoder:
"""
Beam search with a priority queue.
Complexity per decoding step:
- For each of k beams: compute vocab scores = O(k * V)
- Expand: k * V candidates
- Top-k selection via heap: O(k*V * log k)
- Total per step: O(k*V*log k)
- Over T steps: O(T * k * V * log k)
"""
def __init__(self, beam_width: int = 4, vocab_size: int = 50_000,
max_length: int = 128, eos_token_id: int = 2):
self.beam_width = beam_width
self.vocab_size = vocab_size
self.max_length = max_length
self.eos_token_id = eos_token_id
def decode(self, model_fn, initial_token: int = 1) -> List[BeamHypothesis]:
"""
model_fn(token_ids) -> log_probs over vocabulary (shape: vocab_size)
"""
# Initialize with a single hypothesis containing the start token
initial = BeamHypothesis(
neg_score=0.0,
token_ids=[initial_token],
score=0.0
)
active_beams = [initial]
completed = []
for step in range(self.max_length):
if not active_beams:
break
# Collect all candidates from expanding active beams
# Use a min-heap of size beam_width to track top candidates
# This is O(n * log(beam_width)) where n = len(active_beams) * vocab_size
candidates = [] # (neg_total_score, hypothesis)
for hyp in active_beams:
# Get log-probabilities for next token
log_probs = model_fn(hyp.token_ids) # shape: (vocab_size,)
# Add this hypothesis's current score
total_log_probs = hyp.score + log_probs # (vocab_size,)
# Use torch.topk or np.argpartition for efficiency
# Rather than sorting all vocab_size scores
top_indices = np.argpartition(total_log_probs, -self.beam_width)[-self.beam_width:]
for token_id in top_indices:
new_score = float(total_log_probs[token_id])
new_hyp = BeamHypothesis(
neg_score=-new_score,
token_ids=hyp.token_ids + [int(token_id)],
score=new_score
)
candidates.append(new_hyp)
# Select top beam_width candidates using a heap
# O(len(candidates) * log(beam_width))
top_candidates = heapq.nsmallest(self.beam_width, candidates)
# Separate completed (EOS) from active
active_beams = []
for hyp in top_candidates:
if hyp.last_token == self.eos_token_id:
completed.append(hyp)
else:
active_beams.append(hyp)
# Add any remaining active beams to completed
completed.extend(active_beams)
# Return sorted by score (descending)
completed.sort(key=lambda h: h.score, reverse=True)
return completed
# Demo with mock model
def mock_lm(token_ids: List[int]) -> np.ndarray:
"""Mock language model returning uniform-ish log-probs."""
rng = np.random.default_rng(seed=sum(token_ids))
log_probs = rng.standard_normal(50_000).astype(np.float32)
log_probs -= np.max(log_probs)
log_probs -= np.log(np.sum(np.exp(log_probs)))
return log_probs
decoder = BeamSearchDecoder(beam_width=4, vocab_size=50_000, max_length=20)
results = decoder.decode(mock_lm, initial_token=1)
print(f"Top hypothesis: tokens={results[0].token_ids[:10]}... score={results[0].score:.3f}")
External Sort for Large ML Datasets
When a dataset does not fit in RAM - common in large-scale pretraining - you cannot sort it in memory. External sort is the solution: sort chunks that fit in memory individually, then merge the sorted chunks.
Phase 1 (sort runs): Read -byte chunks, sort each in memory, write back sorted "runs." Phase 2 (merge): -way merge the sorted runs using a min-heap of size (one entry per run).
Total I/O: passes over the data. Total time: with a small constant.
In ML: used for deduplicating pretraining corpora (sort by hash, find adjacent duplicates), sorting training data by sequence length (pack batches efficiently), and merging sharded dataset files.
import heapq
import os
import json
from typing import Iterator
def external_sort_by_length(
input_file: str,
output_file: str,
chunk_size_mb: int = 512
) -> None:
"""
Sort a JSONL dataset by sequence length using external sort.
Uses (sequence_length, line) as sort key.
Phase 1: Sort chunks that fit in chunk_size_mb of RAM.
Phase 2: K-way merge sorted chunk files.
"""
chunk_size_bytes = chunk_size_mb * 1024 * 1024
chunk_files = []
# Phase 1: Sort in-memory chunks, write to temp files
current_chunk = []
current_size = 0
chunk_idx = 0
with open(input_file, 'r') as f:
for line in f:
obj = json.loads(line)
length = len(obj.get('text', '').split())
current_chunk.append((length, line.strip()))
current_size += len(line)
if current_size >= chunk_size_bytes:
# Sort this chunk in memory
current_chunk.sort(key=lambda x: x[0])
chunk_path = f"/tmp/sort_chunk_{chunk_idx}.jsonl"
with open(chunk_path, 'w') as cf:
for _, l in current_chunk:
cf.write(l + '\n')
chunk_files.append(chunk_path)
current_chunk = []
current_size = 0
chunk_idx += 1
# Handle last chunk
if current_chunk:
current_chunk.sort(key=lambda x: x[0])
chunk_path = f"/tmp/sort_chunk_{chunk_idx}.jsonl"
with open(chunk_path, 'w') as cf:
for _, l in current_chunk:
cf.write(l + '\n')
chunk_files.append(chunk_path)
# Phase 2: K-way merge using a min-heap
# Heap entries: (length, chunk_index, line)
heap = []
file_handles = [open(f, 'r') for f in chunk_files]
# Initialize heap with first line from each chunk
for i, fh in enumerate(file_handles):
line = fh.readline().strip()
if line:
obj = json.loads(line)
length = len(obj.get('text', '').split())
heapq.heappush(heap, (length, i, line))
with open(output_file, 'w') as out:
while heap:
length, chunk_idx_h, line = heapq.heappop(heap)
out.write(line + '\n')
# Get next line from the same chunk
next_line = file_handles[chunk_idx_h].readline().strip()
if next_line:
obj = json.loads(next_line)
next_length = len(obj.get('text', '').split())
heapq.heappush(heap, (next_length, chunk_idx_h, next_line))
# Cleanup
for fh in file_handles:
fh.close()
for chunk_path in chunk_files:
os.remove(chunk_path)
print(f"Sorted {len(chunk_files)} chunks into {output_file}")
torch.topk vs numpy.partition: Practical Comparison
import numpy as np
import torch
import time
def compare_topk_methods(n: int = 100_000, k: int = 100):
"""
Comprehensive comparison of top-k methods for ML use cases.
"""
data_np = np.random.randn(n).astype(np.float32)
data_t = torch.from_numpy(data_np)
n_reps = 500
# 1. Full argsort (baseline, O(n log n))
start = time.perf_counter()
for _ in range(n_reps):
idx = np.argsort(data_np)[-k:][::-1]
argsort_ms = (time.perf_counter() - start) * 1000 / n_reps
# 2. np.argpartition (O(n) average, unordered top-k)
start = time.perf_counter()
for _ in range(n_reps):
part_idx = np.argpartition(data_np, -k)[-k:]
partition_ms = (time.perf_counter() - start) * 1000 / n_reps
# 3. np.argpartition + sort the k results (O(n) + O(k log k))
start = time.perf_counter()
for _ in range(n_reps):
part_idx = np.argpartition(data_np, -k)[-k:]
top_k_sorted = part_idx[np.argsort(data_np[part_idx])[::-1]]
partition_sorted_ms = (time.perf_counter() - start) * 1000 / n_reps
# 4. torch.topk (O(n log k), highly optimized)
start = time.perf_counter()
for _ in range(n_reps):
vals, idx = torch.topk(data_t, k=k)
topk_ms = (time.perf_counter() - start) * 1000 / n_reps
# 5. heapq.nlargest (O(n log k), pure Python)
import heapq
start = time.perf_counter()
for _ in range(n_reps):
result = heapq.nlargest(k, enumerate(data_np), key=lambda x: x[1])
heapq_ms = (time.perf_counter() - start) * 1000 / n_reps
print(f"\nTop-{k} from n={n:,}:")
print(f" np.argsort (full sort): {argsort_ms:.3f} ms (baseline)")
print(f" np.argpartition (unordered): {partition_ms:.3f} ms ({argsort_ms/partition_ms:.1f}x faster)")
print(f" argpartition + sort k: {partition_sorted_ms:.3f} ms ({argsort_ms/partition_sorted_ms:.1f}x faster)")
print(f" torch.topk: {topk_ms:.3f} ms ({argsort_ms/topk_ms:.1f}x faster)")
print(f" heapq.nlargest: {heapq_ms:.3f} ms ({argsort_ms/heapq_ms:.1f}x faster)")
# Recommendation guide
print("\nRecommendation:")
if k < n // 100:
print(f" k={k} << n={n}: use torch.topk (fastest + sorted output)")
elif k < n // 10:
print(f" k={k} < n/10: use np.argpartition + sort k")
else:
print(f" k={k} close to n: full argsort may be comparable")
compare_topk_methods(n=50_000, k=5) # beam search vocabulary selection
compare_topk_methods(n=100_000, k=100) # retrieval top-100
compare_topk_methods(n=10_000, k=1000) # dense retrieval re-ranking
Mermaid: Sorting and Search Algorithm Selection
Production Engineering Notes
1. Beam search performance in production LLMs. The dominant cost in beam search is not the sort/heap - it is the model forward pass. But the cumulative overhead of poorly chosen top-k selection (full sort vs partial sort) adds up over thousands of tokens. For a vocabulary of 50,000 and beam width 4, replace sorted(enumerate(logits), key=lambda x: -x[1])[:4] with torch.topk(logits, k=4). The speedup is real and requires zero model changes.
2. Argsort vs argpartition for rankings. np.argsort returns the full sorted order. np.argpartition returns the top-k in unspecified order. If you need the top-k items in sorted order (e.g., generating a ranked list for display), use argpartition to identify the candidates then sort only those items. This is vs - a significant saving when and .
3. Sorting stability matters for reproducibility. np.sort with kind='quicksort' (the default for numeric data) is not stable. If two elements have the same sort key, their relative order is undefined. For reproducible experiments, use kind='stable' (timsort). This matters when sorting training samples by length - two samples with identical token count should stay in a consistent order for deterministic data loading.
4. GPU sort is almost always faster than CPU sort for tensors. If your data is already a PyTorch tensor on GPU, do not move it to CPU to sort with NumPy. torch.sort() and torch.topk() on GPU are highly optimized (CUB radix sort under the hood). The CPU-GPU transfer overhead alone will dominate. Keep sorts on the device where the data lives.
5. FAISS index selection depends on dataset size and recall requirement. Flat (brute-force) for . IVF-Flat for with moderate recall requirements. HNSW for fastest query latency. IVF-PQ for memory-constrained deployments at billion scale. Calibrate nprobe (IVF) or ef_search (HNSW) against your recall-latency target on a validation query set before deployment.
:::danger The Sorting Stability Trap in Multi-Key Ranking
If you sort a ranked list by score, then re-sort by relevance, and your sort is unstable, items with equal relevance scores may have their original ranking scrambled. This breaks re-ranking pipelines silently - the output looks right but the ordering is non-deterministic. Always use a stable sort (kind='stable' in NumPy, torch.sort which is stable) in ranking pipelines, or combine all sort keys into a single sort key tuple.
:::
:::warning Quickselect in NumPy Has O(n^2) Worst Case
np.argpartition uses introselect (quickselect + guaranteed fallback to median-of-medians), but edge cases on adversarially ordered data can degrade to in older NumPy versions. In production, validate performance on your actual data distribution, not just synthetic benchmarks. If you observe outlier latency on specific inputs, torch.topk (which uses a different algorithm) may be more reliable.
:::
:::tip Use np.partition / np.argpartition When You Do Not Need Sorted Order
A very common anti-pattern: np.argsort(scores)[-k:] to get the top-k indices. This sorts the entire array to find elements. Replace with np.argpartition(scores, -k)[-k:]. The result is the top-k elements in arbitrary order - if you need them sorted, do np.argsort(scores[top_k_idx]). Net result: instead of .
:::
Interview Q&A
Q1: What is the difference between np.argsort and np.argpartition and when would you use each?
np.argsort returns indices that sort the array in . np.argpartition(a, k) rearranges the array so that element at position is what it would be if the array were sorted, elements before it are all smaller, and elements after it are all larger - in average time (quickselect). The catch: argpartition does not guarantee the order within the "smaller than k-th" or "larger than k-th" groups.
Use argsort when you need the full sorted order. Use argpartition when you need the top-k elements and their relative ordering does not matter, or when you want to sort only the candidates afterward (paying instead of ).
ML use cases for argpartition: finding the k nearest neighbors in a small distance array, selecting the top-k logits before sampling (then you softmax only those k values), getting the hardest negative examples in metric learning.
Q2: Why is timsort better than quicksort for real-world data, and when does it matter in ML?
Timsort detects "natural runs" - already-sorted subsequences - in the input and merges them rather than sorting from scratch. On partially-sorted data, timsort is for fully sorted input, for runs, and for random input. Quicksort is always average (and worst case).
In ML: when you sort training samples by sequence length each epoch, and lengths do not change, repeated sorts are with timsort vs with a naive implementation. When you merge sorted result sets from multiple retrieval sources (already sorted by score), timsort's merge optimization kicks in. When sorting DataFrame columns that are nearly in order (timestamps, incrementally added IDs), timsort is 2-5x faster than alternatives.
Q3: How would you implement top-k sampling in a language model decoder efficiently?
Top-k sampling: keep only the k highest-probability tokens and sample from their renormalized distribution. Native implementation with torch.topk:
def top_k_sample(logits: torch.Tensor, k: int, temperature: float = 1.0) -> int:
# Scale by temperature
scaled = logits / temperature
# Get top-k values and indices in O(n log k)
top_vals, top_idx = torch.topk(scaled, k=k)
# Softmax only over the k values
probs = torch.softmax(top_vals, dim=-1)
# Sample
sampled = torch.multinomial(probs, num_samples=1)
return top_idx[sampled].item()
Key optimization: only compute softmax over the retained logits, not the full vocabulary. Softmax over is 1000x cheaper than softmax over the full vocabulary of 50,000.
Q4: How does ternary search work and what ML problem is it useful for?
Ternary search finds the minimum (or maximum) of a unimodal function on an interval . At each step, evaluate at the two trisection points and . If , the minimum is in ; otherwise it is in . Each step reduces the search space by 1/3. Convergence requires evaluations.
ML use cases: finding the optimal learning rate when you can assume the loss-vs-LR curve is unimodal (it usually is, with too-small LR giving slow convergence and too-large LR giving divergence). Finding the optimal temperature for calibration (cross-entropy vs temperature is unimodal). Finding the optimal in k-NN classifiers. Crucially: ternary search only works for unimodal functions. If the function has multiple local minima, use random search or Bayesian optimization.
Q5: Describe how beam search uses a priority queue and what its time complexity is.
Beam search maintains a beam of partial sequences. At each step:
- Expand each of the beams by appending every possible next token: candidate sequences
- Score each candidate (existing beam score + log-probability of new token)
- Retain only the top- candidates for the next step
The priority queue (min-heap of size ) is used in step 3: as you score each of the candidates, maintain a min-heap. If the candidate is better than the worst item in the heap, push it and pop the worst. This is per step. Over decoding steps, total complexity is .
In practice: the factor is handled efficiently with torch.topk per beam rather than a Python heap (fewer Python object allocations). Real LLM decoders also use batched attention KV-cache to avoid recomputing the full context for each beam at each step.
Q6: What is the difference between IVF-Flat and HNSW in FAISS and when would you use each?
IVF-Flat: k-means partition the dataset into clusters. At query time, find the nearest clusters (by centroid distance), then search exhaustively within those clusters. Trade-off: more = higher recall, slower query. Construction is fast ( for assignment after k-means training). Memory is just the raw vectors plus cluster assignments. Best for datasets where you control recall vs speed via .
HNSW (Hierarchical Navigable Small World): builds a multi-layer graph where higher layers are sparse long-range connections and the bottom layer is a dense local neighborhood graph. Query traverses from coarse top layer to fine bottom layer, following nearest-neighbor links. Average query complexity . Higher memory than IVF ( graph edges). Construction is - slow for large datasets. Best for lowest possible query latency with high recall (95-99%) at moderate memory cost. Used in production for applications like Discord's message search, where latency matters more than memory.
Sorting in NumPy and PyTorch: Practical Reference
Understanding which NumPy and PyTorch sort functions to call, and when, saves significant time when building ML pipelines.
import numpy as np
import torch
# --- NumPy sorting reference ---
arr = np.array([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0, 5.0, 3.0])
# 1. np.sort - returns sorted copy
sorted_arr = np.sort(arr) # ascending
sorted_desc = np.sort(arr)[::-1] # descending (reverse view)
sorted_stable = np.sort(arr, kind='stable') # timsort, stable
# 2. np.argsort - returns indices that sort the array
idx = np.argsort(arr) # O(n log n), returns indices
idx_desc = np.argsort(arr)[::-1] # descending indices
# 3. np.argpartition - O(n) partial sort, top-k unordered
# Get indices of k LARGEST elements (unordered within the k)
k = 3
top_k_idx_unordered = np.argpartition(arr, -k)[-k:]
# To get them in sorted order:
top_k_idx = top_k_idx_unordered[np.argsort(arr[top_k_idx_unordered])[::-1]]
print(f"Top-{k} values: {arr[top_k_idx]}")
# 4. Multi-key sort: sort rows of a 2D array by multiple columns
data = np.array([[3, 2], [1, 5], [3, 1], [1, 3], [2, 4]])
# Sort by column 0 first, then column 1 (lexicographic)
sorted_idx = np.lexsort((data[:, 1], data[:, 0])) # NOTE: last key is primary
data_sorted = data[sorted_idx]
print(f"Multi-key sorted:\n{data_sorted}")
# 5. Argsort for embedding retrieval - sort by similarity score
queries = np.random.randn(5, 768) # 5 query embeddings
corpus = np.random.randn(1000, 768) # 1000 corpus embeddings (normalized for cosine)
queries /= np.linalg.norm(queries, axis=1, keepdims=True)
corpus /= np.linalg.norm(corpus, axis=1, keepdims=True)
# Compute all similarities: (5, 1000)
similarities = queries @ corpus.T
# Top-10 per query using argpartition (faster than argsort for k << n)
for q_idx in range(5):
top_k = np.argpartition(similarities[q_idx], -10)[-10:]
top_k_sorted = top_k[np.argsort(similarities[q_idx][top_k])[::-1]]
print(f"Query {q_idx}: top similarity = {similarities[q_idx][top_k_sorted[0]]:.4f}")
# --- PyTorch sorting reference ---
t = torch.tensor([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0])
# torch.sort - stable sort, returns (values, indices)
vals, idx = torch.sort(t)
vals_desc, idx_desc = torch.sort(t, descending=True)
# torch.argsort - just the indices
idx = torch.argsort(t)
idx_desc = torch.argsort(t, descending=True)
# torch.topk - O(n log k), returns (values, indices) in sorted order
top_vals, top_idx = torch.topk(t, k=3) # top-3 largest
bot_vals, bot_idx = torch.topk(t, k=3, largest=False) # top-3 smallest
# Batched sort on GPU - very efficient
batch_scores = torch.randn(32, 50000) # 32 batch, 50k vocab
top_k_vals, top_k_idx = torch.topk(batch_scores, k=5, dim=-1)
# Returns shape (32, 5) - top-5 per batch element
# torch.msort - sort along first dimension
matrix = torch.randn(4, 3)
col_sorted = torch.msort(matrix) # sorts each column independently
# Stable vs unstable sort
# torch.sort is stable by default (equal elements keep original order)
print("\nPyTorch sort examples:")
print(f"torch.topk(t, k=3): values={top_vals}, indices={top_idx}")
# --- Sorting datasets by sequence length for efficient batching ---
def bucket_by_length(texts, tokenizer, n_buckets=8):
"""
Sort and bucket texts by tokenized length for efficient training.
Reduces padding waste by grouping similar-length sequences together.
Returns (sorted_texts, bucket_boundaries).
"""
lengths = np.array([len(tokenizer(t)) for t in texts])
# Sort by length (use argsort to track original indices)
sorted_idx = np.argsort(lengths, kind='stable')
sorted_texts = [texts[i] for i in sorted_idx]
sorted_lengths = lengths[sorted_idx]
# Compute bucket boundaries
total = len(texts)
bucket_size = total // n_buckets
boundaries = [sorted_lengths[i * bucket_size] for i in range(n_buckets)]
avg_padding_waste = 1.0 - sorted_lengths.mean() / sorted_lengths.max()
print(f" Average padding waste after bucketing: {avg_padding_waste:.1%}")
return sorted_texts, boundaries
# Demonstrate the padding savings
lengths = np.concatenate([
np.random.randint(10, 50, 500), # short sentences
np.random.randint(100, 512, 500), # long documents
])
# Without sorting: batch contains mixed lengths, padded to max
naive_waste = 1.0 - lengths.mean() / lengths.max()
# With sorting: batch max is much closer to batch mean
sorted_lengths = np.sort(lengths)
batch_size = 32
batch_wastes = []
for i in range(0, len(sorted_lengths) - batch_size, batch_size):
batch = sorted_lengths[i:i + batch_size]
batch_wastes.append(1.0 - batch.mean() / batch.max())
print(f"\nPadding waste:")
print(f" Without length sorting: {naive_waste:.1%}")
print(f" With length sorting: {np.mean(batch_wastes):.1%}")
Summary
Sorting and search are the workhorses of efficient ML pipelines. The key decisions:
- For full sorted order of floats:
np.sort(timsort, stable) ortorch.sort - For top-k elements when :
torch.topkfor tensors,np.argpartitionfor arrays - For unimodal hyperparameter search: ternary search in evaluations
- For finding values in sorted arrays:
bisect.bisect_leftandbisect.bisect_rightin - For approximate nearest neighbor at scale: FAISS IVF (controllable recall/speed) or HNSW (lowest latency)
- For beam search: maintain a min-heap of size over partial sequences; replace full vocabulary sort with
torch.topk - For large datasets that do not fit in RAM: external sort (sort chunks, k-way merge with a heap)
- For integer keys with bounded range: counting sort or radix sort for sorting
- For GPU operations: keep sorts on-device with
torch.sortortorch.topk- never move to CPU for sorting
The single most impactful optimization in practice: replacing np.argsort(scores)[-k:] with np.argpartition(scores, -k)[-k:] in hot paths. This alone can make retrieval and beam search 5-10x faster when is small relative to the array size.
