23 minute read

Design production clustering systems that group similar items using hash-based and distance-based approaches for recommendations, search, and analytics.

Problem Statement

Design a Clustering System that groups millions of data points (users, documents, products, etc.) into meaningful clusters based on similarity, supporting real-time queries and incremental updates.

Functional Requirements

  1. Clustering algorithms: Support K-means, DBSCAN, hierarchical clustering
  2. Similarity metrics: Euclidean, cosine, Jaccard, custom distances
  3. Real-time assignment: Assign new points to clusters in <100ms
  4. Incremental updates: Add new data without full recomputation
  5. Cluster quality: Evaluate cluster cohesion and separation
  6. Scalability: Handle millions to billions of data points
  7. Query interface: Find nearest clusters, similar items, cluster statistics
  8. Visualization: Support for cluster visualization and exploration

Non-Functional Requirements

  1. Latency: p95 cluster assignment < 100ms
  2. Throughput: 10,000+ assignments/second
  3. Scalability: Support 100M+ data points
  4. Accuracy: High cluster quality (silhouette score > 0.5)
  5. Availability: 99.9% uptime
  6. Cost efficiency: Optimize compute and storage
  7. Freshness: Support near-real-time clustering updates

Understanding the Requirements

Clustering is everywhere in production ML:

Common Use Cases

Company Use Case Clustering Method Scale
Netflix User segmentation K-means on viewing patterns 200M+ users
Spotify Music recommendation DBSCAN on audio features 80M+ songs
Google News clustering Hierarchical on doc embeddings Billions of articles
Amazon Product categorization K-means on product attributes 300M+ products
Uber Demand forecasting Geospatial clustering Real-time zones
Airbnb Listing similarity Locality-sensitive hashing 7M+ listings

Why Clustering Matters

  1. Data exploration: Understand data structure and patterns
  2. Dimensionality reduction: Group high-dimensional data
  3. Anomaly detection: Find outliers far from clusters
  4. Recommendation: “Users like you also liked…”
  5. Segmentation: Targeted marketing, personalization
  6. Data compression: Represent data by cluster centroids

The Hash-Based Grouping Connection

Just like the Group Anagrams problem:

Group Anagrams Clustering Systems Speaker Diarization
Group strings by sorted chars Group points by similarity Group audio by speaker
Hash key: sorted string Hash key: quantized vector Hash key: voice embedding
O(1) lookup LSH for fast similarity Vector similarity
Exact matching Approximate matching Threshold-based matching

All three use hash-based or similarity-based grouping to organize items efficiently.

High-Level Architecture

┌─────────────────────────────────────────────────────────────────┐
│                      Clustering System                           │
└─────────────────────────────────────────────────────────────────┘

                    ┌──────────────┐
                    │   Data Input │
                    │   (Features) │
                    └──────┬───────┘
                           │
        ┌──────────────────┼──────────────────┐
        │                  │                  │
┌───────▼────────┐ ┌──────▼──────┐ ┌────────▼────────┐
│ Batch          │ │ Streaming   │ │ Real-time       │
│ Clustering     │ │ Updates     │ │ Assignment      │
│                │ │             │ │                 │
│ - K-means      │ │ - Mini-batch│ │ - Nearest       │
│ - DBSCAN       │ │ - Online    │ │   cluster       │
│ - Hierarchical │ │   updates   │ │ - LSH lookup    │
└───────┬────────┘ └──────┬──────┘ └────────┬────────┘
        │                  │                  │
        └──────────────────┼──────────────────┘
                           │
                    ┌──────▼──────┐
                    │  Cluster    │
                    │  Storage    │
                    │             │
                    │ - Centroids │
                    │ - Metadata  │
                    │ - Assignments│
                    └──────┬──────┘
                           │
                    ┌──────▼──────┐
                    │   Query API │
                    │             │
                    │ - Find      │
                    │   cluster   │
                    │ - Find      │
                    │   similar   │
                    │ - Stats     │
                    └─────────────┘

Key Components

  1. Clustering Engine: Core algorithms (K-means, DBSCAN, etc.)
  2. Feature Store: Pre-computed embeddings and features
  3. Index: Fast similarity search (Faiss, Annoy)
  4. Cluster Store: Centroids, assignments, metadata
  5. Update Service: Incremental clustering updates
  6. Query API: Real-time cluster assignment and search

