Clustering Systems
Design production clustering systems that group similar items using hash-based and distance-based approaches for recommendations, search, and analytics.
Problem Statement
Design a Clustering System that groups millions of data points (users, documents, products, etc.) into meaningful clusters based on similarity, supporting real-time queries and incremental updates.
Functional Requirements
- Clustering algorithms: Support K-means, DBSCAN, hierarchical clustering
- Similarity metrics: Euclidean, cosine, Jaccard, custom distances
- Real-time assignment: Assign new points to clusters in <100ms
- Incremental updates: Add new data without full recomputation
- Cluster quality: Evaluate cluster cohesion and separation
- Scalability: Handle millions to billions of data points
- Query interface: Find nearest clusters, similar items, cluster statistics
- Visualization: Support for cluster visualization and exploration
Non-Functional Requirements
- Latency: p95 cluster assignment < 100ms
- Throughput: 10,000+ assignments/second
- Scalability: Support 100M+ data points
- Accuracy: High cluster quality (silhouette score > 0.5)
- Availability: 99.9% uptime
- Cost efficiency: Optimize compute and storage
- Freshness: Support near-real-time clustering updates
Understanding the Requirements
Clustering is everywhere in production ML:
Common Use Cases
| Company | Use Case | Clustering Method | Scale |
|---|---|---|---|
| Netflix | User segmentation | K-means on viewing patterns | 200M+ users |
| Spotify | Music recommendation | DBSCAN on audio features | 80M+ songs |
| News clustering | Hierarchical on doc embeddings | Billions of articles | |
| Amazon | Product categorization | K-means on product attributes | 300M+ products |
| Uber | Demand forecasting | Geospatial clustering | Real-time zones |
| Airbnb | Listing similarity | Locality-sensitive hashing | 7M+ listings |
Why Clustering Matters
- Data exploration: Understand data structure and patterns
- Dimensionality reduction: Group high-dimensional data
- Anomaly detection: Find outliers far from clusters
- Recommendation: “Users like you also liked…”
- Segmentation: Targeted marketing, personalization
- Data compression: Represent data by cluster centroids
The Hash-Based Grouping Connection
Just like the Group Anagrams problem:
| Group Anagrams | Clustering Systems | Speaker Diarization |
|---|---|---|
| Group strings by sorted chars | Group points by similarity | Group audio by speaker |
| Hash key: sorted string | Hash key: quantized vector | Hash key: voice embedding |
| O(1) lookup | LSH for fast similarity | Vector similarity |
| Exact matching | Approximate matching | Threshold-based matching |
All three use hash-based or similarity-based grouping to organize items efficiently.
High-Level Architecture
┌─────────────────────────────────────────────────────────────────┐
│ Clustering System │
└─────────────────────────────────────────────────────────────────┘
┌──────────────┐
│ Data Input │
│ (Features) │
└──────┬───────┘
│
┌──────────────────┼──────────────────┐
│ │ │
┌───────▼────────┐ ┌──────▼──────┐ ┌────────▼────────┐
│ Batch │ │ Streaming │ │ Real-time │
│ Clustering │ │ Updates │ │ Assignment │
│ │ │ │ │ │
│ - K-means │ │ - Mini-batch│ │ - Nearest │
│ - DBSCAN │ │ - Online │ │ cluster │
│ - Hierarchical │ │ updates │ │ - LSH lookup │
└───────┬────────┘ └──────┬──────┘ └────────┬────────┘
│ │ │
└──────────────────┼──────────────────┘
│
┌──────▼──────┐
│ Cluster │
│ Storage │
│ │
│ - Centroids │
│ - Metadata │
│ - Assignments│
└──────┬──────┘
│
┌──────▼──────┐
│ Query API │
│ │
│ - Find │
│ cluster │
│ - Find │
│ similar │
│ - Stats │
└─────────────┘
Key Components
- Clustering Engine: Core algorithms (K-means, DBSCAN, etc.)
- Feature Store: Pre-computed embeddings and features
- Index: Fast similarity search (Faiss, Annoy)
- Cluster Store: Centroids, assignments, metadata
- Update Service: Incremental clustering updates
- Query API: Real-time cluster assignment and search
Component Deep-Dives
1. Clustering Engine - K-Means Implementation
K-means is the most widely used clustering algorithm:
import numpy as np
from typing import List, Tuple, Optional
from dataclasses import dataclass
import logging
@dataclass
class ClusterMetrics:
"""Metrics for cluster quality."""
inertia: float # Sum of squared distances to centroids
silhouette_score: float # Cluster separation quality
n_iterations: int
converged: bool
class KMeansClustering:
"""
Production K-means clustering.
Similar to Group Anagrams:
- Anagrams: Group by exact match (sorted string)
- K-means: Group by approximate match (nearest centroid)
Both use hash-like keys for grouping:
- Anagrams: hash = sorted(string)
- K-means: hash = nearest_centroid_id
"""
def __init__(
self,
n_clusters: int = 8,
max_iters: int = 300,
tol: float = 1e-4,
init_method: str = "kmeans++",
random_state: Optional[int] = None
):
"""
Initialize K-means clusterer.
Args:
n_clusters: Number of clusters (k)
max_iters: Maximum iterations
tol: Convergence tolerance
init_method: "random" or "kmeans++"
random_state: Random seed
"""
self.n_clusters = n_clusters
self.max_iters = max_iters
self.tol = tol
self.init_method = init_method
self.random_state = random_state
self.centroids: Optional[np.ndarray] = None
self.labels: Optional[np.ndarray] = None
self.inertia: float = 0.0
self.logger = logging.getLogger(__name__)
if random_state is not None:
np.random.seed(random_state)
def fit(self, X: np.ndarray) -> 'KMeansClustering':
"""
Fit K-means to data.
Algorithm:
1. Initialize k centroids
2. Assign points to nearest centroid (like hash lookup)
3. Update centroids to mean of assigned points
4. Repeat until convergence
Args:
X: Data matrix of shape (n_samples, n_features)
Returns:
self
"""
n_samples, n_features = X.shape
if n_samples < self.n_clusters:
raise ValueError(
f"n_samples ({n_samples}) must be >= n_clusters ({self.n_clusters})"
)
# Initialize centroids
self.centroids = self._initialize_centroids(X)
# Iterative assignment and update
for iteration in range(self.max_iters):
# Assignment step: assign each point to nearest centroid
# (Like grouping strings by sorted key)
old_labels = self.labels
self.labels = self._assign_clusters(X)
# Update step: recompute centroids
old_centroids = self.centroids.copy()
self._update_centroids(X)
# Check convergence
centroid_shift = np.linalg.norm(self.centroids - old_centroids)
if centroid_shift < self.tol:
self.logger.info(f"Converged after {iteration + 1} iterations")
break
# Calculate final inertia
self.inertia = self._calculate_inertia(X)
return self
def _initialize_centroids(self, X: np.ndarray) -> np.ndarray:
"""
Initialize centroids.
K-means++ initialization:
- Choose first centroid randomly
- Choose subsequent centroids with probability proportional to distance²
- Spreads out initial centroids for better convergence
"""
n_samples = X.shape[0]
if self.init_method == "random":
# Random initialization
indices = np.random.choice(n_samples, self.n_clusters, replace=False)
return X[indices].copy()
elif self.init_method == "kmeans++":
# K-means++ initialization
centroids = []
# Choose first centroid randomly
first_idx = np.random.randint(n_samples)
centroids.append(X[first_idx])
# Choose remaining centroids
for _ in range(1, self.n_clusters):
# Calculate distances to nearest existing centroid
distances = np.min([
np.linalg.norm(X - c, axis=1) ** 2
for c in centroids
], axis=0)
# Choose next centroid with probability ∝ distance²
probabilities = distances / distances.sum()
next_idx = np.random.choice(n_samples, p=probabilities)
centroids.append(X[next_idx])
return np.array(centroids)
else:
raise ValueError(f"Unknown init_method: {self.init_method}")
def _assign_clusters(self, X: np.ndarray) -> np.ndarray:
"""
Assign each point to nearest centroid.
This is the "grouping" step (like anagram grouping).
Returns:
Array of cluster labels
"""
# Calculate distances to all centroids
# Shape: (n_samples, n_clusters)
distances = np.linalg.norm(
X[:, np.newaxis] - self.centroids,
axis=2
)
# Assign to nearest centroid
labels = np.argmin(distances, axis=1)
return labels
def _update_centroids(self, X: np.ndarray):
"""
Update centroids to mean of assigned points.
If a cluster is empty, reinitialize that centroid.
"""
for k in range(self.n_clusters):
# Get points assigned to cluster k
mask = self.labels == k
if mask.sum() > 0:
# Update to mean of assigned points
self.centroids[k] = X[mask].mean(axis=0)
else:
# Empty cluster - reinitialize randomly
self.logger.warning(f"Empty cluster {k}, reinitializing")
random_idx = np.random.randint(len(X))
self.centroids[k] = X[random_idx]
def _calculate_inertia(self, X: np.ndarray) -> float:
"""
Calculate inertia (within-cluster sum of squares).
Lower inertia = tighter clusters.
"""
inertia = 0.0
for k in range(self.n_clusters):
mask = self.labels == k
if mask.sum() > 0:
cluster_points = X[mask]
centroid = self.centroids[k]
# Sum of squared distances
inertia += np.sum((cluster_points - centroid) ** 2)
return inertia
def predict(self, X: np.ndarray) -> np.ndarray:
"""
Predict cluster labels for new data.
This is like finding anagrams of a new string:
- Hash the string (sort it)
- Look up in hash table
For K-means:
- Calculate distances to centroids
- Assign to nearest
Args:
X: Data matrix of shape (n_samples, n_features)
Returns:
Cluster labels
"""
if self.centroids is None:
raise ValueError("Model not fitted. Call fit() first.")
distances = np.linalg.norm(
X[:, np.newaxis] - self.centroids,
axis=2
)
return np.argmin(distances, axis=1)
def get_cluster_centers(self) -> np.ndarray:
"""Get cluster centroids."""
return self.centroids.copy()
def get_cluster_sizes(self) -> np.ndarray:
"""Get number of points in each cluster."""
return np.bincount(self.labels, minlength=self.n_clusters)
def calculate_silhouette_score(self, X: np.ndarray) -> float:
"""
Calculate silhouette score for cluster quality.
Score ranges from -1 to 1:
- 1: Perfect clustering
- 0: Overlapping clusters
- -1: Wrong clustering
"""
from sklearn.metrics import silhouette_score
if len(np.unique(self.labels)) < 2:
return 0.0
return silhouette_score(X, self.labels)
2. DBSCAN - Density-Based Clustering
DBSCAN doesn’t require specifying k and finds arbitrary-shaped clusters:
from sklearn.neighbors import NearestNeighbors
class DBSCANClustering:
"""
Density-Based Spatial Clustering (DBSCAN).
Advantages over K-means:
- No need to specify k
- Finds arbitrary-shaped clusters
- Handles noise/outliers
Good for:
- Geospatial data
- Data with varying density
- Anomaly detection
"""
def __init__(self, eps: float = 0.5, min_samples: int = 5):
"""
Initialize DBSCAN.
Args:
eps: Maximum distance for neighborhood
min_samples: Minimum points for core point
"""
self.eps = eps
self.min_samples = min_samples
self.labels: Optional[np.ndarray] = None
self.core_sample_indices: Optional[np.ndarray] = None
def fit(self, X: np.ndarray) -> 'DBSCANClustering':
"""
Fit DBSCAN to data.
Algorithm:
1. Find core points (points with >= min_samples neighbors within eps)
2. Form clusters by connecting core points
3. Assign border points to nearest cluster
4. Mark noise points as outliers (-1)
"""
n_samples = X.shape[0]
# Find neighbors for all points
nbrs = NearestNeighbors(radius=self.eps).fit(X)
neighborhoods = nbrs.radius_neighbors(X, return_distance=False)
# Initialize labels (-1 = unvisited)
labels = np.full(n_samples, -1, dtype=int)
# Find core points
core_samples = np.array([
len(neighbors) >= self.min_samples
for neighbors in neighborhoods
])
self.core_sample_indices = np.where(core_samples)[0]
# Assign clusters
cluster_id = 0
for idx in range(n_samples):
# Skip if already labeled or not a core point
if labels[idx] != -1 or not core_samples[idx]:
continue
# Start new cluster
self._expand_cluster(idx, neighborhoods, labels, cluster_id, core_samples)
cluster_id += 1
self.labels = labels
return self
def _expand_cluster(
self,
seed_idx: int,
neighborhoods: List[np.ndarray],
labels: np.ndarray,
cluster_id: int,
core_samples: np.ndarray
):
"""
Expand cluster from seed point using BFS.
Similar to connected component search in graphs.
"""
# Queue of points to process
queue = [seed_idx]
labels[seed_idx] = cluster_id
while queue:
current = queue.pop(0)
# Add neighbors to queue if core point
if core_samples[current]:
for neighbor in neighborhoods[current]:
if labels[neighbor] == -1:
labels[neighbor] = cluster_id
queue.append(neighbor)
def predict(self, X: np.ndarray, X_train: np.ndarray) -> np.ndarray:
"""
Predict cluster for new points.
Assign to nearest core point's cluster.
"""
if self.labels is None:
raise ValueError("Model not fitted")
# Find nearest core point for each new point
nbrs = NearestNeighbors(n_neighbors=1).fit(
X_train[self.core_sample_indices]
)
distances, indices = nbrs.kneighbors(X)
# Assign to nearest core point's cluster if within eps
labels = np.full(len(X), -1, dtype=int)
for i, (dist, idx) in enumerate(zip(distances, indices)):
if dist[0] <= self.eps:
core_idx = self.core_sample_indices[idx[0]]
labels[i] = self.labels[core_idx]
return labels
3. Hierarchical Clustering
Build a hierarchy of clusters (dendrogram):
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist
class HierarchicalClustering:
"""
Hierarchical (agglomerative) clustering.
Advantages:
- Creates hierarchy (dendrogram)
- No need to specify k upfront
- Deterministic
Disadvantages:
- O(N²) time and space
- Doesn't scale to millions of points
"""
def __init__(self, method: str = "ward", metric: str = "euclidean"):
"""
Initialize hierarchical clustering.
Args:
method: Linkage method ("ward", "average", "complete", "single")
metric: Distance metric
"""
self.method = method
self.metric = metric
self.linkage_matrix: Optional[np.ndarray] = None
self.labels: Optional[np.ndarray] = None
def fit(self, X: np.ndarray, n_clusters: int) -> 'HierarchicalClustering':
"""
Fit hierarchical clustering.
Args:
X: Data matrix
n_clusters: Number of clusters to create
"""
# Compute linkage matrix
self.linkage_matrix = linkage(X, method=self.method, metric=self.metric)
# Cut dendrogram to get clusters
self.labels = fcluster(
self.linkage_matrix,
n_clusters,
criterion='maxclust'
) - 1 # Convert to 0-indexed
return self
def predict(self, X: np.ndarray, X_train: np.ndarray, n_clusters: int) -> np.ndarray:
"""
Predict cluster for new points.
Assign to nearest training point's cluster.
"""
from sklearn.neighbors import NearestNeighbors
nbrs = NearestNeighbors(n_neighbors=1).fit(X_train)
_, indices = nbrs.kneighbors(X)
return self.labels[indices.flatten()]
4. Locality-Sensitive Hashing for Fast Clustering
For very large datasets, use LSH for approximate clustering:
from typing import Dict, Set, List
import hashlib
class LSHClustering:
"""
Locality-Sensitive Hashing for fast approximate clustering.
Similar to Group Anagrams:
- Anagrams: Hash = sorted string (exact)
- LSH: Hash = quantized vector (approximate)
Both group similar items using hash keys.
"""
def __init__(
self,
n_hash_functions: int = 10,
n_hash_tables: int = 5,
hash_size: int = 8
):
"""
Initialize LSH clusterer.
Args:
n_hash_functions: Number of hash functions per table
n_hash_tables: Number of hash tables
hash_size: Size of hash (bits)
"""
self.n_hash_functions = n_hash_functions
self.n_hash_tables = n_hash_tables
self.hash_size = hash_size
# Hash tables: table_id -> {hash_key -> [point_ids]}
self.hash_tables: List[Dict[str, List[int]]] = [
{} for _ in range(n_hash_tables)
]
# Random projection vectors for hashing
self.projection_vectors: List[List[np.ndarray]] = []
def fit(self, X: np.ndarray) -> 'LSHClustering':
"""
Build LSH index.
Args:
X: Data matrix of shape (n_samples, n_features)
"""
n_samples, n_features = X.shape
# Generate random projection vectors
for table_idx in range(self.n_hash_tables):
table_projections = []
for _ in range(self.n_hash_functions):
# Random unit vector
random_vec = np.random.randn(n_features)
random_vec /= np.linalg.norm(random_vec)
table_projections.append(random_vec)
self.projection_vectors.append(table_projections)
# Insert all points into hash tables
for point_id, point in enumerate(X):
self._insert_point(point_id, point)
return self
def _hash_point(self, point: np.ndarray, table_idx: int) -> str:
"""
Hash a point using random projections.
Similar to sorting string in anagram problem:
- Anagrams: sorted chars create hash key
- LSH: projection signs create hash key
Returns:
Hash key (binary string)
"""
projections = self.projection_vectors[table_idx]
# Sign of dot product with each projection vector
hash_bits = [
'1' if np.dot(point, proj) > 0 else '0'
for proj in projections
]
return ''.join(hash_bits)
def _insert_point(self, point_id: int, point: np.ndarray):
"""Insert point into all hash tables."""
for table_idx in range(self.n_hash_tables):
hash_key = self._hash_point(point, table_idx)
if hash_key not in self.hash_tables[table_idx]:
self.hash_tables[table_idx][hash_key] = []
self.hash_tables[table_idx][hash_key].append(point_id)
def find_similar_points(
self,
query: np.ndarray,
k: int = 10
) -> List[int]:
"""
Find k similar points to query.
Args:
query: Query point
k: Number of similar points to return
Returns:
List of point IDs
"""
candidates = set()
# Look up in all hash tables
for table_idx in range(self.n_hash_tables):
hash_key = self._hash_point(query, table_idx)
# Get points with same hash
if hash_key in self.hash_tables[table_idx]:
candidates.update(self.hash_tables[table_idx][hash_key])
# Return top k candidates
return list(candidates)[:k]
def get_clusters(self) -> List[Set[int]]:
"""
Extract clusters from hash tables.
Points in same hash bucket are in same cluster.
"""
# Aggregate across all tables
all_clusters = []
for table in self.hash_tables:
for hash_key, point_ids in table.items():
if len(point_ids) > 1:
all_clusters.append(set(point_ids))
# Merge overlapping clusters
merged = self._merge_clusters(all_clusters)
return merged
def _merge_clusters(self, clusters: List[Set[int]]) -> List[Set[int]]:
"""Merge overlapping clusters."""
if not clusters:
return []
merged = []
current = clusters[0]
for cluster in clusters[1:]:
if current & cluster: # Overlap
current |= cluster
else:
merged.append(current)
current = cluster
merged.append(current)
return merged
Data Flow
Batch Clustering Pipeline
1. Data Collection
└─> Features from data lake/warehouse
└─> Embeddings from model inference
2. Feature Engineering
└─> Normalization/scaling
└─> Dimensionality reduction (PCA, UMAP)
└─> Feature selection
3. Clustering
└─> Run K-means/DBSCAN/Hierarchical
└─> Evaluate cluster quality
└─> Store centroids and assignments
4. Indexing
└─> Build fast similarity index (Faiss)
└─> Store in cache (Redis)
└─> Expose via API
5. Monitoring
└─> Track cluster drift
└─> Alert on quality degradation
└─> Trigger retraining
Real-Time Assignment Flow
1. New point arrives
└─> Feature extraction
2. Normalize features
└─> Apply same scaling as training
3. Find nearest cluster
└─> LSH lookup (approximate)
└─> Or Faiss search (exact)
4. Return cluster ID + metadata
└─> Cluster centroid
└─> Similar points
└─> Confidence score
5. Optional: Update cluster
└─> Online learning
└─> Mini-batch update
Scaling Strategies
Horizontal Scaling - Distributed K-Means
from pyspark.ml.clustering import KMeans as SparkKMeans
from pyspark.sql import SparkSession
class DistributedKMeans:
"""
Distributed K-means using Spark.
For datasets too large for single machine.
"""
def __init__(self, n_clusters: int = 8):
self.n_clusters = n_clusters
self.spark = SparkSession.builder.appName("Clustering").getOrCreate()
self.model = None
def fit(self, data_path: str):
"""
Fit K-means on distributed data.
Args:
data_path: Path to data (S3, HDFS, etc.)
"""
# Load data
df = self.spark.read.parquet(data_path)
# Create K-means model
kmeans = SparkKMeans(
k=self.n_clusters,
seed=42,
featuresCol="features"
)
# Fit (distributed across cluster)
self.model = kmeans.fit(df)
return self
def predict(self, data_path: str, output_path: str):
"""Predict clusters for new data."""
df = self.spark.read.parquet(data_path)
predictions = self.model.transform(df)
predictions.write.parquet(output_path)
Mini-Batch K-Means for Streaming
class MiniBatchKMeans:
"""
Mini-batch K-means for streaming data.
Updates clusters incrementally as new data arrives.
"""
def __init__(self, n_clusters: int = 8, batch_size: int = 100):
self.n_clusters = n_clusters
self.batch_size = batch_size
self.centroids: Optional[np.ndarray] = None
self.counts = np.zeros(n_clusters) # Points per cluster
def partial_fit(self, X: np.ndarray) -> 'MiniBatchKMeans':
"""
Update clusters with mini-batch.
Algorithm:
1. Assign batch points to nearest centroid
2. Update centroids with learning rate
3. Use exponential moving average
Args:
X: Mini-batch of data
"""
if self.centroids is None:
# Initialize on first batch
self.centroids = X[:self.n_clusters].copy()
# Assign points to clusters
labels = self._assign_clusters(X)
# Update centroids
for k in range(self.n_clusters):
mask = labels == k
n_k = mask.sum()
if n_k > 0:
# Exponential moving average
learning_rate = n_k / (self.counts[k] + n_k)
self.centroids[k] = (
(1 - learning_rate) * self.centroids[k] +
learning_rate * X[mask].mean(axis=0)
)
self.counts[k] += n_k
return self
def _assign_clusters(self, X: np.ndarray) -> np.ndarray:
"""Assign points to nearest centroid."""
distances = np.linalg.norm(
X[:, np.newaxis] - self.centroids,
axis=2
)
return np.argmin(distances, axis=1)
Implementation: Complete System
import redis
import json
from typing import Dict, List, Optional
import numpy as np
class ProductionClusteringSystem:
"""
Complete production clustering system.
Features:
- Multiple clustering algorithms
- Fast similarity search
- Incremental updates
- Caching
- Monitoring
"""
def __init__(
self,
algorithm: str = "kmeans",
n_clusters: int = 10,
cache_enabled: bool = True
):
self.algorithm = algorithm
self.n_clusters = n_clusters
# Choose clustering algorithm
if algorithm == "kmeans":
self.clusterer = KMeansClustering(n_clusters=n_clusters)
elif algorithm == "dbscan":
self.clusterer = DBSCANClustering()
elif algorithm == "lsh":
self.clusterer = LSHClustering()
else:
raise ValueError(f"Unknown algorithm: {algorithm}")
# Cache for fast lookups
self.cache_enabled = cache_enabled
if cache_enabled:
self.cache = redis.Redis(host='localhost', port=6379, db=0)
# Training data (for incremental updates)
self.X_train: Optional[np.ndarray] = None
# Metrics
self.request_count = 0
self.cache_hits = 0
def fit(self, X: np.ndarray) -> 'ProductionClusteringSystem':
"""Fit clustering model."""
self.X_train = X.copy()
self.clusterer.fit(X)
# Cache centroids
if self.cache_enabled and hasattr(self.clusterer, 'centroids'):
self._cache_centroids()
return self
def predict(self, X: np.ndarray) -> np.ndarray:
"""Predict cluster for new points."""
self.request_count += len(X)
# Try cache first
if self.cache_enabled:
cached = self._try_cache(X)
if cached is not None:
self.cache_hits += len(cached)
return cached
# Predict
labels = self.clusterer.predict(X)
# Cache results
if self.cache_enabled:
self._cache_predictions(X, labels)
return labels
def find_similar(
self,
query: np.ndarray,
k: int = 10
) -> List[int]:
"""
Find k similar points to query.
Returns indices of similar points in training data.
"""
# Get query's cluster
cluster_id = self.predict(query.reshape(1, -1))[0]
# Find points in same cluster
if hasattr(self.clusterer, 'labels'):
same_cluster = np.where(self.clusterer.labels == cluster_id)[0]
if len(same_cluster) > k:
# Calculate distances within cluster
distances = np.linalg.norm(
self.X_train[same_cluster] - query,
axis=1
)
# Return k nearest
nearest_indices = np.argsort(distances)[:k]
return same_cluster[nearest_indices].tolist()
return same_cluster.tolist()
return []
def get_cluster_info(self, cluster_id: int) -> Dict:
"""Get information about a cluster."""
if not hasattr(self.clusterer, 'labels'):
return {}
mask = self.clusterer.labels == cluster_id
cluster_points = self.X_train[mask]
return {
"cluster_id": cluster_id,
"size": int(mask.sum()),
"centroid": (
self.clusterer.centroids[cluster_id].tolist()
if hasattr(self.clusterer, 'centroids')
else None
),
"mean": cluster_points.mean(axis=0).tolist(),
"std": cluster_points.std(axis=0).tolist(),
}
def _cache_centroids(self):
"""Cache cluster centroids in Redis."""
centroids = self.clusterer.get_cluster_centers()
for i, centroid in enumerate(centroids):
key = f"centroid:{i}"
self.cache.set(key, json.dumps(centroid.tolist()))
def _try_cache(self, X: np.ndarray) -> Optional[np.ndarray]:
"""Try to get predictions from cache."""
# Simple caching by rounding features
# In production: use better hashing
return None
def _cache_predictions(self, X: np.ndarray, labels: np.ndarray):
"""Cache predictions."""
# Implement caching strategy
pass
def get_metrics(self) -> Dict:
"""Get system metrics."""
return {
"algorithm": self.algorithm,
"n_clusters": self.n_clusters,
"request_count": self.request_count,
"cache_hit_rate": (
self.cache_hits / self.request_count
if self.request_count > 0 else 0.0
),
"training_samples": (
len(self.X_train) if self.X_train is not None else 0
),
}
# Example usage
if __name__ == "__main__":
# Generate sample data
from sklearn.datasets import make_blobs
X, y_true = make_blobs(
n_samples=10000,
n_features=10,
centers=5,
random_state=42
)
# Create clustering system
system = ProductionClusteringSystem(
algorithm="kmeans",
n_clusters=5
)
# Fit
system.fit(X[:8000]) # Train on 80%
# Predict
labels = system.predict(X[8000:]) # Test on 20%
print(f"Predicted {len(labels)} samples")
print(f"Metrics: {system.get_metrics()}")
# Find similar points
query = X[8000]
similar = system.find_similar(query, k=5)
print(f"Similar points to query: {similar}")
# Get cluster info
info = system.get_cluster_info(0)
print(f"Cluster 0 info: {info}")
Real-World Case Study: Spotify’s Music Clustering
Spotify’s Approach
Spotify clusters 80M+ songs for recommendation:
Architecture:
- Feature extraction:
- Audio features: tempo, key, loudness, etc.
- Collaborative filtering: user listening patterns
- NLP: song metadata, lyrics
- Hierarchical clustering:
- Genre-level clusters (rock, pop, etc.)
- Sub-genre clusters (indie rock, classic rock)
- Micro-clusters for precise recommendations
- Real-time assignment:
- New songs assigned via nearest centroid
- Updated daily with mini-batch K-means
- LSH for fast similarity search
- Hybrid approach:
- DBSCAN for discovering new genres
- K-means for stable clusters
- Hierarchical for taxonomy
Results:
- 80M+ songs clustered
- <50ms latency for song similarity
- +25% engagement from better recommendations
- Daily updates for new releases
Key Lessons
- Multiple algorithms work together - no one-size-fits-all
- Feature engineering matters most - better features > better algorithm
- Hierarchical structure helps - multi-level clustering
- Incremental updates essential - can’t recluster daily
- LSH enables scale - exact search doesn’t scale to 80M
Cost Analysis
Cost Breakdown (1M data points, daily clustering)
| Component | Single Machine | Distributed (10 nodes) | Cost Difference |
|---|---|---|---|
| Training (daily) | 2 hours @ $2/hr | 15 min @ $20/hr | -$1/day |
| Storage | 10GB @ $0.10/GB/month | 10GB @ $0.10/GB/month | Same |
| Queries (10K/sec) | $500/day | $100/day | -$400/day |
| Total | $502/day | $121/day | -76% |
Optimization strategies:
- Mini-batch K-means:
- Incremental updates vs full retraining
- Savings: 80% compute cost
- LSH for queries:
- Approximate vs exact search
- Savings: 90% query latency
- Caching:
- Cache frequent queries
- Hit rate 30% = 30% cost reduction
- Dimensionality reduction:
- PCA to 50D from 1000D
- Savings: 95% storage, 80% compute
Key Takeaways
✅ Clustering is everywhere: Recommendations, search, segmentation, anomaly detection
✅ K-means is workhorse: Simple, fast, scales well
✅ DBSCAN for arbitrary shapes: No need to specify k, handles outliers
✅ LSH enables scale: Hash-based approximate clustering for billions of points
✅ Mini-batch for streaming: Incremental updates without full retraining
✅ Same pattern as anagrams: Hash-based grouping (exact or approximate)
✅ Feature engineering crucial: Better features » better algorithm
✅ Multiple algorithms better: Hierarchical structure with different methods
✅ Monitoring essential: Track cluster drift and quality over time
✅ Hybrid approaches work: Combine multiple algorithms for best results
Connection to Thematic Link: Grouping Similar Items with Hash-Based Approaches
All three topics use hash-based or similarity-based grouping:
DSA (Group Anagrams):
- Hash key: sorted string (exact match)
- Grouping: O(1) hash table lookup
- Result: exact anagram groups
ML System Design (Clustering Systems):
- Hash key: quantized vector or nearest centroid
- Grouping: approximate similarity
- Result: data point clusters
Speech Tech (Speaker Diarization):
- Hash key: voice embedding
- Grouping: similarity threshold
- Result: speaker clusters
Universal Pattern
Hash-Based Grouping:
# Generic pattern for all three
def group_items(items, hash_function):
groups = defaultdict(list)
for item in items:
key = hash_function(item) # Create hash key
groups[key].append(item) # Group by key
return list(groups.values())
Applications:
- Anagrams:
hash_function = sorted - Clustering:
hash_function = nearest_centroid - Diarization:
hash_function = voice_embedding
Originally published at: arunbaby.com/ml-system-design/0015-clustering-systems
If you found this helpful, consider sharing it with others who might benefit.