Sorting and Searching
Comparison sorts, counting/radix sort, binary search variations, quickselect, and merge sort for external data - with ML-relevant problems (top-K selection, nearest neighbor search) and full Python solutions.
Reading time: ~45 min | Interview relevance: Very High | Roles: MLE, Data Engineer, AI Infrastructure
The Real Interview Moment
You are forty minutes into a virtual onsite at a large language model company. The interviewer, a senior ML infrastructure engineer, shares her screen and says: "We have a production model that generates 10 million confidence scores per hour. We need to efficiently find the top 1000 highest-scoring predictions without sorting the entire stream. The data does not fit in memory all at once - it arrives in chunks of 100,000. How would you design this?"
You pause for two seconds. You know that sorting all 10 million scores would be O(n log n), about 230 million operations. But you only need the top 1000. A min-heap of size 1000 would process each element in O(log 1000), giving you O(n log K) total - and it naturally handles streaming data because you process one chunk at a time, maintaining only the heap across chunks.
You explain this, then write the code. The interviewer pushes further: "What if we need the top 1000 from each of 50 model shards, then merge?" You recognize this as a K-way merge problem - another classic pattern. You reach for heapq.merge and explain the external sort paradigm.
This chapter teaches you the sorting and searching patterns that separate candidates who can handle real-scale ML data from those who cannot.
Sorting Algorithm Overview
Comparison Sorts
Merge Sort
Merge sort is the most important sorting algorithm for ML interviews because it is stable, predictable (always O(n log n)), and forms the basis of external sorting for massive datasets.
def merge_sort(arr):
"""Merge sort - stable, O(n log n) guaranteed.
ML context: Sorting training examples by loss for curriculum learning,
sorting predictions by confidence for threshold calibration.
Time: O(n log n), Space: O(n)
"""
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return _merge(left, right)
def _merge(left, right):
"""Merge two sorted arrays into one sorted array."""
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]: # <= makes it stable
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
Why merge sort over quicksort in ML contexts? Merge sort is stable (equal elements maintain their original order), which matters when sorting predictions that share the same confidence score but have different metadata. It also has guaranteed O(n log n) worst-case, while quicksort degrades to O(n^2) on adversarial inputs. Python's built-in sorted() uses Timsort, a hybrid of merge sort and insertion sort.
Quicksort
import random
def quicksort(arr):
"""Quicksort with random pivot - O(n log n) average.
In-place version (Lomuto partition scheme).
Time: O(n log n) average, O(n^2) worst case
Space: O(log n) average (recursion stack)
"""
_quicksort_helper(arr, 0, len(arr) - 1)
return arr
def _quicksort_helper(arr, low, high):
if low < high:
pivot_idx = _partition(arr, low, high)
_quicksort_helper(arr, low, pivot_idx - 1)
_quicksort_helper(arr, pivot_idx + 1, high)
def _partition(arr, low, high):
"""Lomuto partition with random pivot."""
# Random pivot avoids O(n^2) on sorted input
rand_idx = random.randint(low, high)
arr[rand_idx], arr[high] = arr[high], arr[rand_idx]
pivot = arr[high]
i = low - 1 # Boundary of elements <= pivot
for j in range(low, high):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i]
arr[i + 1], arr[high] = arr[high], arr[i + 1]
return i + 1
If you implement quicksort and always choose the first or last element as the pivot without randomization, the interviewer will immediately ask: "What happens on sorted input?" The answer is O(n^2). If you cannot fix this by adding random pivot selection or median-of-three, it signals you do not understand the algorithm's failure modes - a critical gap for production ML systems where data often arrives partially sorted.
Insertion Sort
def insertion_sort(arr):
"""Insertion sort - efficient for small or nearly sorted arrays.
Python's Timsort uses insertion sort for small subarrays (< 64 elements).
This is why understanding it matters even though you would never use it
alone on large data.
Time: O(n^2) worst, O(n) best (nearly sorted)
Space: O(1)
"""
for i in range(1, len(arr)):
key = arr[i]
j = i - 1
while j >= 0 and arr[j] > key:
arr[j + 1] = arr[j]
j -= 1
arr[j + 1] = key
return arr
Heap Sort
def heap_sort(arr):
"""Heap sort - O(n log n) guaranteed, in-place, but not stable.
ML context: Understanding heap sort helps with priority queue
problems (top-K, streaming median, beam search).
Time: O(n log n), Space: O(1)
"""
n = len(arr)
# Build max-heap (heapify from bottom up)
for i in range(n // 2 - 1, -1, -1):
_sift_down(arr, n, i)
# Extract elements one by one
for i in range(n - 1, 0, -1):
arr[0], arr[i] = arr[i], arr[0] # Move max to end
_sift_down(arr, i, 0) # Restore heap property
return arr
def _sift_down(arr, size, root):
"""Sift down element at root to maintain max-heap property."""
largest = root
left = 2 * root + 1
right = 2 * root + 2
if left < size and arr[left] > arr[largest]:
largest = left
if right < size and arr[right] > arr[largest]:
largest = right
if largest != root:
arr[root], arr[largest] = arr[largest], arr[root]
_sift_down(arr, size, largest)
Sorting Comparison Table
| Algorithm | Best | Average | Worst | Space | Stable | Notes |
|---|---|---|---|---|---|---|
| Merge Sort | O(n log n) | O(n log n) | O(n log n) | O(n) | Yes | Best for linked lists, external sort |
| Quicksort | O(n log n) | O(n log n) | O(n^2) | O(log n) | No | Fastest in practice (cache friendly) |
| Heap Sort | O(n log n) | O(n log n) | O(n log n) | O(1) | No | Guaranteed O(n log n), in-place |
| Insertion Sort | O(n) | O(n^2) | O(n^2) | O(1) | Yes | Best for small/nearly sorted data |
| Timsort | O(n) | O(n log n) | O(n log n) | O(n) | Yes | Python's sorted(), hybrid approach |
Non-Comparison Sorts
Non-comparison sorts break the O(n log n) barrier by exploiting the structure of the data.
Counting Sort
def counting_sort(arr, max_val=None):
"""Counting sort - O(n + k) where k is the range of values.
ML context: Sorting discrete labels (class indices 0..C-1),
histogram computation for feature binning, sorting token IDs.
Time: O(n + k), Space: O(n + k)
Stable: Yes (this implementation preserves relative order)
"""
if not arr:
return arr
if max_val is None:
max_val = max(arr)
# Count occurrences
count = [0] * (max_val + 1)
for val in arr:
count[val] += 1
# Cumulative count (for stable sort)
for i in range(1, len(count)):
count[i] += count[i - 1]
# Build output array (traverse input in reverse for stability)
output = [0] * len(arr)
for val in reversed(arr):
count[val] -= 1
output[count[val]] = val
return output
Radix Sort
def radix_sort(arr):
"""Radix sort - O(d * (n + b)) where d = digits, b = base.
ML context: Sorting large collections of integer IDs (user IDs,
document IDs) or fixed-precision scores converted to integers.
Time: O(d * n) for fixed-width integers
Space: O(n + b)
Stable: Yes
"""
if not arr:
return arr
max_val = max(arr)
exp = 1 # Current digit position (ones, tens, hundreds, ...)
while max_val // exp > 0:
arr = _counting_sort_by_digit(arr, exp)
exp *= 10
return arr
def _counting_sort_by_digit(arr, exp):
"""Sort by a specific digit position using counting sort."""
n = len(arr)
output = [0] * n
count = [0] * 10 # Base 10 digits
for val in arr:
digit = (val // exp) % 10
count[digit] += 1
for i in range(1, 10):
count[i] += count[i - 1]
for val in reversed(arr):
digit = (val // exp) % 10
count[digit] -= 1
output[count[digit]] = val
return output
Counting sort and radix sort appear most frequently at companies dealing with massive categorical data - recommendation systems at Netflix and Spotify (sorting billions of item IDs), ad platforms at Meta and Google (sorting event logs by discrete category), and genomics companies (sorting nucleotide sequences). For general MLE roles, knowing when to use them is more important than implementing them from scratch.
Bucket Sort
def bucket_sort(arr, num_buckets=10):
"""Bucket sort - O(n + k) average for uniformly distributed data.
ML context: Sorting model confidence scores in [0, 1] range.
Scores from well-calibrated models are roughly uniformly distributed,
making bucket sort ideal.
Time: O(n + k) average, O(n^2) worst case
Space: O(n + k)
"""
if not arr:
return arr
min_val, max_val = min(arr), max(arr)
if min_val == max_val:
return arr[:]
bucket_range = (max_val - min_val) / num_buckets
buckets = [[] for _ in range(num_buckets)]
for val in arr:
idx = min(int((val - min_val) / bucket_range), num_buckets - 1)
buckets[idx].append(val)
# Sort individual buckets (insertion sort is efficient for small buckets)
result = []
for bucket in buckets:
bucket.sort() # Python's Timsort
result.extend(bucket)
return result
# Example: sorting 1 million confidence scores in [0, 1]
# With 1000 buckets, each bucket has ~1000 elements
# Sorting each bucket: ~1000 * log(1000) ≈ 10,000 ops
# Total: ~10 million ops vs ~20 million for comparison sort
Binary Search
Binary search is the most frequently tested searching technique. In ML interviews, it appears in threshold tuning, hyperparameter search, and any problem involving sorted data.
Standard Binary Search
def binary_search(arr, target):
"""Standard binary search in a sorted array.
Time: O(log n), Space: O(1)
"""
left, right = 0, len(arr) - 1
while left <= right:
mid = left + (right - left) // 2 # Avoids integer overflow
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
Binary Search Variations
The real power of binary search comes from its variations. These are far more common in interviews than the standard version.
import bisect
def find_first_occurrence(arr, target):
"""Find the FIRST occurrence of target in a sorted array.
ML context: Finding the first prediction above a confidence threshold
in a sorted list of scores.
Time: O(log n), Space: O(1)
"""
left, right = 0, len(arr) - 1
result = -1
while left <= right:
mid = left + (right - left) // 2
if arr[mid] == target:
result = mid # Record this position
right = mid - 1 # Keep searching left for earlier occurrence
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return result
def find_last_occurrence(arr, target):
"""Find the LAST occurrence of target in a sorted array.
Time: O(log n), Space: O(1)
"""
left, right = 0, len(arr) - 1
result = -1
while left <= right:
mid = left + (right - left) // 2
if arr[mid] == target:
result = mid # Record this position
left = mid + 1 # Keep searching right for later occurrence
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return result
def count_occurrences(arr, target):
"""Count occurrences of target using binary search.
Time: O(log n), Space: O(1)
"""
first = find_first_occurrence(arr, target)
if first == -1:
return 0
last = find_last_occurrence(arr, target)
return last - first + 1
# Python's bisect module provides these efficiently:
# bisect.bisect_left(arr, target) -> index of first element >= target
# bisect.bisect_right(arr, target) -> index of first element > target
The most common binary search bug is an off-by-one error in the boundary conditions. Remember: left <= right with left = mid + 1 and right = mid - 1 is the standard template. If you use left < right, you need different update rules. Pick one template and stick with it - do not mix them.
Binary Search on Answer (Monotonic Predicate)
This is the most important binary search pattern for ML interviews. Instead of searching for an element in an array, you search over the answer space.
def find_threshold(scores, labels, target_precision):
"""Find the minimum confidence threshold that achieves target precision.
ML context: Threshold tuning for a classifier. As the threshold
increases, precision increases but recall decreases. Find the
lowest threshold that meets the precision requirement.
Approach: Binary search on the threshold value [0, 1].
At each threshold, compute precision. The precision function is
monotonically non-decreasing with threshold.
Time: O(n * log(1/epsilon)) where epsilon is the desired precision
Space: O(1)
"""
def compute_precision(threshold):
true_pos = sum(1 for s, l in zip(scores, labels)
if s >= threshold and l == 1)
predicted_pos = sum(1 for s in scores if s >= threshold)
if predicted_pos == 0:
return 1.0 # No predictions => vacuously precise
return true_pos / predicted_pos
left, right = 0.0, 1.0
epsilon = 1e-6
while right - left > epsilon:
mid = (left + right) / 2
if compute_precision(mid) >= target_precision:
right = mid # This threshold works, try lower
else:
left = mid # Need higher threshold for more precision
return right
def min_capacity_to_ship(weights, days):
"""Find the minimum ship capacity to ship all packages within 'days' days.
Classic binary search on answer problem (LeetCode 1011).
The predicate: "Can we ship with capacity C in <= days days?"
is monotonically true above some threshold C.
Time: O(n * log(sum - max)), Space: O(1)
"""
def can_ship(capacity):
current_load = 0
days_needed = 1
for w in weights:
if current_load + w > capacity:
days_needed += 1
current_load = 0
current_load += w
return days_needed <= days
left = max(weights) # Must at least carry the heaviest package
right = sum(weights) # Worst case: ship everything in one day
while left < right:
mid = left + (right - left) // 2
if can_ship(mid):
right = mid # This capacity works, try smaller
else:
left = mid + 1 # Need more capacity
return left
Search in Rotated Sorted Array
def search_rotated(nums, target):
"""Search for target in a rotated sorted array.
ML context: Searching in circular buffers used for model metric
logging (e.g., a ring buffer of the last N loss values).
Key insight: At least one half of the array is always sorted.
Determine which half is sorted, then check if target is in that half.
Time: O(log n), Space: O(1)
"""
left, right = 0, len(nums) - 1
while left <= right:
mid = left + (right - left) // 2
if nums[mid] == target:
return mid
# Left half is sorted
if nums[left] <= nums[mid]:
if nums[left] <= target < nums[mid]:
right = mid - 1 # Target is in sorted left half
else:
left = mid + 1 # Target is in right half
# Right half is sorted
else:
if nums[mid] < target <= nums[right]:
left = mid + 1 # Target is in sorted right half
else:
right = mid - 1 # Target is in left half
return -1
Find Peak Element
def find_peak(arr):
"""Find a peak element (greater than its neighbors).
ML context: Finding the epoch where validation accuracy peaked
during training (for early stopping).
The key insight is that binary search works even on unsorted arrays
if the problem has a "hill" structure. If arr[mid] < arr[mid+1],
there must be a peak to the right.
Time: O(log n), Space: O(1)
"""
left, right = 0, len(arr) - 1
while left < right:
mid = left + (right - left) // 2
if arr[mid] < arr[mid + 1]:
left = mid + 1 # Peak is to the right
else:
right = mid # Peak is at mid or to the left
return left # left == right, the peak index
Selection Algorithms: Top-K and Quickselect
Quickselect (Kth Smallest/Largest)
import random
def quickselect(arr, k):
"""Find the kth smallest element (0-indexed) in O(n) average time.
ML context: Finding the median prediction, or the kth percentile
of a score distribution without fully sorting.
This is the selection algorithm behind numpy.partition().
Time: O(n) average, O(n^2) worst case
Space: O(1)
"""
def select(left, right, k):
if left == right:
return arr[left]
# Random pivot
pivot_idx = random.randint(left, right)
arr[pivot_idx], arr[right] = arr[right], arr[pivot_idx]
pivot = arr[right]
# Partition
store_idx = left
for i in range(left, right):
if arr[i] < pivot:
arr[i], arr[store_idx] = arr[store_idx], arr[i]
store_idx += 1
arr[store_idx], arr[right] = arr[right], arr[store_idx]
# Recurse on the correct side
if k == store_idx:
return arr[k]
elif k < store_idx:
return select(left, store_idx - 1, k)
else:
return select(store_idx + 1, right, k)
return select(0, len(arr) - 1, k)
def find_kth_largest(arr, k):
"""Find the kth largest element.
The kth largest is the (n - k)th smallest.
Time: O(n) average, Space: O(1)
"""
n = len(arr)
return quickselect(arr, n - k)
Top-K Using a Heap
import heapq
def top_k_smallest(arr, k):
"""Find the k smallest elements using a max-heap of size k.
Time: O(n log k), Space: O(k)
"""
if k >= len(arr):
return sorted(arr)
# Use negative values because Python has a min-heap
heap = [-x for x in arr[:k]]
heapq.heapify(heap)
for val in arr[k:]:
if -val > heap[0]: # val is smaller than the largest in heap
heapq.heapreplace(heap, -val)
return sorted(-x for x in heap)
def top_k_largest(arr, k):
"""Find the k largest elements using a min-heap of size k.
ML context: Finding top-K predictions, top-K most similar documents,
top-K features by importance.
Time: O(n log k), Space: O(k)
"""
if k >= len(arr):
return sorted(arr, reverse=True)
# Min-heap of size k - smallest of the top-k is at the root
heap = arr[:k]
heapq.heapify(heap)
for val in arr[k:]:
if val > heap[0]: # val belongs in top-k
heapq.heapreplace(heap, val)
return sorted(heap, reverse=True)
def top_k_frequent(arr, k):
"""Find the k most frequent elements.
ML context: Finding the most common tokens in a corpus,
most frequent class labels, most common error types.
Time: O(n + m log k) where m = unique elements
Space: O(m)
"""
from collections import Counter
counts = Counter(arr)
# Use a min-heap of size k on (count, element) pairs
return heapq.nlargest(k, counts.keys(), key=counts.get)
Quickselect vs Heap for top-K: Use quickselect (O(n) average) when you need the kth element and can modify the array. Use a min-heap of size K (O(n log K)) when data is streaming and you cannot store it all. For K much smaller than n, both are far better than full sorting. Python's heapq.nlargest(k, arr) uses a heap internally. NumPy's np.partition(arr, k) uses introselect (quickselect variant).
K-Way Merge and External Sort
K-Way Merge
import heapq
def k_way_merge(sorted_lists):
"""Merge K sorted lists into one sorted list.
ML context: Merging sorted prediction results from multiple
model replicas or data shards. Also the core of external merge sort.
Time: O(N log K) where N = total elements, K = number of lists
Space: O(K) for the heap + O(N) for output
"""
# Heap entries: (value, list_index, element_index)
heap = []
for i, lst in enumerate(sorted_lists):
if lst:
heapq.heappush(heap, (lst[0], i, 0))
result = []
while heap:
val, list_idx, elem_idx = heapq.heappop(heap)
result.append(val)
# Push next element from the same list
if elem_idx + 1 < len(sorted_lists[list_idx]):
next_val = sorted_lists[list_idx][elem_idx + 1]
heapq.heappush(heap, (next_val, list_idx, elem_idx + 1))
return result
# Python's heapq.merge() does the same thing:
# merged = list(heapq.merge(*sorted_lists))
External Merge Sort
def external_merge_sort(input_file, output_file, chunk_size=100000):
"""Sort a file too large for memory using external merge sort.
ML context: Sorting massive datasets (billions of training examples)
by loss, timestamp, or feature value for curriculum learning,
data deduplication, or temporal splitting.
Approach:
1. Read chunks that fit in memory
2. Sort each chunk, write to temp file
3. K-way merge all temp files
Time: O(N log N), Space: O(chunk_size) memory
"""
import tempfile
import os
# Phase 1: Create sorted runs
temp_files = []
with open(input_file, 'r') as f:
while True:
chunk = []
for _ in range(chunk_size):
line = f.readline()
if not line:
break
chunk.append(float(line.strip()))
if not chunk:
break
chunk.sort()
# Write sorted chunk to temp file
tmp = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.tmp')
for val in chunk:
tmp.write(f"{val}\n")
tmp.close()
temp_files.append(tmp.name)
# Phase 2: K-way merge of sorted runs
file_handles = [open(f, 'r') for f in temp_files]
# Initialize heap with first element from each file
heap = []
for i, fh in enumerate(file_handles):
line = fh.readline()
if line:
heapq.heappush(heap, (float(line.strip()), i))
with open(output_file, 'w') as out:
while heap:
val, file_idx = heapq.heappop(heap)
out.write(f"{val}\n")
line = file_handles[file_idx].readline()
if line:
heapq.heappush(heap, (float(line.strip()), file_idx))
# Cleanup
for fh in file_handles:
fh.close()
for f in temp_files:
os.remove(f)
External merge sort questions are common at data infrastructure companies (Databricks, Snowflake, Palantir) and at any large ML company that deals with massive training datasets (Google, Meta, Amazon). You will rarely be asked to implement the full algorithm, but you must explain the two-phase approach: sort chunks, then K-way merge.
ML-Specific Applications
Nearest Neighbor Search with Sorted Data
def k_nearest_in_sorted(sorted_arr, target, k):
"""Find k elements closest to target in a sorted array.
ML context: Finding the K nearest neighbors in a 1D feature space.
First, binary search for the closest element, then expand outward.
Time: O(log n + k), Space: O(k)
"""
import bisect
n = len(sorted_arr)
# Find insertion point
pos = bisect.bisect_left(sorted_arr, target)
left = pos - 1
right = pos
result = []
while len(result) < k and (left >= 0 or right < n):
if left < 0:
result.append(sorted_arr[right])
right += 1
elif right >= n:
result.append(sorted_arr[left])
left -= 1
elif abs(sorted_arr[left] - target) <= abs(sorted_arr[right] - target):
result.append(sorted_arr[left])
left -= 1
else:
result.append(sorted_arr[right])
right += 1
return result
Streaming Median
import heapq
class StreamingMedian:
"""Maintain the median of a stream of numbers using two heaps.
ML context: Computing the running median of model predictions
or training metrics. The median is more robust to outliers than
the mean, making it useful for monitoring.
Two heaps:
- max_heap (left half): stores the smaller half, max at top
- min_heap (right half): stores the larger half, min at top
Invariant: len(max_heap) == len(min_heap) or len(max_heap) == len(min_heap) + 1
Time: O(log n) per insert, O(1) per query
Space: O(n)
"""
def __init__(self):
self.max_heap = [] # Left half (negated for max-heap)
self.min_heap = [] # Right half
def add(self, num):
# Always add to max_heap first
heapq.heappush(self.max_heap, -num)
# Balance: move max of left to right if needed
if (self.min_heap and
-self.max_heap[0] > self.min_heap[0]):
val = -heapq.heappop(self.max_heap)
heapq.heappush(self.min_heap, val)
# Maintain size: left can have at most 1 more than right
if len(self.max_heap) > len(self.min_heap) + 1:
val = -heapq.heappop(self.max_heap)
heapq.heappush(self.min_heap, val)
elif len(self.min_heap) > len(self.max_heap):
val = heapq.heappop(self.min_heap)
heapq.heappush(self.max_heap, -val)
def get_median(self):
if len(self.max_heap) > len(self.min_heap):
return -self.max_heap[0]
return (-self.max_heap[0] + self.min_heap[0]) / 2.0
# Usage:
# sm = StreamingMedian()
# for score in model_scores_stream:
# sm.add(score)
# print(f"Current median: {sm.get_median()}")
Top-K with Streaming Data (Chunked Processing)
import heapq
def streaming_top_k(data_chunks, k):
"""Find top-K elements across streaming data chunks.
ML context: The production scenario from the opening - finding
top-K predictions from a stream of model outputs arriving in chunks.
Time: O(N log K) where N = total elements across all chunks
Space: O(K)
"""
min_heap = [] # Min-heap of size K
for chunk in data_chunks:
for val in chunk:
if len(min_heap) < k:
heapq.heappush(min_heap, val)
elif val > min_heap[0]:
heapq.heapreplace(min_heap, val)
return sorted(min_heap, reverse=True)
# With metadata (e.g., prediction ID + score):
def streaming_top_k_with_metadata(data_chunks, k):
"""Top-K with associated metadata.
Each chunk is a list of (score, metadata) tuples.
"""
min_heap = []
for chunk in data_chunks:
for score, metadata in chunk:
if len(min_heap) < k:
heapq.heappush(min_heap, (score, metadata))
elif score > min_heap[0][0]:
heapq.heapreplace(min_heap, (score, metadata))
return sorted(min_heap, key=lambda x: -x[0])
Practice Problems
Problem 1: Merge Sorted Model Outputs (K-Way Merge)
You have K model replicas, each producing a sorted list of (score, item_id) pairs. Merge these into a single sorted list and return the top N results.
# Input: outputs = [
# [(0.95, "A"), (0.87, "B"), (0.72, "C")],
# [(0.93, "D"), (0.85, "E"), (0.70, "F")],
# [(0.91, "G"), (0.88, "H"), (0.75, "I")]
# ], n = 5
# Output: [(0.95, "A"), (0.93, "D"), (0.91, "G"), (0.88, "H"), (0.87, "B")]
Hint 1
Hint 2
Hint 3
Solution
import heapq
def merge_top_n(outputs, n):
"""Merge K sorted (descending) lists, return top N.
Time: O(N log K), Space: O(K)
"""
# Negate scores for min-heap behavior (we want max)
heap = []
for list_idx, lst in enumerate(outputs):
if lst:
score, item_id = lst[0]
heapq.heappush(heap, (-score, item_id, list_idx, 0))
result = []
while heap and len(result) < n:
neg_score, item_id, list_idx, elem_idx = heapq.heappop(heap)
result.append((-neg_score, item_id))
# Push next element from the same list
if elem_idx + 1 < len(outputs[list_idx]):
next_score, next_id = outputs[list_idx][elem_idx + 1]
heapq.heappush(heap, (-next_score, next_id, list_idx, elem_idx + 1))
return result
Problem 2: Find Optimal Learning Rate (Binary Search on Answer)
Given a function that returns training loss for a given learning rate, find the learning rate in [1e-6, 1.0] that produces the minimum loss. Assume the loss function is unimodal (decreases then increases) - a common shape in learning rate range tests.
# Input: loss_function(lr) returns the training loss for learning rate lr
# Output: The learning rate that minimizes the loss
# Example: loss_function is unimodal with minimum around lr = 0.001
Hint 1
Hint 2
Hint 3
Solution
import math
def find_optimal_lr(loss_function, lr_min=1e-6, lr_max=1.0, iterations=100):
"""Find the learning rate that minimizes a unimodal loss function.
Uses ternary search on log scale.
Time: O(iterations * cost_of_loss_function)
Space: O(1)
"""
# Search on log scale
log_min = math.log10(lr_min)
log_max = math.log10(lr_max)
for _ in range(iterations):
log_m1 = log_min + (log_max - log_min) / 3
log_m2 = log_max - (log_max - log_min) / 3
lr1 = 10 ** log_m1
lr2 = 10 ** log_m2
loss1 = loss_function(lr1)
loss2 = loss_function(lr2)
if loss1 > loss2:
log_min = log_m1 # Minimum is in right 2/3
else:
log_max = log_m2 # Minimum is in left 2/3
return 10 ** ((log_min + log_max) / 2)
Problem 3: Kth Largest Element in Streaming Predictions
Design a class that maintains a stream of prediction scores and can return the kth largest score at any time.
# kth = KthLargest(3, [4, 5, 8, 2])
# kth.add(3) -> 4 (stream: [2, 3, 4, 5, 8], 3rd largest = 4)
# kth.add(5) -> 5 (stream: [2, 3, 4, 5, 5, 8], 3rd largest = 5)
# kth.add(10) -> 5 (stream: [2, 3, 4, 5, 5, 8, 10], 3rd largest = 5)
# kth.add(9) -> 8 (stream: [2, 3, 4, 5, 5, 8, 9, 10], 3rd largest = 8)
Hint 1
Hint 2
Hint 3
Solution
import heapq
class KthLargest:
"""Maintain the kth largest element in a stream.
Time: O(log k) per add, O(1) per query
Space: O(k)
"""
def __init__(self, k, initial_values):
self.k = k
self.min_heap = []
for val in initial_values:
self._add_to_heap(val)
def _add_to_heap(self, val):
if len(self.min_heap) < self.k:
heapq.heappush(self.min_heap, val)
elif val > self.min_heap[0]:
heapq.heapreplace(self.min_heap, val)
def add(self, val):
self._add_to_heap(val)
return self.min_heap[0]
Problem 4: Sort Colors / Dutch National Flag (Three-Way Partition)
Given an array with elements 0, 1, and 2, sort it in one pass using constant extra space.
# Input: [2, 0, 2, 1, 1, 0]
# Output: [0, 0, 1, 1, 2, 2]
# ML context: Partitioning training samples into three groups
# (negative, neutral, positive) based on sentiment labels.
Hint 1
Hint 2
Hint 3
Solution
def sort_colors(nums):
"""Dutch National Flag partition - sort array of 0s, 1s, 2s in place.
Time: O(n), Space: O(1)
"""
low, mid, high = 0, 0, len(nums) - 1
while mid <= high:
if nums[mid] == 0:
nums[low], nums[mid] = nums[mid], nums[low]
low += 1
mid += 1
elif nums[mid] == 1:
mid += 1
else: # nums[mid] == 2
nums[mid], nums[high] = nums[high], nums[mid]
high -= 1
# Do not advance mid - need to check the swapped element
Problem 5: Find Minimum in Rotated Sorted Array
A sorted array is rotated at some pivot. Find the minimum element.
# Input: [3, 4, 5, 1, 2]
# Output: 1
# ML context: Finding the epoch with the lowest validation loss
# in a circular logging buffer.
Hint 1
Hint 2
Hint 3
Solution
def find_min_rotated(nums):
"""Find minimum in a rotated sorted array.
Time: O(log n), Space: O(1)
"""
left, right = 0, len(nums) - 1
while left < right:
mid = left + (right - left) // 2
if nums[mid] > nums[right]:
left = mid + 1 # Minimum is in right half
else:
right = mid # Minimum is at mid or in left half
return nums[left]
Problem 6: Search a 2D Matrix
Each row of a matrix is sorted, and the first element of each row is greater than the last element of the previous row. Search for a target value.
# Input: matrix = [
# [1, 3, 5, 7],
# [10, 11, 16, 20],
# [23, 30, 34, 60]
# ], target = 3
# Output: True
# ML context: Searching in a feature matrix where features are sorted
# within each sample and samples are sorted by magnitude.
Hint 1
Hint 2
Hint 3
Solution
def search_matrix(matrix, target):
"""Binary search on a row-sorted, column-connected matrix.
Time: O(log(m * n)), Space: O(1)
"""
if not matrix or not matrix[0]:
return False
rows, cols = len(matrix), len(matrix[0])
left, right = 0, rows * cols - 1
while left <= right:
mid = left + (right - left) // 2
val = matrix[mid // cols][mid % cols]
if val == target:
return True
elif val < target:
left = mid + 1
else:
right = mid - 1
return False
Problem 7: Sort Array by Feature Importance (Custom Sort)
Given a list of features with their importance scores and computation costs, sort them by importance-to-cost ratio in descending order. If two features have the same ratio, sort by importance descending.
# Input: features = [
# ("embed_dim", 0.8, 2.0), # (name, importance, cost)
# ("num_heads", 0.6, 1.0),
# ("learning_rate", 0.9, 3.0),
# ("batch_size", 0.3, 1.0),
# ("dropout", 0.6, 2.0),
# ]
# Output: ["num_heads", "embed_dim", "batch_size", "learning_rate", "dropout"]
# Ratios: num_heads=0.6, embed_dim=0.4, batch_size=0.3, learning_rate=0.3, dropout=0.3
# Tie at 0.3: learning_rate (0.9) > batch_size (0.3) > dropout (0.3)... but batch_size and dropout
# have same importance 0.3 vs 0.6 - actually dropout importance is 0.6, so dropout before batch_size
Hint 1
Hint 2
Hint 3
Solution
def sort_by_efficiency(features):
"""Sort features by importance/cost ratio (descending), then importance (descending).
Time: O(n log n), Space: O(n)
"""
sorted_features = sorted(
features,
key=lambda f: (-f[1] / f[2], -f[1]) # (-ratio, -importance)
)
return [f[0] for f in sorted_features]
Problem 8: Median of Two Sorted Arrays
Find the median of two sorted arrays in O(log(min(m, n))) time.
# Input: nums1 = [1, 3], nums2 = [2]
# Output: 2.0
# Input: nums1 = [1, 2], nums2 = [3, 4]
# Output: 2.5
Hint 1
Hint 2
Hint 3
Solution
def find_median_sorted_arrays(nums1, nums2):
"""Find median of two sorted arrays.
Time: O(log(min(m, n))), Space: O(1)
"""
# Ensure nums1 is the shorter array
if len(nums1) > len(nums2):
nums1, nums2 = nums2, nums1
m, n = len(nums1), len(nums2)
total_left = (m + n + 1) // 2
left, right = 0, m
while left <= right:
i = (left + right) // 2 # Partition point in nums1
j = total_left - i # Partition point in nums2
max_left1 = float('-inf') if i == 0 else nums1[i - 1]
min_right1 = float('inf') if i == m else nums1[i]
max_left2 = float('-inf') if j == 0 else nums2[j - 1]
min_right2 = float('inf') if j == n else nums2[j]
if max_left1 <= min_right2 and max_left2 <= min_right1:
# Found the correct partition
if (m + n) % 2 == 1:
return max(max_left1, max_left2)
else:
return (max(max_left1, max_left2) +
min(min_right1, min_right2)) / 2.0
elif max_left1 > min_right2:
right = i - 1 # Move partition left in nums1
else:
left = i + 1 # Move partition right in nums1
raise ValueError("Input arrays are not sorted")
Problem 9: Sort a Nearly Sorted Array (K-Sorted)
An array where each element is at most K positions from its sorted position. Sort it efficiently.
# Input: arr = [6, 5, 3, 2, 8, 10, 9], k = 3
# Output: [2, 3, 5, 6, 8, 9, 10]
# Each element is at most 3 positions from its correct position
# ML context: Re-ordering training batches that are approximately
# sorted by difficulty but have local shuffling for regularization.
Hint 1
Hint 2
Hint 3
Solution
import heapq
def sort_nearly_sorted(arr, k):
"""Sort a K-sorted array using a min-heap.
Time: O(n log k), Space: O(k)
"""
heap = arr[:k + 1]
heapq.heapify(heap)
result = []
for i in range(k + 1, len(arr)):
result.append(heapq.heapreplace(heap, arr[i]))
# Drain remaining elements from heap
while heap:
result.append(heapq.heappop(heap))
return result
Problem 10: Interval Scheduling - Maximum Non-Overlapping Predictions
Given a list of model inference tasks with start and end times, find the maximum number of non-overlapping tasks that can be scheduled.
# Input: tasks = [(1, 3), (2, 5), (4, 7), (6, 8), (5, 9), (8, 10)]
# Output: 3 (tasks: (1,3), (4,7), (8,10))
# ML context: Scheduling inference requests on a GPU where each
# request occupies the GPU for a fixed duration.
Hint 1
Hint 2
Hint 3
Solution
def max_non_overlapping(tasks):
"""Maximum non-overlapping intervals using greedy + sort.
Time: O(n log n), Space: O(1)
"""
if not tasks:
return 0
# Sort by end time
tasks.sort(key=lambda x: x[1])
count = 1
last_end = tasks[0][1]
for start, end in tasks[1:]:
if start >= last_end: # No overlap
count += 1
last_end = end
return count
Interview Cheat Sheet
| Algorithm / Pattern | When to Use | Time | Space | Key Detail |
|---|---|---|---|---|
| Merge Sort | Stable sort needed, linked lists, external sort | O(n log n) | O(n) | <= in merge for stability |
| Quicksort | In-place sort, average case performance | O(n log n) avg | O(log n) | Always use random pivot |
| Counting Sort | Small integer range, known bounds | O(n + k) | O(n + k) | k = range of values |
| Radix Sort | Fixed-width integers, large n | O(d * n) | O(n) | d = number of digits |
| Binary Search | Sorted array, monotonic predicate | O(log n) | O(1) | Watch off-by-one errors |
| Binary Search on Answer | Optimization with monotonic feasibility | O(log(range) * check) | O(1) | Define predicate first |
| Quickselect | Kth element, no need for full sort | O(n) avg | O(1) | Random pivot essential |
| Min-Heap of size K | Top-K from stream, cannot store all data | O(n log K) | O(K) | heapq.heapreplace |
| K-Way Merge | Merge K sorted sources | O(N log K) | O(K) | heapq.merge in Python |
| Two Heaps | Streaming median | O(log n) insert | O(n) | Max-heap + min-heap |
| External Merge Sort | Data too large for memory | O(N log N) | O(chunk) | Sort chunks, then merge |
Spaced Repetition Checkpoints
| Day | Review Activity | Time |
|---|---|---|
| Day 0 | Read all patterns. Implement binary search (standard + first/last occurrence) and top-K with a heap from scratch. | 40 min |
| Day 3 | Implement quickselect and merge sort without reference. Verify on small examples. | 25 min |
| Day 7 | Solve 3 binary search problems from LeetCode (one standard, one on answer, one in rotated array). | 30 min |
| Day 14 | Without looking, write the streaming median class and the K-way merge function. Trace through examples. | 25 min |
| Day 21 | Time yourself: solve a sorting/searching Medium in under 20 minutes. Then implement external merge sort pseudocode on a whiteboard. | 40 min |
Next Steps
Sorting and searching are the workhorses of data processing in ML systems. Next, learn NumPy Interviews - the library-specific patterns that test whether you can express array operations efficiently in the NumPy paradigm, which is the foundation of every modern ML framework.