Component Deep-Dives

1. Clustering Engine - K-Means Implementation

K-means is the most widely used clustering algorithm:

import numpy as np
from typing import List, Tuple, Optional
from dataclasses import dataclass
import logging

@dataclass
class ClusterMetrics:
    """Metrics for cluster quality."""
    inertia: float  # Sum of squared distances to centroids
    silhouette_score: float  # Cluster separation quality
    n_iterations: int
    converged: bool

class KMeansClustering:
    """
    Production K-means clustering.
    
    Similar to Group Anagrams:
    - Anagrams: Group by exact match (sorted string)
    - K-means: Group by approximate match (nearest centroid)
    
    Both use hash-like keys for grouping:
    - Anagrams: hash = sorted(string)
    - K-means: hash = nearest_centroid_id
    """
    
    def __init__(
        self,
        n_clusters: int = 8,
        max_iters: int = 300,
        tol: float = 1e-4,
        init_method: str = "kmeans++",
        random_state: Optional[int] = None
    ):
        """
        Initialize K-means clusterer.
        
        Args:
            n_clusters: Number of clusters (k)
            max_iters: Maximum iterations
            tol: Convergence tolerance
            init_method: "random" or "kmeans++"
            random_state: Random seed
        """
        self.n_clusters = n_clusters
        self.max_iters = max_iters
        self.tol = tol
        self.init_method = init_method
        self.random_state = random_state
        
        self.centroids: Optional[np.ndarray] = None
        self.labels: Optional[np.ndarray] = None
        self.inertia: float = 0.0
        
        self.logger = logging.getLogger(__name__)
        
        if random_state is not None:
            np.random.seed(random_state)
    
    def fit(self, X: np.ndarray) -> 'KMeansClustering':
        """
        Fit K-means to data.
        
        Algorithm:
        1. Initialize k centroids
        2. Assign points to nearest centroid (like hash lookup)
        3. Update centroids to mean of assigned points
        4. Repeat until convergence
        
        Args:
            X: Data matrix of shape (n_samples, n_features)
            
        Returns:
            self
        """
        n_samples, n_features = X.shape
        
        if n_samples < self.n_clusters:
            raise ValueError(
                f"n_samples ({n_samples}) must be >= n_clusters ({self.n_clusters})"
            )
        
        # Initialize centroids
        self.centroids = self._initialize_centroids(X)
        
        # Iterative assignment and update
        for iteration in range(self.max_iters):
            # Assignment step: assign each point to nearest centroid
            # (Like grouping strings by sorted key)
            old_labels = self.labels
            self.labels = self._assign_clusters(X)
            
            # Update step: recompute centroids
            old_centroids = self.centroids.copy()
            self._update_centroids(X)
            
            # Check convergence
            centroid_shift = np.linalg.norm(self.centroids - old_centroids)
            
            if centroid_shift < self.tol:
                self.logger.info(f"Converged after {iteration + 1} iterations")
                break
        
        # Calculate final inertia
        self.inertia = self._calculate_inertia(X)
        
        return self
    
    def _initialize_centroids(self, X: np.ndarray) -> np.ndarray:
        """
        Initialize centroids.
        
        K-means++ initialization:
        - Choose first centroid randomly
        - Choose subsequent centroids with probability proportional to distance²
        - Spreads out initial centroids for better convergence
        """
        n_samples = X.shape[0]
        
        if self.init_method == "random":
            # Random initialization
            indices = np.random.choice(n_samples, self.n_clusters, replace=False)
            return X[indices].copy()
        
        elif self.init_method == "kmeans++":
            # K-means++ initialization
            centroids = []
            
            # Choose first centroid randomly
            first_idx = np.random.randint(n_samples)
            centroids.append(X[first_idx])
            
            # Choose remaining centroids
            for _ in range(1, self.n_clusters):
                # Calculate distances to nearest existing centroid
                distances = np.min([
                    np.linalg.norm(X - c, axis=1) ** 2
                    for c in centroids
                ], axis=0)
                
                # Choose next centroid with probability ∝ distance²
                probabilities = distances / distances.sum()
                next_idx = np.random.choice(n_samples, p=probabilities)
                centroids.append(X[next_idx])
            
            return np.array(centroids)
        
        else:
            raise ValueError(f"Unknown init_method: {self.init_method}")
    
    def _assign_clusters(self, X: np.ndarray) -> np.ndarray:
        """
        Assign each point to nearest centroid.
        
        This is the "grouping" step (like anagram grouping).
        
        Returns:
            Array of cluster labels
        """
        # Calculate distances to all centroids
        # Shape: (n_samples, n_clusters)
        distances = np.linalg.norm(
            X[:, np.newaxis] - self.centroids,
            axis=2
        )
        
        # Assign to nearest centroid
        labels = np.argmin(distances, axis=1)
        
        return labels
    
    def _update_centroids(self, X: np.ndarray):
        """
        Update centroids to mean of assigned points.
        
        If a cluster is empty, reinitialize that centroid.
        """
        for k in range(self.n_clusters):
            # Get points assigned to cluster k
            mask = self.labels == k
            
            if mask.sum() > 0:
                # Update to mean of assigned points
                self.centroids[k] = X[mask].mean(axis=0)
            else:
                # Empty cluster - reinitialize randomly
                self.logger.warning(f"Empty cluster {k}, reinitializing")
                random_idx = np.random.randint(len(X))
                self.centroids[k] = X[random_idx]
    
    def _calculate_inertia(self, X: np.ndarray) -> float:
        """
        Calculate inertia (within-cluster sum of squares).
        
        Lower inertia = tighter clusters.
        """
        inertia = 0.0
        
        for k in range(self.n_clusters):
            mask = self.labels == k
            if mask.sum() > 0:
                cluster_points = X[mask]
                centroid = self.centroids[k]
                
                # Sum of squared distances
                inertia += np.sum((cluster_points - centroid) ** 2)
        
        return inertia
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """
        Predict cluster labels for new data.
        
        This is like finding anagrams of a new string:
        - Hash the string (sort it)
        - Look up in hash table
        
        For K-means:
        - Calculate distances to centroids
        - Assign to nearest
        
        Args:
            X: Data matrix of shape (n_samples, n_features)
            
        Returns:
            Cluster labels
        """
        if self.centroids is None:
            raise ValueError("Model not fitted. Call fit() first.")
        
        distances = np.linalg.norm(
            X[:, np.newaxis] - self.centroids,
            axis=2
        )
        
        return np.argmin(distances, axis=1)
    
    def get_cluster_centers(self) -> np.ndarray:
        """Get cluster centroids."""
        return self.centroids.copy()
    
    def get_cluster_sizes(self) -> np.ndarray:
        """Get number of points in each cluster."""
        return np.bincount(self.labels, minlength=self.n_clusters)
    
    def calculate_silhouette_score(self, X: np.ndarray) -> float:
        """
        Calculate silhouette score for cluster quality.
        
        Score ranges from -1 to 1:
        - 1: Perfect clustering
        - 0: Overlapping clusters
        - -1: Wrong clustering
        """
        from sklearn.metrics import silhouette_score
        
        if len(np.unique(self.labels)) < 2:
            return 0.0
        
        return silhouette_score(X, self.labels)

