K-Means Clustering
The Real Interview Moment
You are in a system design round at a growth-stage e-commerce startup. The interviewer describes the problem: five million users, forty behavioral features each - days since last purchase, average order value, category preferences, time-of-day browsing patterns, support ticket history, mobile vs. desktop ratio. None of it is labeled. The marketing team needs to send different email campaigns to different user types, but nobody has defined what those types are. Product wants three distinct onboarding flows. The data science team keeps hearing "our users are different" without any quantitative definition of how. How would you discover those user types from the raw behavioral data?
You say K-means. The interviewer nods and asks you to walk through it. You explain Lloyd's algorithm step by step. They ask why it converges - you derive the monotone decrease argument, explaining that both the assignment step and the update step provably cannot increase the objective. They ask what happens with bad initialization - you explain K-means++ and its approximation guarantee, stepping through D² sampling. They ask how you would pick K - you lay out the elbow method, the silhouette score, and BIC via Gaussian Mixture Models. They ask what happens at 10 million users - you explain Mini-batch K-means, batch sizes, and partial_fit() for online updates. You walk out with an offer.
This is not a hypothetical scenario. K-means is one of the most commonly tested algorithms in data science and ML engineering interviews. It appears in system design rounds, ML fundamentals rounds, and practical coding assessments alike. Interviewers use it because it is deceptively simple on the surface - any practitioner can recite the two-step loop - but deep knowledge separates junior candidates from senior ones. The convergence proof, the initialization theory, the evaluation metrics, the scaling strategies, and the production failure modes are all fair game.
This lesson covers everything you need to understand K-means deeply - not just the mechanics, but the convergence proofs, initialization theory, evaluation metrics, and production scaling patterns. By the end you will be able to implement it from scratch, explain every design decision, and discuss its limitations and alternatives with confidence.
Why K-Means Exists: The Clustering Problem
Before supervised learning existed at scale, before neural networks were practical, statisticians faced a fundamental question: how do you find structure in data without labels? The naive answer is to enumerate all possible groupings - but with points and groups, the number of possible assignments is . With 1,000 users and 5 clusters, that is - a number with 700 digits. No computer in any era can enumerate that space.
Early hierarchical approaches (agglomerative clustering) built trees of merges but ran in or time, making them impractical for even moderately large datasets. The fundamental challenge was finding a tractable approximation to the exponentially large optimization problem.
The insight behind K-means, developed by Stuart Lloyd at Bell Labs in 1957 (published in an internal technical report, then formally in IEEE Transactions in 1982), was that you could attack this problem through coordinate descent: fix the cluster assignments and optimize the centroids (a trivial mean computation), then fix the centroids and optimize the assignments (a simple nearest-neighbor lookup). Neither step is hard in isolation. Alternating between them drives the objective downward monotonically until convergence.
This made clustering tractable for the first time. It ran on 1950s hardware. It produced interpretable results - centroids you could inspect, print, and label. And it scaled linearly in the number of data points per iteration, making it usable on datasets that would choke any combinatorial approach. K-means remains the most widely deployed clustering algorithm over sixty years later for these same reasons.
Historical Context
1957: Stuart Lloyd derives the algorithm at Bell Labs for pulse-code modulation - quantizing continuous signals into discrete levels. The "clusters" are signal values, the "centroids" are the quantization levels. The problem is exactly K-means: minimize the sum of squared quantization errors.
1965: Edward Forgy publishes essentially the same algorithm, leading to the occasional name "Lloyd-Forgy algorithm."
1967: James MacQueen coins the term "K-means" and proves convergence.
2007: David Arthur and Sergei Vassilvitskii publish K-means++ with the first theoretical approximation guarantee for initialization.
2010: David Sculley introduces Mini-batch K-means, enabling K-means on web-scale datasets.
The key insight that makes K-means special is not the algorithm itself - the two-step loop is obvious once you frame it as coordinate descent. The insight is the choice of objective: within-cluster sum of squared distances. This specific loss function makes both steps analytically solvable, creating the fast closed-form update that powers the algorithm.
The Objective: Within-Cluster Sum of Squares
K-means minimizes the within-cluster sum of squared distances (WCSS), also called inertia:
where is the centroid of cluster and is the set of points assigned to it.
This objective has a clear geometric interpretation: we want each point to be as close as possible to the center of its cluster. Minimizing this is equivalent to minimizing the average squared distance from any point to its cluster mean - which captures our intuitive notion of "tight, compact clusters."
The Voronoi interpretation: once centroids are fixed, the optimal assignment partitions the space into Voronoi cells - each point belongs to the region closer to its centroid than to any other. The boundaries between clusters are hyperplanes equidistant from neighboring centroids. This is why K-means always produces convex, linearly-bounded cluster regions - a critical limitation for non-spherical data.
Note that WCSS can always be decreased by increasing - more clusters means each point can be closer to a centroid. This is why you cannot pick by minimizing WCSS alone; you need a separate selection criterion that accounts for model complexity.
Lloyd's Algorithm: The Core Loop
Step 1 - Assignment: Assign each point to its nearest centroid.
Step 2 - Update: Recompute each centroid as the mean of its assigned points.
Repeat until convergence (no assignments change, or the centroid shift drops below a tolerance).
Convergence Proof: Monotone Decrease of the Objective
Why does this always converge? Both steps provably decrease (or maintain) the WCSS objective.
Assignment step: Given fixed centroids , assigning each point to its nearest centroid minimizes exactly. If a point were assigned to a non-nearest centroid, moving it to the nearest one would decrease its squared distance contribution. So the assignment step produces the globally optimal assignment for the current centroids - it cannot increase .
Update step: Given fixed assignments , the mean minimizes the sum of squared distances within each cluster. This is a classical result: for a set of points , the minimizer of over is their arithmetic mean . So the update step produces the globally optimal centroids for the current assignments - it cannot increase .
Since is bounded below by zero and decreases at every step, the sequence of values must converge. Because there are finitely many possible assignments of points to clusters ( possibilities, finite), the algorithm terminates in a finite number of steps.
The catch: convergence is guaranteed, but only to a local minimum, not the global one. The final solution depends on initialization. Different starting centroids can lead to very different local optima. This is why initialization strategy and multiple restarts matter so much in practice.
Full NumPy Implementation from Scratch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
def kmeans(X: np.ndarray, K: int, max_iter: int = 300,
tol: float = 1e-4, random_state: int = 42):
"""
K-means clustering using Lloyd's algorithm.
Parameters
----------
X : (n, d) array of data points
K : number of clusters
max_iter : maximum number of iterations
tol : convergence tolerance on centroid shift (L2 norm)
Returns
-------
centroids : (K, d) array
labels : (n,) integer cluster assignments
inertia_history : list of WCSS at each iteration
"""
rng = np.random.default_rng(random_state)
n, d = X.shape
# Random initialization - naive, will improve with K-means++
idx = rng.choice(n, size=K, replace=False)
centroids = X[idx].copy() # (K, d)
inertia_history = []
labels = np.zeros(n, dtype=int)
for iteration in range(max_iter):
# --- Assignment step ---
# Broadcast: (n, 1, d) - (1, K, d) = (n, K, d) -> sum -> (n, K)
dists = np.sum(
(X[:, None, :] - centroids[None, :, :]) ** 2,
axis=2
)
new_labels = np.argmin(dists, axis=1) # (n,)
# Inertia: sum of squared distances to assigned centroids
inertia = dists[np.arange(n), new_labels].sum()
inertia_history.append(inertia)
# --- Update step ---
new_centroids = np.zeros_like(centroids)
for k in range(K):
mask = new_labels == k
if mask.sum() > 0:
new_centroids[k] = X[mask].mean(axis=0)
else:
# Empty cluster: reinitialize to a random point
# This avoids the pathological case where K > number of modes
new_centroids[k] = X[rng.integers(n)]
# Check convergence: how much did centroids move?
shift = np.linalg.norm(new_centroids - centroids)
centroids = new_centroids
labels = new_labels
if shift < tol:
print(f"Converged at iteration {iteration + 1} (shift={shift:.6f})")
break
return centroids, labels, inertia_history
# Generate synthetic data with 4 blobs
X, y_true = make_blobs(n_samples=500, centers=4,
cluster_std=0.8, random_state=42)
centroids, labels, history = kmeans(X, K=4)
# Visualize results
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].scatter(X[:, 0], X[:, 1], c=labels, cmap='tab10',
alpha=0.6, s=20)
axes[0].scatter(centroids[:, 0], centroids[:, 1],
c='red', marker='X', s=200, zorder=5, label='Centroids')
axes[0].set_title("K-Means Clustering Result (K=4)")
axes[0].legend()
axes[1].plot(history, marker='o', ms=4, color='steelblue')
axes[1].set_xlabel("Iteration")
axes[1].set_ylabel("Inertia (WCSS)")
axes[1].set_title("Convergence Curve - Monotone Decrease")
axes[1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print(f"Final inertia: {history[-1]:.2f}")
print(f"Cluster sizes: {np.bincount(labels)}")
K-Means++ Initialization: D² Sampling
Naive random initialization picks random points from the dataset as starting centroids. This often places multiple centroids in the same dense region while leaving other modes uncovered, leading to poor local optima and slow convergence. The pathological case: all centroids land inside the same cluster, and the algorithm needs many more iterations (or multiple restarts) to escape to a reasonable solution.
K-means++, introduced by Arthur and Vassilvitskii (2007), uses a probabilistic initialization that deliberately spreads centroids apart. The core idea is D² sampling - choosing each new centroid with probability proportional to the squared distance from the nearest already-chosen centroid.
The D² Sampling Algorithm
- Choose the first centroid uniformly at random from the data: where .
- For each subsequent centroid (where ):
- Compute for each data point - the squared distance to the nearest already-chosen centroid.
- Sample the next centroid with probability:
- Repeat until centroids are chosen, then run Lloyd's algorithm normally.
The intuition: points far from existing centroids are more likely to be chosen as the next centroid. A point in an already-covered dense region has small and is unlikely to be selected again. A point in an unexplored region has large and is likely to be selected. This ensures the initial centroids cover the data space.
Theoretical Guarantee
K-means++ provides an expected cost of times the optimal WCSS before any Lloyd's iterations. Formally:
where is the globally optimal WCSS. This is a multiplicative approximation guarantee. Random initialization has no such guarantee and can be arbitrarily bad - consider clusters where random initialization places all centroids in one cluster. K-means++ cannot make that mistake with high probability.
In practice, K-means++ reduces the number of Lloyd iterations needed, reduces the frequency of poor local optima, and gives lower variance across runs.
def kmeans_plus_plus_init(X: np.ndarray, K: int,
rng: np.random.Generator) -> np.ndarray:
"""
K-means++ centroid initialization (D² sampling).
Each centroid is chosen with probability proportional to its
squared distance from the nearest existing centroid.
Parameters
----------
X : (n, d) data array
K : number of centroids to initialize
rng : numpy random Generator
Returns
-------
centroids : (K, d) initial centroid positions
"""
n, d = X.shape
centroids = []
# Step 1: Choose first centroid uniformly at random
first_idx = rng.integers(n)
centroids.append(X[first_idx].copy())
for _ in range(K - 1):
# Stack current centroids: (len(centroids), d)
centers_arr = np.array(centroids)
# Squared distances from each point to its nearest centroid
# Shapes: X is (n, d), centers_arr is (m, d)
# (n, 1, d) - (1, m, d) = (n, m, d) -> sum over d -> (n, m)
sq_dists = np.sum(
(X[:, None, :] - centers_arr[None, :, :]) ** 2,
axis=2
)
# D(x)² for each point: distance to nearest existing centroid
min_sq_dists = sq_dists.min(axis=1) # (n,)
# Normalize to get probabilities: P[x selected] ∝ D(x)²
probs = min_sq_dists / min_sq_dists.sum()
next_idx = rng.choice(n, p=probs)
centroids.append(X[next_idx].copy())
return np.array(centroids) # (K, d)
def kmeans_pp(X: np.ndarray, K: int, max_iter: int = 300,
tol: float = 1e-4, random_state: int = 42):
"""K-means with K-means++ initialization."""
rng = np.random.default_rng(random_state)
n, d = X.shape
# K-means++ initialization - O(n*K) cost, worth it
centroids = kmeans_plus_plus_init(X, K, rng)
labels = np.zeros(n, dtype=int)
for iteration in range(max_iter):
# Assignment
dists = np.sum(
(X[:, None, :] - centroids[None, :, :]) ** 2, axis=2
)
new_labels = np.argmin(dists, axis=1)
# Update
new_centroids = np.zeros_like(centroids)
for k in range(K):
mask = new_labels == k
if mask.sum() > 0:
new_centroids[k] = X[mask].mean(axis=0)
else:
new_centroids[k] = X[rng.integers(n)]
shift = np.linalg.norm(new_centroids - centroids)
centroids = new_centroids
labels = new_labels
if shift < tol:
break
inertia = np.sum((X - centroids[labels]) ** 2)
return centroids, labels, inertia
# Compare naive vs K-means++ initialization over 20 random seeds
inertias_naive = []
inertias_pp = []
for seed in range(20):
_, _, inertia_hist = kmeans(X, K=4, random_state=seed)
inertias_naive.append(inertia_hist[-1])
_, _, inertia = kmeans_pp(X, K=4, random_state=seed)
inertias_pp.append(inertia)
print(f"Naive init: mean inertia {np.mean(inertias_naive):.1f}, "
f"std {np.std(inertias_naive):.1f}, "
f"worst {np.max(inertias_naive):.1f}")
print(f"K-means++ init: mean inertia {np.mean(inertias_pp):.1f}, "
f"std {np.std(inertias_pp):.1f}, "
f"worst {np.max(inertias_pp):.1f}")
# K-means++ is consistently better and dramatically less variable
In practice, use sklearn.cluster.KMeans(init='k-means++') which is the default. The n_init parameter controls how many independent restarts to run; the best result by inertia is kept.
from sklearn.cluster import KMeans
import numpy as np
kmeans_sklearn = KMeans(
n_clusters=4,
init='k-means++', # default - D² sampling
n_init=10, # 10 independent restarts; keep best by inertia
max_iter=300,
tol=1e-4,
random_state=42
)
kmeans_sklearn.fit(X)
print(f"Inertia: {kmeans_sklearn.inertia_:.2f}")
print(f"Iterations to converge: {kmeans_sklearn.n_iter_}")
print(f"Cluster sizes: {np.bincount(kmeans_sklearn.labels_)}")
print(f"Centroids shape: {kmeans_sklearn.cluster_centers_.shape}")
Choosing K: Three Methods
The central challenge in K-means is choosing . With too few clusters, you miss structure; with too many, you overfit noise and create clusters that correspond to nothing meaningful in the real world. There is no universally correct answer - the right depends on both the data geometry and the downstream use case.
Method 1: Elbow Method
Plot inertia (WCSS) vs . As increases, inertia always decreases - adding more clusters always lets each point be closer to its centroid. But the rate of decrease typically slows at the "true" . Look for the bend (elbow) in the curve where the marginal gain from adding one more cluster drops sharply.
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import numpy as np
inertias = []
K_range = range(2, 15)
for k in K_range:
km = KMeans(n_clusters=k, init='k-means++', n_init=10, random_state=42)
km.fit(X)
inertias.append(km.inertia_)
# Find elbow automatically using second derivative (largest curvature point)
inertias_arr = np.array(inertias)
second_diff = np.diff(np.diff(inertias_arr))
elbow_k = list(K_range)[np.argmax(second_diff) + 1]
plt.figure(figsize=(9, 5))
plt.plot(K_range, inertias, 'bo-', ms=6, lw=2)
plt.axvline(x=elbow_k, color='red', linestyle='--',
label=f'Elbow at K={elbow_k}')
plt.xlabel("Number of Clusters K")
plt.ylabel("Inertia (WCSS)")
plt.title("Elbow Method for Optimal K")
plt.xticks(K_range)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
print(f"Suggested K by elbow: {elbow_k}")
:::warning Elbow Ambiguity The elbow method often produces ambiguous results on real data. Plots with smooth, gradual curves have no clear elbow - the curve looks like an arm with no obvious bend. Use it as a coarse filter (narrow the range to 2–3 candidate values) and combine with the silhouette score for a final decision. Never rely on the elbow method alone in production. :::
Method 2: Silhouette Score
The silhouette score measures how similar a point is to its own cluster versus other clusters. For each point :
- : mean distance from point to all other points in the same cluster (within-cluster cohesion)
- : mean distance from point to all points in the nearest other cluster (separation from others)
Interpretation:
- : point is well inside its cluster, far from others - correctly clustered
- : point is near the boundary between clusters - ambiguous assignment
- : point is closer to another cluster than its own - likely misassigned
The average silhouette score across all points measures overall clustering quality. Higher is better, and scores above 0.5 indicate reasonable structure.
from sklearn.metrics import silhouette_score, silhouette_samples
import matplotlib.pyplot as plt
import numpy as np
silhouette_scores = []
K_range = range(2, 12)
for k in K_range:
km = KMeans(n_clusters=k, init='k-means++', n_init=10, random_state=42)
labels = km.fit_predict(X)
score = silhouette_score(X, labels)
silhouette_scores.append(score)
print(f"K={k:2d} silhouette={score:.4f}")
best_k = list(K_range)[np.argmax(silhouette_scores)]
print(f"\nBest K by silhouette score: {best_k}")
# Silhouette plot for the best K - shows per-point scores grouped by cluster
km_best = KMeans(n_clusters=best_k, n_init=10, random_state=42)
labels_best = km_best.fit_predict(X)
sample_silhouettes = silhouette_samples(X, labels_best)
fig, ax = plt.subplots(figsize=(8, 5))
y_lower = 10
for k in range(best_k):
cluster_sils = np.sort(sample_silhouettes[labels_best == k])
size = cluster_sils.shape[0]
y_upper = y_lower + size
ax.fill_betweenx(np.arange(y_lower, y_upper),
0, cluster_sils, alpha=0.7)
ax.text(-0.05, y_lower + 0.5 * size, str(k))
y_lower = y_upper + 10
avg_sil = silhouette_score(X, labels_best)
ax.axvline(x=avg_sil, color='red', linestyle='--',
label=f'Mean silhouette = {avg_sil:.3f}')
ax.set_xlabel("Silhouette Coefficient")
ax.set_ylabel("Cluster")
ax.set_title(f"Silhouette Plot - K={best_k}")
ax.legend()
plt.tight_layout()
plt.show()
The silhouette plot reveals which individual clusters are well-formed (wide, positive bars) and which are problematic (narrow, or bars crossing the zero line - misassigned points). This per-cluster view is more informative than the average score alone.
Method 3: BIC with Gaussian Mixture Models
BIC (Bayesian Information Criterion) penalizes model complexity, so it naturally resists adding unnecessary clusters. While technically associated with GMMs, BIC gives a principled answer that K-means heuristics cannot:
where is the model likelihood, is the number of parameters, and is the number of samples. Lower BIC is better. It balances goodness of fit against model complexity - adding more clusters increases but also increases , so BIC only rewards clusters that meaningfully improve the likelihood.
from sklearn.mixture import GaussianMixture
import numpy as np
bic_scores = []
aic_scores = []
K_range = range(2, 15)
for k in K_range:
gmm = GaussianMixture(
n_components=k,
covariance_type='full',
n_init=3,
random_state=42
)
gmm.fit(X)
bic_scores.append(gmm.bic(X))
aic_scores.append(gmm.aic(X))
best_k_bic = list(K_range)[np.argmin(bic_scores)]
best_k_aic = list(K_range)[np.argmin(aic_scores)]
print(f"Best K by BIC: {best_k_bic}")
print(f"Best K by AIC: {best_k_aic}")
Rule of thumb: Use silhouette score as your primary metric for K selection. Use BIC when you want a more principled, probabilistic justification and can afford the GMM fitting cost. Use the elbow method only as a sanity check. In production, combine all three: if they agree on K, you have strong evidence; if they disagree, examine the cluster profiles at each candidate K and choose based on business interpretability.
Evaluation Metrics Beyond Silhouette
Davies-Bouldin Index
The Davies-Bouldin Index (DBI) measures the ratio of within-cluster scatter to between-cluster separation. For each cluster :
where is the average distance of points in cluster to their centroid, and is the distance between centroids. Lower DBI is better - you want tight clusters ( small) that are far from each other ( large).
Calinski-Harabasz Index
Also called the Variance Ratio Criterion:
Higher CH is better. It is an F-ratio comparing the variance between clusters to the variance within clusters - the same logic as ANOVA. It rewards compact, well-separated clusters and penalizes many small clusters that are close together.
from sklearn.metrics import (davies_bouldin_score,
calinski_harabasz_score,
silhouette_score)
import numpy as np
km_final = KMeans(n_clusters=4, init='k-means++',
n_init=20, random_state=42)
labels = km_final.fit_predict(X)
sil = silhouette_score(X, labels)
db = davies_bouldin_score(X, labels) # Lower is better
ch = calinski_harabasz_score(X, labels) # Higher is better
print(f"Silhouette Score: {sil:.4f} (higher = better, max 1.0)")
print(f"Davies-Bouldin Index: {db:.4f} (lower = better, min 0)")
print(f"Calinski-Harabasz Index: {ch:.2f} (higher = better)")
# Full cluster summary with per-cluster statistics
for k in range(4):
mask = labels == k
print(f"\nCluster {k}: {mask.sum()} points")
print(f" Centroid: {km_final.cluster_centers_[k].round(2)}")
cluster_dists = np.linalg.norm(
X[mask] - km_final.cluster_centers_[k], axis=1
)
print(f" Mean dist to centroid: {cluster_dists.mean():.3f}")
print(f" Max dist to centroid: {cluster_dists.max():.3f}")
print(f" Radius (95th pct): {np.percentile(cluster_dists, 95):.3f}")
:::tip Combining Metrics No single metric is universally best. Use all three in practice: silhouette rewards well-separated clusters, DBI rewards compact clusters relative to their separation, and CH rewards between-cluster variance. If all three agree on a K value, it is a strong signal. Report all three metrics in any clustering analysis - reviewers and stakeholders will ask about them. :::
K-Means Limitations
Limitation 1: Assumes Spherical, Equal-Sized Clusters
K-means partitions space into Voronoi cells - convex regions separated by linear (hyperplane) boundaries. This implicitly assumes clusters are spherical and roughly equal in size. It fails catastrophically on elongated, crescent-shaped, or concentric ring clusters.
from sklearn.datasets import make_moons, make_circles
# K-means fails on moons - the "clusters" span both crescents
X_moons, y_moons = make_moons(n_samples=300, noise=0.05, random_state=42)
km_moons = KMeans(n_clusters=2, random_state=42)
labels_moons = km_moons.fit_predict(X_moons)
# The result: K-means cuts through the moons with a vertical line
# because that is the Voronoi boundary between the two centroids
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].scatter(X_moons[:, 0], X_moons[:, 1],
c=y_moons, cmap='bwr', s=20, alpha=0.7)
axes[0].set_title("True Labels (two crescents)")
axes[1].scatter(X_moons[:, 0], X_moons[:, 1],
c=labels_moons, cmap='bwr', s=20, alpha=0.7)
axes[1].scatter(km_moons.cluster_centers_[:, 0],
km_moons.cluster_centers_[:, 1],
c='red', marker='X', s=200, zorder=5)
axes[1].set_title("K-Means Result - fails with linear boundary")
plt.tight_layout()
plt.show()
For non-spherical clusters, use DBSCAN (Lesson 03) or spectral clustering (which transforms data into a Laplacian eigenspace where clusters become linearly separable).
Limitation 2: Sensitive to Scale
K-means uses Euclidean distance. Features on different scales dominate the distance computation - a feature measured in thousands (e.g., income) will dominate over a feature measured in units (e.g., number of children). Always standardize before clustering:
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
pipeline = Pipeline([
('scaler', StandardScaler()), # mean=0, std=1 for each feature
('kmeans', KMeans(n_clusters=4,
init='k-means++',
n_init=10,
random_state=42))
])
pipeline.fit(X)
labels = pipeline.named_steps['kmeans'].labels_
Limitation 3: Sensitive to Outliers
A single outlier can pull a centroid far from the cluster mass, distorting the entire cluster boundary. Solutions:
- Remove outliers beforehand with DBSCAN or Isolation Forest
- Use K-medoids (uses actual data points as centroids - more robust, but cost)
- Use robust scaling (
RobustScalerbased on median and IQR) before K-means
Limitation 4: Converges to Local Optima
Multiple restarts (n_init=10 or more) reduce but do not eliminate the risk of poor local optima. For critical applications, run 20–50 restarts. With K-means++, 10 restarts is typically sufficient for most datasets.
Mini-Batch K-Means for Large Datasets
Standard K-means requires computing distances from every point to every centroid at each iteration. At 10 million rows and clusters in 50-dimensional space, the assignment step alone is operations per iteration. This is prohibitive.
Mini-batch K-means, introduced by David Sculley (2010), updates centroids using small random batches at each step, drastically reducing memory usage and computation time while sacrificing only a small amount of quality.
The Mini-Batch Update Rule
At each step, sample a mini-batch of size from the data. Assign each point in to its nearest centroid, then update the centroids with a per-centroid running average:
where is the count of points assigned to centroid so far (across all batches), and is the subset of the mini-batch assigned to cluster . This is a stochastic gradient descent step with learning rate that naturally decays as grows.
The algorithm uses the max_no_improvement parameter for early stopping: if the inertia does not improve for a fixed number of consecutive batches, training stops. This prevents unnecessary computation when centroids have effectively converged.
from sklearn.cluster import MiniBatchKMeans
import time
import numpy as np
# Simulate large dataset: 1 million points, 20 dimensions
rng = np.random.default_rng(42)
X_large = rng.standard_normal((1_000_000, 20))
# Standard K-means would load all 1M rows per iteration
# Mini-batch K-means processes batches of 10K per step
t0 = time.time()
mbkm = MiniBatchKMeans(
n_clusters=10,
batch_size=10_000, # points per mini-batch (typically 1–5% of data)
n_init=5, # fewer restarts needed than full K-means
max_iter=100, # max passes over the data (in batch units)
max_no_improvement=10, # stop if inertia doesn't improve for 10 batches
reassignment_ratio=0.01, # fraction of centroids to reassign if they stall
random_state=42,
verbose=0
)
mbkm.fit(X_large)
elapsed = time.time() - t0
print(f"Mini-batch K-means on 1M × 20: {elapsed:.2f}s")
print(f"Inertia: {mbkm.inertia_:.2f}")
# Standard K-means comparison (on a subset - full would be too slow)
from sklearn.cluster import KMeans
t0 = time.time()
km_subset = KMeans(n_clusters=10, n_init=5, random_state=42)
km_subset.fit(X_large[:100_000]) # 100K subset
elapsed_subset = time.time() - t0
print(f"\nStandard K-means on 100K subset: {elapsed_subset:.2f}s")
print(f"Mini-batch processes 10x more data in approximately the same time")
Online Clustering with partial_fit()
For streaming data, Mini-batch K-means can be updated incrementally as new batches arrive. This enables online clustering where the model evolves as new data comes in:
from sklearn.cluster import MiniBatchKMeans
import numpy as np
# Initialize model
mbkm_online = MiniBatchKMeans(
n_clusters=6,
batch_size=1_000,
n_init=3,
random_state=42
)
# Simulate streaming data - 100 batches of 1000 points each
for batch_idx in range(100):
# In production: fetch from Kafka, Kinesis, or database
batch = np.random.randn(1_000, 20)
mbkm_online.partial_fit(batch) # incremental update
if (batch_idx + 1) % 10 == 0:
print(f"Batch {batch_idx+1}: inertia = {mbkm_online.inertia_:.2f}")
# Score new points after training
X_new = np.random.randn(100, 20)
labels_new = mbkm_online.predict(X_new)
:::tip Mini-Batch in Production
For user segmentation systems that must reflect today's behavior, not last month's, Mini-batch K-means with partial_fit() enables live segment updating. Retrain the full model weekly with historical data; update centroids daily with new behavioral batches. The segments drift gradually rather than jumping discontinuously.
:::
Mini-batch K-means is typically 3–10x faster than standard K-means with inertia typically less than 1–2% higher. At 10M+ rows, it is the only practical option. At 100M+ rows, consider distributed K-means implementations (Spark MLlib).
K-Means for Image Color Quantization
A beautiful application of K-means that demonstrates the algorithm's core idea concretely: reduce the number of colors in an image from millions to representative colors by clustering pixels in RGB space.
Each pixel is a point in 3D space (R, G, B). K-means finds cluster centers - the representative colors. Each pixel is replaced by the color of its nearest cluster center. The result is an image that uses only distinct colors, with each color chosen to minimize the average squared color error.
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.utils import shuffle
# Create a synthetic colorful image for demonstration
np.random.seed(42)
h, w = 100, 150
img = np.zeros((h, w, 3), dtype=np.float32)
# Add colored blobs
for _ in range(8):
cx = np.random.randint(10, w-10)
cy = np.random.randint(10, h-10)
r = np.random.randint(15, 40)
color = np.random.rand(3)
for i in range(h):
for j in range(w):
if (i - cy)**2 + (j - cx)**2 < r**2:
img[i, j] = color
img = np.clip(img + 0.05 * np.random.rand(h, w, 3), 0, 1)
# Reshape to 2D: (n_pixels, 3) - each row is one pixel's RGB values
pixels = img.reshape(-1, 3)
# Subsample for fitting (fast approximation for large images)
pixels_sample = shuffle(pixels, n_samples=min(len(pixels), 5000),
random_state=42)
# Quantize to K colors and measure compression
for K in [4, 16, 64]:
km = KMeans(n_clusters=K, n_init=3, random_state=42)
km.fit(pixels_sample)
# Replace each pixel with its nearest centroid color
quantized_pixels = km.cluster_centers_[km.predict(pixels)]
quantized_img = quantized_pixels.reshape(img.shape)
# Compression ratio: original uses 3 bytes per pixel
# Quantized uses 1 byte per pixel (index) + K * 3 bytes (palette)
print(f"K={K:3d}: compression ratio ≈ {len(pixels) * 3 / (K * 3 + len(pixels)):.1f}x")
This is the perfect mental model for K-means: centroids are representative colors, the assignment step finds the nearest color for each pixel, and the update step computes the average RGB of each color cluster. The algorithm is identical to user segmentation - only the data domain changes.
K-Means as EM for Gaussian Mixtures
K-means is a degenerate special case of the EM algorithm for Gaussian Mixture Models. Understanding this connection illuminates why K-means makes the assumptions it does.
In GMM-EM:
- E-step: compute soft cluster membership probabilities - each point has a fractional weight in each cluster, where
- M-step: update mixture weights, means , and full covariance matrices
K-means is what you get when you impose two restrictions:
- Hard assignments: replace soft weights with hard 0/1 indicators (each point goes 100% to its nearest cluster)
- Isotropic, equal covariances: constrain all with the same
Under these restrictions, the E-step becomes "assign each point to its nearest centroid" (since with isotropic Gaussians, the maximum-likelihood assignment is just the nearest mean). The M-step becomes "compute the mean of each cluster." The objective becomes the negative log-likelihood of the isotropic GMM, which is exactly WCSS.
This connection explains K-means's core assumption - that clusters are spherical with equal radii - and why Gaussian Mixture Models are the natural next step when you need elliptical clusters, soft assignments, or principled model selection via BIC.
Production Patterns
Pattern 1: Full Customer Segmentation Pipeline
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.pipeline import Pipeline
from sklearn.metrics import silhouette_score, davies_bouldin_score
# User behavioral features - in production: fetch from data warehouse
np.random.seed(42)
n_users = 5_000
user_features = pd.DataFrame({
'days_since_last_purchase': np.random.exponential(30, n_users),
'avg_order_value': np.random.lognormal(4, 1, n_users),
'purchase_frequency': np.random.poisson(3, n_users),
'browse_to_purchase_ratio': np.random.beta(2, 5, n_users),
'days_active_last_30': np.random.randint(0, 30, n_users),
'support_tickets': np.random.poisson(0.5, n_users),
'mobile_fraction': np.random.beta(3, 2, n_users),
'category_diversity': np.random.exponential(2, n_users),
})
# Step 1: Select K using silhouette scores
print("Evaluating K values...")
X = user_features.values
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
results = []
for k in range(3, 10):
km = KMeans(n_clusters=k, n_init=10, random_state=42)
labels = km.fit_predict(X_scaled)
sil = silhouette_score(X_scaled, labels)
db = davies_bouldin_score(X_scaled, labels)
results.append({'k': k, 'silhouette': sil, 'davies_bouldin': db,
'inertia': km.inertia_})
print(f" K={k}: silhouette={sil:.3f}, DBI={db:.3f}")
results_df = pd.DataFrame(results)
best_k = results_df.loc[results_df['silhouette'].idxmax(), 'k']
print(f"\nSelected K={best_k} by silhouette score")
# Step 2: Fit final model with more restarts
pipeline = Pipeline([
('scaler', StandardScaler()),
('kmeans', KMeans(n_clusters=best_k, n_init=30, random_state=42))
])
pipeline.fit(user_features)
user_features['segment'] = pipeline.named_steps['kmeans'].labels_
# Step 3: Interpret clusters in original feature units
scaler_fitted = pipeline.named_steps['scaler']
km_fitted = pipeline.named_steps['kmeans']
centroids_original = scaler_fitted.inverse_transform(km_fitted.cluster_centers_)
segment_profiles = pd.DataFrame(
centroids_original,
columns=user_features.columns[:-1]
)
segment_profiles['size'] = [
(user_features['segment'] == k).sum()
for k in range(best_k)
]
print("\nSegment profiles (original feature units):")
print(segment_profiles.round(2))
# Step 4: Business naming based on profiles
# High avg_order_value + high frequency → "High-value loyalists"
# High days_since_last_purchase + low frequency → "At-risk churners"
# High browse_to_purchase_ratio + low purchase_frequency → "Window shoppers"
Pattern 2: Fit on Train, Predict on New Data
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import numpy as np
# CRITICAL: fit scaler and K-means ONLY on training data
X_train = np.random.randn(10_000, 20)
X_new = np.random.randn(100, 20) # new users arriving today
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train) # fit_transform on train only
km = KMeans(n_clusters=6, n_init=20, random_state=42)
km.fit(X_train_scaled)
# Save model artifacts (in production: serialize with joblib)
# joblib.dump({'scaler': scaler, 'kmeans': km}, 'segmentation_model.pkl')
# Assign new users to existing segments
X_new_scaled = scaler.transform(X_new) # transform (not fit_transform!)
new_labels = km.predict(X_new_scaled) # nearest centroid assignment
distances = km.transform(X_new_scaled) # distance to each centroid
# Confidence: how much closer is the assigned cluster vs the next nearest?
assigned_dist = distances[np.arange(len(X_new)), new_labels]
second_best = np.sort(distances, axis=1)[:, 1]
confidence = 1 - assigned_dist / second_best # 0 = ambiguous, 1 = clear
print(f"New user segments: {new_labels[:10]}")
print(f"Assignment confidence: {confidence[:10].round(2)}")
:::danger Never Fit on Test Data
The most common K-means production mistake is calling fit_transform() on the test set or new data. This recomputes the mean and std using the test data, which changes the feature scaling and makes the segment labels incomparable with the training-time segments. Always fit on training data and transform on new data. This applies to the scaler, PCA, and any preprocessing step chained before K-means.
:::
Algorithm Comparison
| Property | K-Means | DBSCAN | GMM | HDBSCAN |
|---|---|---|---|---|
| Specify K | Required | Not required | Required | Not required |
| Cluster shape | Spherical (Voronoi) | Arbitrary | Elliptical | Arbitrary |
| Outlier handling | No - all points assigned | Yes - labeled noise | Soft (low probability) | Yes - soft scores |
| Time complexity | w/ index | |||
| Soft assignments | No | No | Yes | Yes (probabilities) |
| Scalability | Excellent (Mini-batch) | Moderate | Moderate | Good |
| Interpretability | High (centroids) | Medium | Medium | Medium |
| Varying density | Fails | Fails | Soft handling | Handles well |
YouTube Resources
| Title | Channel | Why Watch |
|---|---|---|
| StatQuest: K-means Clustering | StatQuest with Josh Starmer | Best visual intuition for the algorithm and convergence; animated Voronoi boundaries |
| K-means and the EM Algorithm | Serrano.Academy | Explains the K-means / GMM-EM connection clearly with great visuals |
| K-Means Clustering from Scratch | AssemblyAI | Live coding from scratch in Python - great preparation for coding interviews |
| Clustering with Scikit-Learn | Python Programmer | Production sklearn patterns, pipelines, and evaluation metrics |
| K-Means++ Explained | ritvikmath | Mathematical walkthrough of D² sampling and the O(log K) approximation guarantee |
Common Mistakes
:::danger Never Fit on Test Data
Calling scaler.fit_transform(X_test) instead of scaler.transform(X_test) recomputes normalization statistics on the test set, making segments incomparable across sets. Always fit preprocessing on training data only, then transform all other data.
:::
:::danger Forgetting to Standardize
Running K-means on raw features with different scales will produce clusters driven entirely by high-scale features. A user's income (in thousands) will dominate over their number of purchases (single digits). Standardize with StandardScaler before clustering - this is not optional.
:::
:::warning Trusting the Elbow Method Alone The elbow method is ambiguous for most real datasets. It should be used as a starting point, not a conclusion. Always cross-validate with silhouette score and domain knowledge about how many segments make business sense. :::
:::warning Using K-Means on High-Dimensional Data In high dimensions (d > 50), Euclidean distances lose discriminating power - all pairwise distances become approximately equal (the curse of dimensionality). Apply PCA or UMAP first to reduce dimensionality to 10–30 dimensions before clustering. The explained variance from PCA also gives you a natural feature importance ranking. :::
Interview Q&A
Q1: What does K-means minimize and why does it always converge?
K-means minimizes the within-cluster sum of squared distances (WCSS / inertia): . Convergence is guaranteed because both steps provably decrease the objective. The assignment step produces the globally optimal assignment given current centroids - any other assignment would increase WCSS, since each point goes to its nearest centroid. The update step sets each centroid to the mean of its points, which is the unique minimizer of sum of squared distances within a cluster. Since both steps can only decrease or leave it unchanged, and is bounded below by zero, the sequence must converge. Because the number of possible assignments is finite (), the algorithm terminates in finite iterations. Note: it converges to a local minimum, not the global one.
Q2: Why is K-means++ better than random initialization, and what is the theoretical guarantee?
Random initialization can place multiple centroids in the same dense region, leaving other modes uncovered and leading to poor local optima. K-means++ seeds centroids using D² sampling: each new centroid is chosen with probability proportional to the squared distance from the nearest existing centroid. This deliberately spreads centroids across the data, ensuring coverage of distinct modes. The theoretical guarantee (Arthur & Vassilvitskii 2007): K-means++ achieves an expected WCSS within an factor of the global optimum before any Lloyd's iterations - specifically . Random initialization has no such guarantee and can be arbitrarily bad. In practice, K-means++ reduces the number of iterations needed and the frequency of poor solutions.
Q3: How do you choose K in production when you cannot look at plots?
Use the silhouette score over a range of candidate K values and pick the maximum programmatically. Loop K from 2 to a reasonable upper bound (e.g., ), compute silhouette scores, select the highest. But in practice, the right K is often constrained by the downstream use case: a marketing team that can run 6 distinct campaigns should use K=6 regardless of the silhouette score. Frame K selection as an optimization problem for the downstream metric - test K=4, 6, 8 in an A/B experiment and pick the K that maximizes revenue or engagement, not just a clustering metric. Document the choice: what K was selected, why, and what metric it optimized.
Q4: What is the time complexity of K-means, and when should you use Mini-batch K-means?
Each iteration of standard K-means is for the assignment step, where is samples, is clusters, is dimensions. With iterations this becomes . Mini-batch K-means processes batches of size each iteration, reducing per-iteration cost to . Use Mini-batch K-means when or when you need online/streaming updates via partial_fit(). The quality trade-off is small - typically less than 1–2% higher inertia. At , Mini-batch K-means is the only practical in-memory option; at larger scales, use distributed implementations like Spark MLlib's K-means.
Q5: K-means fails on the two moons dataset. Why, and what would you use instead?
K-means assigns points to their nearest centroid, partitioning space into Voronoi cells - convex regions separated by linear (hyperplane) boundaries. The two moons form non-convex, crescent-shaped clusters that cannot be separated by any linear boundary. No matter where you place two centroids, the Voronoi boundary will cut through one of the crescents rather than following the curved shape. DBSCAN identifies clusters by density connectivity, making it the natural choice - it can follow the crescent shapes because density is locally consistent along each crescent. Spectral clustering is another option: it constructs a graph Laplacian capturing local connectivity, then applies K-means in the Laplacian eigenspace where the moons become linearly separable.
Q6: What is the relationship between K-means and the EM algorithm?
K-means is the hard-assignment limit of Expectation-Maximization (EM) for Gaussian Mixture Models. In GMM-EM, the E-step computes soft membership probabilities (each point has a fractional assignment to every cluster), and the M-step updates mixture weights, means, and full covariance matrices. K-means emerges when the GMM covariance matrices are constrained to be isotropic and equal (), and the soft assignments are replaced by hard 0/1 indicators - each point goes 100% to its nearest centroid. The K-means WCSS objective is then exactly equivalent to the negative log-likelihood of this simplified GMM. This explains why K-means implicitly assumes spherical, equal-variance clusters - those are the assumptions baked into the degenerate GMM limit.
Q7: How would you handle the case where K-means gives you an empty cluster?
An empty cluster occurs when no points are assigned to a centroid during the assignment step - usually because the initial centroid was placed in a region far from all data, and all its territory was "stolen" by better-placed centroids in subsequent iterations. The standard fix: reinitialize the empty centroid to the data point with the highest assignment distance (the point that would benefit most from a new centroid), or to a uniformly random data point. Sklearn handles this automatically. In a from-scratch implementation, always check if mask.sum() == 0 after assignment and reinitialize affected centroids. In production, log when this happens - frequent empty clusters across multiple restarts is a signal that K is too large for the actual structure in the data.
Q8: How does the silhouette score work, and what does a score of 0.3 mean?
The silhouette score for each point computes , where is the mean distance to other points in the same cluster (cohesion) and is the mean distance to points in the nearest other cluster (separation). The score ranges from -1 to +1. A score of 0.3 means the average point is moderately better assigned to its cluster than to the next-nearest cluster, but the separation is not sharp - the clusters overlap somewhat. Interpretation thresholds commonly used: 0.5–0.7 indicates reasonable structure; 0.7–1.0 indicates strong, well-separated clusters; below 0.25 indicates no meaningful cluster structure (you may be forcing structure onto random data). A silhouette score of 0.3 is typical for real-world behavioral data where user segments are genuinely fuzzy and overlapping.
Q9: Describe a production failure mode you would watch for with K-means in a recommendation system.
The most dangerous production failure is segment drift over time: customer behavior evolves seasonally or due to product changes, but the K-means model was trained months ago. Centroids represent old behavior patterns that no longer match reality. New users get assigned to segments based on outdated profiles, leading to irrelevant recommendations. Solution: monitor segment distribution over time - if the fraction of users in each segment shifts dramatically (e.g., one segment doubles in size), it is a signal that behavior has drifted beyond what the model captures. Set up automated retraining pipelines: retrain weekly with a sliding window of recent behavioral data, use Mini-batch partial_fit() for daily updates, and alert when the average silhouette score drops below a threshold (indicating that the existing centroids no longer fit the data well).
:::tip 🎮 Interactive Playground
Visualize this concept: Try the K-Means Clustering demo on the EngineersOfAI Playground - no code required.
:::