2. DBSCAN - Density-Based Clustering

DBSCAN doesn’t require specifying k and finds arbitrary-shaped clusters:

from sklearn.neighbors import NearestNeighbors

class DBSCANClustering:
    """
    Density-Based Spatial Clustering (DBSCAN).
    
    Advantages over K-means:
    - No need to specify k
    - Finds arbitrary-shaped clusters
    - Handles noise/outliers
    
    Good for:
    - Geospatial data
    - Data with varying density
    - Anomaly detection
    """
    
    def __init__(self, eps: float = 0.5, min_samples: int = 5):
        """
        Initialize DBSCAN.
        
        Args:
            eps: Maximum distance for neighborhood
            min_samples: Minimum points for core point
        """
        self.eps = eps
        self.min_samples = min_samples
        
        self.labels: Optional[np.ndarray] = None
        self.core_sample_indices: Optional[np.ndarray] = None
    
    def fit(self, X: np.ndarray) -> 'DBSCANClustering':
        """
        Fit DBSCAN to data.
        
        Algorithm:
        1. Find core points (points with >= min_samples neighbors within eps)
        2. Form clusters by connecting core points
        3. Assign border points to nearest cluster
        4. Mark noise points as outliers (-1)
        """
        n_samples = X.shape[0]
        
        # Find neighbors for all points
        nbrs = NearestNeighbors(radius=self.eps).fit(X)
        neighborhoods = nbrs.radius_neighbors(X, return_distance=False)
        
        # Initialize labels (-1 = unvisited)
        labels = np.full(n_samples, -1, dtype=int)
        
        # Find core points
        core_samples = np.array([
            len(neighbors) >= self.min_samples
            for neighbors in neighborhoods
        ])
        
        self.core_sample_indices = np.where(core_samples)[0]
        
        # Assign clusters
        cluster_id = 0
        
        for idx in range(n_samples):
            # Skip if already labeled or not a core point
            if labels[idx] != -1 or not core_samples[idx]:
                continue
            
            # Start new cluster
            self._expand_cluster(idx, neighborhoods, labels, cluster_id, core_samples)
            cluster_id += 1
        
        self.labels = labels
        return self
    
    def _expand_cluster(
        self,
        seed_idx: int,
        neighborhoods: List[np.ndarray],
        labels: np.ndarray,
        cluster_id: int,
        core_samples: np.ndarray
    ):
        """
        Expand cluster from seed point using BFS.
        
        Similar to connected component search in graphs.
        """
        # Queue of points to process
        queue = [seed_idx]
        labels[seed_idx] = cluster_id
        
        while queue:
            current = queue.pop(0)
            
            # Add neighbors to queue if core point
            if core_samples[current]:
                for neighbor in neighborhoods[current]:
                    if labels[neighbor] == -1:
                        labels[neighbor] = cluster_id
                        queue.append(neighbor)
    
    def predict(self, X: np.ndarray, X_train: np.ndarray) -> np.ndarray:
        """
        Predict cluster for new points.
        
        Assign to nearest core point's cluster.
        """
        if self.labels is None:
            raise ValueError("Model not fitted")
        
        # Find nearest core point for each new point
        nbrs = NearestNeighbors(n_neighbors=1).fit(
            X_train[self.core_sample_indices]
        )
        
        distances, indices = nbrs.kneighbors(X)
        
        # Assign to nearest core point's cluster if within eps
        labels = np.full(len(X), -1, dtype=int)
        
        for i, (dist, idx) in enumerate(zip(distances, indices)):
            if dist[0] <= self.eps:
                core_idx = self.core_sample_indices[idx[0]]
                labels[i] = self.labels[core_idx]
        
        return labels

3. Hierarchical Clustering

Build a hierarchy of clusters (dendrogram):

from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist

class HierarchicalClustering:
    """
    Hierarchical (agglomerative) clustering.
    
    Advantages:
    - Creates hierarchy (dendrogram)
    - No need to specify k upfront
    - Deterministic
    
    Disadvantages:
    - O(N²) time and space
    - Doesn't scale to millions of points
    """
    
    def __init__(self, method: str = "ward", metric: str = "euclidean"):
        """
        Initialize hierarchical clustering.
        
        Args:
            method: Linkage method ("ward", "average", "complete", "single")
            metric: Distance metric
        """
        self.method = method
        self.metric = metric
        
        self.linkage_matrix: Optional[np.ndarray] = None
        self.labels: Optional[np.ndarray] = None
    
    def fit(self, X: np.ndarray, n_clusters: int) -> 'HierarchicalClustering':
        """
        Fit hierarchical clustering.
        
        Args:
            X: Data matrix
            n_clusters: Number of clusters to create
        """
        # Compute linkage matrix
        self.linkage_matrix = linkage(X, method=self.method, metric=self.metric)
        
        # Cut dendrogram to get clusters
        self.labels = fcluster(
            self.linkage_matrix,
            n_clusters,
            criterion='maxclust'
        ) - 1  # Convert to 0-indexed
        
        return self
    
    def predict(self, X: np.ndarray, X_train: np.ndarray, n_clusters: int) -> np.ndarray:
        """
        Predict cluster for new points.
        
        Assign to nearest training point's cluster.
        """
        from sklearn.neighbors import NearestNeighbors
        
        nbrs = NearestNeighbors(n_neighbors=1).fit(X_train)
        _, indices = nbrs.kneighbors(X)
        
        return self.labels[indices.flatten()]

4. Locality-Sensitive Hashing for Fast Clustering

For very large datasets, use LSH for approximate clustering:

from typing import Dict, Set, List
import hashlib

class LSHClustering:
    """
    Locality-Sensitive Hashing for fast approximate clustering.
    
    Similar to Group Anagrams:
    - Anagrams: Hash = sorted string (exact)
    - LSH: Hash = quantized vector (approximate)
    
    Both group similar items using hash keys.
    """
    
    def __init__(
        self,
        n_hash_functions: int = 10,
        n_hash_tables: int = 5,
        hash_size: int = 8
    ):
        """
        Initialize LSH clusterer.
        
        Args:
            n_hash_functions: Number of hash functions per table
            n_hash_tables: Number of hash tables
            hash_size: Size of hash (bits)
        """
        self.n_hash_functions = n_hash_functions
        self.n_hash_tables = n_hash_tables
        self.hash_size = hash_size
        
        # Hash tables: table_id -> {hash_key -> [point_ids]}
        self.hash_tables: List[Dict[str, List[int]]] = [
            {} for _ in range(n_hash_tables)
        ]
        
        # Random projection vectors for hashing
        self.projection_vectors: List[List[np.ndarray]] = []
    
    def fit(self, X: np.ndarray) -> 'LSHClustering':
        """
        Build LSH index.
        
        Args:
            X: Data matrix of shape (n_samples, n_features)
        """
        n_samples, n_features = X.shape
        
        # Generate random projection vectors
        for table_idx in range(self.n_hash_tables):
            table_projections = []
            
            for _ in range(self.n_hash_functions):
                # Random unit vector
                random_vec = np.random.randn(n_features)
                random_vec /= np.linalg.norm(random_vec)
                table_projections.append(random_vec)
            
            self.projection_vectors.append(table_projections)
        
        # Insert all points into hash tables
        for point_id, point in enumerate(X):
            self._insert_point(point_id, point)
        
        return self
    
    def _hash_point(self, point: np.ndarray, table_idx: int) -> str:
        """
        Hash a point using random projections.
        
        Similar to sorting string in anagram problem:
        - Anagrams: sorted chars create hash key
        - LSH: projection signs create hash key
        
        Returns:
            Hash key (binary string)
        """
        projections = self.projection_vectors[table_idx]
        
        # Sign of dot product with each projection vector
        hash_bits = [
            '1' if np.dot(point, proj) > 0 else '0'
            for proj in projections
        ]
        
        return ''.join(hash_bits)
    
    def _insert_point(self, point_id: int, point: np.ndarray):
        """Insert point into all hash tables."""
        for table_idx in range(self.n_hash_tables):
            hash_key = self._hash_point(point, table_idx)
            
            if hash_key not in self.hash_tables[table_idx]:
                self.hash_tables[table_idx][hash_key] = []
            
            self.hash_tables[table_idx][hash_key].append(point_id)
    
    def find_similar_points(
        self,
        query: np.ndarray,
        k: int = 10
    ) -> List[int]:
        """
        Find k similar points to query.
        
        Args:
            query: Query point
            k: Number of similar points to return
            
        Returns:
            List of point IDs
        """
        candidates = set()
        
        # Look up in all hash tables
        for table_idx in range(self.n_hash_tables):
            hash_key = self._hash_point(query, table_idx)
            
            # Get points with same hash
            if hash_key in self.hash_tables[table_idx]:
                candidates.update(self.hash_tables[table_idx][hash_key])
        
        # Return top k candidates
        return list(candidates)[:k]
    
    def get_clusters(self) -> List[Set[int]]:
        """
        Extract clusters from hash tables.
        
        Points in same hash bucket are in same cluster.
        """
        # Aggregate across all tables
        all_clusters = []
        
        for table in self.hash_tables:
            for hash_key, point_ids in table.items():
                if len(point_ids) > 1:
                    all_clusters.append(set(point_ids))
        
        # Merge overlapping clusters
        merged = self._merge_clusters(all_clusters)
        
        return merged
    
    def _merge_clusters(self, clusters: List[Set[int]]) -> List[Set[int]]:
        """Merge overlapping clusters."""
        if not clusters:
            return []
        
        merged = []
        current = clusters[0]
        
        for cluster in clusters[1:]:
            if current & cluster:  # Overlap
                current |= cluster
            else:
                merged.append(current)
                current = cluster
        
        merged.append(current)
        return merged

Data Flow

Batch Clustering Pipeline

1. Data Collection
   └─> Features from data lake/warehouse
   └─> Embeddings from model inference

2. Feature Engineering
   └─> Normalization/scaling
   └─> Dimensionality reduction (PCA, UMAP)
   └─> Feature selection

3. Clustering
   └─> Run K-means/DBSCAN/Hierarchical
   └─> Evaluate cluster quality
   └─> Store centroids and assignments

4. Indexing
   └─> Build fast similarity index (Faiss)
   └─> Store in cache (Redis)
   └─> Expose via API

5. Monitoring
   └─> Track cluster drift
   └─> Alert on quality degradation
   └─> Trigger retraining

Real-Time Assignment Flow

1. New point arrives
   └─> Feature extraction

2. Normalize features
   └─> Apply same scaling as training

3. Find nearest cluster
   └─> LSH lookup (approximate)
   └─> Or Faiss search (exact)

4. Return cluster ID + metadata
   └─> Cluster centroid
   └─> Similar points
   └─> Confidence score

5. Optional: Update cluster
   └─> Online learning
   └─> Mini-batch update

Scaling Strategies

Horizontal Scaling - Distributed K-Means

from pyspark.ml.clustering import KMeans as SparkKMeans
from pyspark.sql import SparkSession

class DistributedKMeans:
    """
    Distributed K-means using Spark.
    
    For datasets too large for single machine.
    """
    
    def __init__(self, n_clusters: int = 8):
        self.n_clusters = n_clusters
        self.spark = SparkSession.builder.appName("Clustering").getOrCreate()
        self.model = None
    
    def fit(self, data_path: str):
        """
        Fit K-means on distributed data.
        
        Args:
            data_path: Path to data (S3, HDFS, etc.)
        """
        # Load data
        df = self.spark.read.parquet(data_path)
        
        # Create K-means model
        kmeans = SparkKMeans(
            k=self.n_clusters,
            seed=42,
            featuresCol="features"
        )
        
        # Fit (distributed across cluster)
        self.model = kmeans.fit(df)
        
        return self
    
    def predict(self, data_path: str, output_path: str):
        """Predict clusters for new data."""
        df = self.spark.read.parquet(data_path)
        predictions = self.model.transform(df)
        predictions.write.parquet(output_path)

Mini-Batch K-Means for Streaming

class MiniBatchKMeans:
    """
    Mini-batch K-means for streaming data.
    
    Updates clusters incrementally as new data arrives.
    """
    
    def __init__(self, n_clusters: int = 8, batch_size: int = 100):
        self.n_clusters = n_clusters
        self.batch_size = batch_size
        
        self.centroids: Optional[np.ndarray] = None
        self.counts = np.zeros(n_clusters)  # Points per cluster
    
    def partial_fit(self, X: np.ndarray) -> 'MiniBatchKMeans':
        """
        Update clusters with mini-batch.
        
        Algorithm:
        1. Assign batch points to nearest centroid
        2. Update centroids with learning rate
        3. Use exponential moving average
        
        Args:
            X: Mini-batch of data
        """
        if self.centroids is None:
            # Initialize on first batch
            self.centroids = X[:self.n_clusters].copy()
        
        # Assign points to clusters
        labels = self._assign_clusters(X)
        
        # Update centroids
        for k in range(self.n_clusters):
            mask = labels == k
            n_k = mask.sum()
            
            if n_k > 0:
                # Exponential moving average
                learning_rate = n_k / (self.counts[k] + n_k)
                self.centroids[k] = (
                    (1 - learning_rate) * self.centroids[k] +
                    learning_rate * X[mask].mean(axis=0)
                )
                self.counts[k] += n_k
        
        return self
    
    def _assign_clusters(self, X: np.ndarray) -> np.ndarray:
        """Assign points to nearest centroid."""
        distances = np.linalg.norm(
            X[:, np.newaxis] - self.centroids,
            axis=2
        )
        return np.argmin(distances, axis=1)

Implementation: Complete System

import redis
import json
from typing import Dict, List, Optional
import numpy as np

class ProductionClusteringSystem:
    """
    Complete production clustering system.
    
    Features:
    - Multiple clustering algorithms
    - Fast similarity search
    - Incremental updates
    - Caching
    - Monitoring
    """
    
    def __init__(
        self,
        algorithm: str = "kmeans",
        n_clusters: int = 10,
        cache_enabled: bool = True
    ):
        self.algorithm = algorithm
        self.n_clusters = n_clusters
        
        # Choose clustering algorithm
        if algorithm == "kmeans":
            self.clusterer = KMeansClustering(n_clusters=n_clusters)
        elif algorithm == "dbscan":
            self.clusterer = DBSCANClustering()
        elif algorithm == "lsh":
            self.clusterer = LSHClustering()
        else:
            raise ValueError(f"Unknown algorithm: {algorithm}")
        
        # Cache for fast lookups
        self.cache_enabled = cache_enabled
        if cache_enabled:
            self.cache = redis.Redis(host='localhost', port=6379, db=0)
        
        # Training data (for incremental updates)
        self.X_train: Optional[np.ndarray] = None
        
        # Metrics
        self.request_count = 0
        self.cache_hits = 0
    
    def fit(self, X: np.ndarray) -> 'ProductionClusteringSystem':
        """Fit clustering model."""
        self.X_train = X.copy()
        self.clusterer.fit(X)
        
        # Cache centroids
        if self.cache_enabled and hasattr(self.clusterer, 'centroids'):
            self._cache_centroids()
        
        return self
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Predict cluster for new points."""
        self.request_count += len(X)
        
        # Try cache first
        if self.cache_enabled:
            cached = self._try_cache(X)
            if cached is not None:
                self.cache_hits += len(cached)
                return cached
        
        # Predict
        labels = self.clusterer.predict(X)
        
        # Cache results
        if self.cache_enabled:
            self._cache_predictions(X, labels)
        
        return labels
    
    def find_similar(
        self,
        query: np.ndarray,
        k: int = 10
    ) -> List[int]:
        """
        Find k similar points to query.
        
        Returns indices of similar points in training data.
        """
        # Get query's cluster
        cluster_id = self.predict(query.reshape(1, -1))[0]
        
        # Find points in same cluster
        if hasattr(self.clusterer, 'labels'):
            same_cluster = np.where(self.clusterer.labels == cluster_id)[0]
            
            if len(same_cluster) > k:
                # Calculate distances within cluster
                distances = np.linalg.norm(
                    self.X_train[same_cluster] - query,
                    axis=1
                )
                
                # Return k nearest
                nearest_indices = np.argsort(distances)[:k]
                return same_cluster[nearest_indices].tolist()
            
            return same_cluster.tolist()
        
        return []
    
    def get_cluster_info(self, cluster_id: int) -> Dict:
        """Get information about a cluster."""
        if not hasattr(self.clusterer, 'labels'):
            return {}
        
        mask = self.clusterer.labels == cluster_id
        cluster_points = self.X_train[mask]
        
        return {
            "cluster_id": cluster_id,
            "size": int(mask.sum()),
            "centroid": (
                self.clusterer.centroids[cluster_id].tolist()
                if hasattr(self.clusterer, 'centroids')
                else None
            ),
            "mean": cluster_points.mean(axis=0).tolist(),
            "std": cluster_points.std(axis=0).tolist(),
        }
    
    def _cache_centroids(self):
        """Cache cluster centroids in Redis."""
        centroids = self.clusterer.get_cluster_centers()
        
        for i, centroid in enumerate(centroids):
            key = f"centroid:{i}"
            self.cache.set(key, json.dumps(centroid.tolist()))
    
    def _try_cache(self, X: np.ndarray) -> Optional[np.ndarray]:
        """Try to get predictions from cache."""
        # Simple caching by rounding features
        # In production: use better hashing
        return None
    
    def _cache_predictions(self, X: np.ndarray, labels: np.ndarray):
        """Cache predictions."""
        # Implement caching strategy
        pass
    
    def get_metrics(self) -> Dict:
        """Get system metrics."""
        return {
            "algorithm": self.algorithm,
            "n_clusters": self.n_clusters,
            "request_count": self.request_count,
            "cache_hit_rate": (
                self.cache_hits / self.request_count
                if self.request_count > 0 else 0.0
            ),
            "training_samples": (
                len(self.X_train) if self.X_train is not None else 0
            ),
        }


# Example usage
if __name__ == "__main__":
    # Generate sample data
    from sklearn.datasets import make_blobs
    
    X, y_true = make_blobs(
        n_samples=10000,
        n_features=10,
        centers=5,
        random_state=42
    )
    
    # Create clustering system
    system = ProductionClusteringSystem(
        algorithm="kmeans",
        n_clusters=5
    )
    
    # Fit
    system.fit(X[:8000])  # Train on 80%
    
    # Predict
    labels = system.predict(X[8000:])  # Test on 20%
    
    print(f"Predicted {len(labels)} samples")
    print(f"Metrics: {system.get_metrics()}")
    
    # Find similar points
    query = X[8000]
    similar = system.find_similar(query, k=5)
    print(f"Similar points to query: {similar}")
    
    # Get cluster info
    info = system.get_cluster_info(0)
    print(f"Cluster 0 info: {info}")

Real-World Case Study: Spotify’s Music Clustering

Spotify’s Approach

Spotify clusters 80M+ songs for recommendation:

Architecture:

  1. Feature extraction:
    • Audio features: tempo, key, loudness, etc.
    • Collaborative filtering: user listening patterns
    • NLP: song metadata, lyrics
  2. Hierarchical clustering:
    • Genre-level clusters (rock, pop, etc.)
    • Sub-genre clusters (indie rock, classic rock)
    • Micro-clusters for precise recommendations
  3. Real-time assignment:
    • New songs assigned via nearest centroid
    • Updated daily with mini-batch K-means
    • LSH for fast similarity search
  4. Hybrid approach:
    • DBSCAN for discovering new genres
    • K-means for stable clusters
    • Hierarchical for taxonomy

Results:

  • 80M+ songs clustered
  • <50ms latency for song similarity
  • +25% engagement from better recommendations
  • Daily updates for new releases

Key Lessons

  1. Multiple algorithms work together - no one-size-fits-all
  2. Feature engineering matters most - better features > better algorithm
  3. Hierarchical structure helps - multi-level clustering
  4. Incremental updates essential - can’t recluster daily
  5. LSH enables scale - exact search doesn’t scale to 80M

Cost Analysis

Cost Breakdown (1M data points, daily clustering)

Component Single Machine Distributed (10 nodes) Cost Difference
Training (daily) 2 hours @ $2/hr 15 min @ $20/hr -$1/day
Storage 10GB @ $0.10/GB/month 10GB @ $0.10/GB/month Same
Queries (10K/sec) $500/day $100/day -$400/day
Total $502/day $121/day -76%

Optimization strategies:

  1. Mini-batch K-means:
    • Incremental updates vs full retraining
    • Savings: 80% compute cost
  2. LSH for queries:
    • Approximate vs exact search
    • Savings: 90% query latency
  3. Caching:
    • Cache frequent queries
    • Hit rate 30% = 30% cost reduction
  4. Dimensionality reduction:
    • PCA to 50D from 1000D
    • Savings: 95% storage, 80% compute

Key Takeaways

Clustering is everywhere: Recommendations, search, segmentation, anomaly detection

K-means is workhorse: Simple, fast, scales well

DBSCAN for arbitrary shapes: No need to specify k, handles outliers

LSH enables scale: Hash-based approximate clustering for billions of points

Mini-batch for streaming: Incremental updates without full retraining

Same pattern as anagrams: Hash-based grouping (exact or approximate)

Feature engineering crucial: Better features » better algorithm

Multiple algorithms better: Hierarchical structure with different methods

Monitoring essential: Track cluster drift and quality over time

Hybrid approaches work: Combine multiple algorithms for best results

All three topics use hash-based or similarity-based grouping:

DSA (Group Anagrams):

  • Hash key: sorted string (exact match)
  • Grouping: O(1) hash table lookup
  • Result: exact anagram groups

ML System Design (Clustering Systems):

  • Hash key: quantized vector or nearest centroid
  • Grouping: approximate similarity
  • Result: data point clusters

Speech Tech (Speaker Diarization):

  • Hash key: voice embedding
  • Grouping: similarity threshold
  • Result: speaker clusters

Universal Pattern

Hash-Based Grouping:

# Generic pattern for all three
def group_items(items, hash_function):
    groups = defaultdict(list)
    
    for item in items:
        key = hash_function(item)  # Create hash key
        groups[key].append(item)    # Group by key
    
    return list(groups.values())

Applications:

  • Anagrams: hash_function = sorted
  • Clustering: hash_function = nearest_centroid
  • Diarization: hash_function = voice_embedding

Originally published at: arunbaby.com/ml-system-design/0015-clustering-systems

If you found this helpful, consider sharing it with others who might benefit.