Distributed ML Systems
Design distributed ML systems that scale to billions of predictions: Master replication, sharding, consensus, and fault tolerance for production ML.
Problem Statement
Design a distributed machine learning system that can:
- Handle billions of predictions per day across multiple regions
- Train models on terabytes of data across multiple machines
- Serve models with low latency (<100ms) and high availability (99.99%)
- Handle failures gracefully without data loss or service disruption
- Scale horizontally by adding more machines
Why Distributed Systems?
The fundamental constraint: A single machine canβt handle:
Data:
- Training data: 10TB+ (wonβt fit in RAM)
- Model size: 100GB+ (large language models, embeddings)
- Inference load: 100,000 requests/sec (CPU melts π₯)
Computation:
- Training time: Days/weeks on single GPU
- Inference: Canβt serve millions of users from one server
Geography:
- Users worldwide: Tokyo, London, New York, SΓ£o Paulo
- Latency: Canβt serve Tokyo users from Virginia (150ms+ RTT)
Reliability:
- Single machine fails β Entire service down β
- Need redundancy and fault tolerance
Real-World Scale Examples
Company | Scale | Challenge |
---|---|---|
Google Search | 8.5B searches/day | Distributed indexing + serving |
Netflix | 200M users, 1B hours/day | Personalization at scale |
Uber | 19M trips/day | Real-time matching + prediction |
Meta | 3B users | Social graph + recommendation |
Common pattern: All use distributed ML systems!
Understanding Distributed Systems Fundamentals
What Makes Systems βDistributedβ?
Definition: Multiple computers working together as one system.
Simple analogy: Restaurant kitchen
- Single machine: One chef makes everything (slow, bottleneck)
- Distributed: Multiple chefs, each specializing (fast, parallel)
But coordination is hard:
- How do chefs know what to cook?
- What if a chef is sick?
- How to avoid making duplicate orders?
These are distributed systems problems!
The CAP Theorem
CAP Theorem states: You can only have 2 of 3:
- Consistency (C): All nodes see same data at same time
- Availability (A): System always responds (even if some nodes down)
- Partition Tolerance (P): System works despite network failures
In practice: Network partitions happen, so you must have P. Real choice: Consistency (CP) vs Availability (AP)
Example scenarios:
Scenario: Network split between US and EU data centers
CP System (Choose Consistency):
- Reject writes until partition healed
- Data stays consistent
- But users in EU can't use system! β
AP System (Choose Availability):
- Accept writes in both regions
- Users happy! β
- But data may conflict later (eventual consistency)
For ML systems:
- Training: CP (want consistent data)
- Serving: AP (availability critical for user experience)
Key Concepts for Junior Engineers
1. Horizontal vs Vertical Scaling
Vertical Scaling (Scale UP):
1 machine β Bigger machine
4 CPU β 64 CPU
16GB RAM β 512GB RAM
Pros: Simple, no code changes
Cons: Expensive, limited (can't buy infinite RAM), single point of failure
Horizontal Scaling (Scale OUT):
1 machine β 10 machines
Pros: Cheaper, unlimited, fault-tolerant
Cons: Complex (distributed systems problems!)
ML systems need horizontal scaling because:
- Data too big for one machine
- Training too slow on one machine
- Serving load too high for one machine
2. Replication vs Sharding
Replication: Same data on multiple machines
Machine 1: [A, B, C, D]
Machine 2: [A, B, C, D] β Same data!
Machine 3: [A, B, C, D]
Use case: High availability, load distribution
Example: Model weights replicated to 100 servers
Sharding: Different data on each machine
Machine 1: [A, B]
Machine 2: [C, D] β Different data!
Machine 3: [E, F]
Use case: Data too big for one machine
Example: Training data split across 10 machines
3. Synchronous vs Asynchronous
Synchronous: Wait for response before continuing
result = call_other_service() # Block here
process(result) # Wait until call returns
- Pros: Simple, consistent
- Cons: Slow (latency adds up)
Asynchronous: Donβt wait, continue immediately
future = call_other_service_async() # Don't block
do_other_work() # Continue immediately
result = future.get() # Get result when needed
- Pros: Fast, better resource usage
- Cons: Complex, harder to debug
Architecture Patterns
Pattern 1: Master-Worker (for Training)
Use case: Distributed model training
βββββββββββββββββββββββββββββββββββββββββββββββ
β MASTER NODE β
β β’ Coordinates workers β
β β’ Aggregates gradients β
β β’ Updates global model β
ββββββββββ¬βββββββββββ¬βββββββββββ¬βββββββββββββββ
β β β
ββββββΌβββββ ββββΌββββββ ββββΌββββββ
βWorker 1 β βWorker 2β βWorker 3β
β GPU 1 β β GPU 2 β β GPU 3 β
βBatch 1 β βBatch 2 β βBatch 3 β
βββββββββββ ββββββββββ ββββββββββ
How it works:
- Master distributes data batches to workers
- Each worker computes gradients on its batch
- Workers send gradients back to master
- Master averages gradients, updates model
- Master broadcasts updated model to workers
- Repeat
Python implementation:
class MasterNode:
"""
Master node for distributed training
Coordinates multiple worker nodes
"""
def __init__(self, model, workers):
self.model = model
self.workers = workers
self.global_step = 0
def train_step(self, data_batches):
"""
One distributed training step
1. Send model to workers
2. Workers compute gradients
3. Aggregate gradients
4. Update model
"""
# Distribute work to workers
futures = []
for worker, batch in zip(self.workers, data_batches):
# Send model and data to worker
future = worker.compute_gradients_async(
self.model.state_dict(),
batch
)
futures.append(future)
# Wait for all workers (synchronous)
gradients = [future.get() for future in futures]
# Aggregate gradients (averaging)
avg_gradients = self._average_gradients(gradients)
# Update model
self.model.update(avg_gradients)
self.global_step += 1
return self.model
def _average_gradients(self, gradients_list):
"""Average gradients from all workers"""
avg_grads = {}
for param_name in gradients_list[0].keys():
# Average this parameter's gradients
param_grads = [g[param_name] for g in gradients_list]
avg_grads[param_name] = sum(param_grads) / len(param_grads)
return avg_grads
class WorkerNode:
"""
Worker node that computes gradients
"""
def __init__(self, worker_id, device='cuda'):
self.worker_id = worker_id
self.device = device
def compute_gradients_async(self, model_state, batch):
"""
Compute gradients on local batch
Returns: Future that will contain gradients
"""
import concurrent.futures
executor = concurrent.futures.ThreadPoolExecutor()
future = executor.submit(
self._compute_gradients,
model_state,
batch
)
return future
def _compute_gradients(self, model_state, batch):
"""Actually compute gradients"""
import torch
# Load model
model = load_model()
model.load_state_dict(model_state)
model.to(self.device)
# Forward + backward
loss = model(batch)
loss.backward()
# Extract gradients
gradients = {
name: param.grad.cpu()
for name, param in model.named_parameters()
}
return gradients
Challenges:
- Straggler problem: Slowest worker delays everyone
- Solution: Asynchronous updates, backup tasks
- Communication overhead: Sending gradients is expensive
- Solution: Gradient compression, local updates
- Fault tolerance: What if worker crashes?
- Solution: Checkpoint frequently, redistribute work
Pattern 2: Load Balancer + Replicas (for Serving)
Use case: Serving ML predictions at scale
ββββββββββββββββ
Requests βββ βLoad Balancer β
β (Round Robin)β
ββββββββ¬ββββββββ
β
ββββββββββββββββββΌβββββββββββββββββ
βΌ βΌ βΌ
βββββββββββ βββββββββββ βββββββββββ
β Replica 1β βReplica 2β βReplica 3β
β Model β β Model β β Model β
β+ Cache β β+ Cache β β+ Cache β
βββββββββββ βββββββββββ βββββββββββ
Benefits:
- High availability: If one replica dies, others handle load
- Load distribution: 10K req/sec across 10 replicas = 1K each
- Zero-downtime deploys: Update replicas one at a time
Implementation:
class LoadBalancer:
"""
Simple round-robin load balancer
Distributes requests across healthy replicas
"""
def __init__(self, replicas):
self.replicas = replicas
self.current_index = 0
self.health_checker = HealthChecker(replicas)
self.health_checker.start()
def route_request(self, request):
"""
Route request to healthy replica
Uses round-robin for simplicity
"""
# Get healthy replicas
healthy = self.health_checker.get_healthy_replicas()
if not healthy:
raise Exception("No healthy replicas available!")
# Round-robin selection
replica = healthy[self.current_index % len(healthy)]
self.current_index += 1
# Forward request
try:
response = replica.predict(request)
return response
except Exception as e:
# Retry with different replica
return self._retry_request(request, exclude=[replica])
def _retry_request(self, request, exclude=None):
"""Retry failed request on different replica"""
exclude = exclude or []
healthy = [
r for r in self.health_checker.get_healthy_replicas()
if r not in exclude
]
if not healthy:
raise Exception("All replicas failed")
return healthy[0].predict(request)
class HealthChecker:
"""
Continuously monitor replica health
Marks unhealthy replicas so LB doesn't route to them
"""
def __init__(self, replicas, check_interval=10):
self.replicas = replicas
self.check_interval = check_interval
self.health_status = {r: True for r in replicas}
self.running = False
def start(self):
"""Start health checking in background"""
import threading
self.running = True
self.thread = threading.Thread(
target=self._health_check_loop,
daemon=True
)
self.thread.start()
def _health_check_loop(self):
"""Continuously check replica health"""
import time
while self.running:
for replica in self.replicas:
is_healthy = replica.health_check()
self.health_status[replica] = is_healthy
if not is_healthy:
print(f"β οΈ Replica {replica.id} unhealthy!")
time.sleep(self.check_interval)
def get_healthy_replicas(self):
"""Get list of currently healthy replicas"""
return [
replica for replica in self.replicas
if self.health_status[replica]
]
Pattern 3: Pub-Sub for Async Communication
Use case: Model updates, feature updates, async tasks
βββββββββββββββββ
β Message Bus β
β (Kafka) β
βββββββββ¬ββββββββ
β
βββββββββββββββββΌββββββββββββββββ
βΌ βΌ βΌ
ββββββββββββββββ ββββββββββββββββ ββββββββββββββββ
β Subscriber 1 β β Subscriber 2 β β Subscriber 3 β
β Update model β β Update cache β β Log metrics β
ββββββββββββββββ ββββββββββββββββ ββββββββββββββββ
When to use:
- Model deployment: Notify all servers to reload model
- Feature updates: Broadcast new feature values
- Logging: Send metrics/logs asynchronously
- Training triggers: Data arrives β trigger training job
Implementation:
class PubSubSystem:
"""
Publish-Subscribe system for async communication
Publishers send messages, subscribers receive them
"""
def __init__(self):
self.subscribers = {} # topic -> [subscribers]
def subscribe(self, topic, callback):
"""
Subscribe to a topic
Args:
topic: Topic name (e.g., 'model.updated')
callback: Function to call when message received
"""
if topic not in self.subscribers:
self.subscribers[topic] = []
self.subscribers[topic].append(callback)
print(f"β Subscribed to {topic}")
def publish(self, topic, message):
"""
Publish message to topic
All subscribers will receive it asynchronously
"""
if topic not in self.subscribers:
return
for callback in self.subscribers[topic]:
# Call asynchronously (non-blocking)
import threading
thread = threading.Thread(
target=callback,
args=(message,)
)
thread.start()
print(f"π’ Published to {topic}: {message}")
# Example usage
pubsub = PubSubSystem()
# Subscriber 1: Model server that reloads on updates
def reload_model(message):
print(f"π Reloading model: {message['model_version']}")
# Load new model...
pubsub.subscribe('model.updated', reload_model)
# Subscriber 2: Cache that invalidates on updates
def invalidate_cache(message):
print(f"ποΈ Invalidating cache for: {message['model_version']}")
# Clear cache...
pubsub.subscribe('model.updated', invalidate_cache)
# Publisher: Training job publishes when done
def training_complete(model_path, version):
pubsub.publish('model.updated', {
'model_path': model_path,
'model_version': version,
'timestamp': time.time()
})
# Trigger
training_complete('s3://models/v123', 'v123')
# Both subscribers receive message asynchronously!
Handling Failures
Key principle: In distributed systems, failures are normal, not exceptional!
Types of Failures
- Machine failure: Server crashes
- Network partition: Network splits, canβt communicate
- Slow nodes: βStragglersβ delay entire system
- Corrupted data: Silent data corruption
- Cascading failures: One failure triggers others
Fault Tolerance Strategies
1. Replication (Multiple Copies)
class ReplicatedStorage:
"""
Store data on multiple nodes
If one fails, others have copy
"""
def __init__(self, nodes, replication_factor=3):
self.nodes = nodes
self.replication_factor = replication_factor
def write(self, key, value):
"""
Write to multiple nodes
Succeeds if majority succeed (quorum)
"""
# Pick nodes to write to
target_nodes = self._pick_nodes(key, self.replication_factor)
# Write to all (parallel)
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(node.write, key, value)
for node in target_nodes
]
# Wait for majority
successes = sum(1 for f in futures if f.result())
# Require majority for success (quorum)
quorum = (self.replication_factor // 2) + 1
if successes >= quorum:
return True
else:
raise Exception(f"Write failed: only {successes}/{self.replication_factor} succeeded")
def read(self, key):
"""
Read from multiple nodes, return most recent
Handles node failures gracefully
"""
target_nodes = self._pick_nodes(key, self.replication_factor)
# Read from all
values = []
for node in target_nodes:
try:
value = node.read(key)
values.append(value)
except Exception:
# Node failed, skip it
continue
if not values:
raise Exception("All replicas failed!")
# Return most recent (highest version)
return max(values, key=lambda v: v['version'])
2. Checkpointing (Save Progress)
class CheckpointedTraining:
"""
Save training progress periodically
If crash, resume from last checkpoint
"""
def __init__(self, model, checkpoint_dir, checkpoint_every=1000):
self.model = model
self.checkpoint_dir = checkpoint_dir
self.checkpoint_every = checkpoint_every
self.global_step = 0
def train(self, data_loader):
"""Train with checkpointing"""
# Try to resume from checkpoint
self.global_step = self._load_checkpoint()
for batch in data_loader:
# Skip batches we've already processed
if self.global_step < batch.id:
continue
# Training step
loss = self.model.train_step(batch)
self.global_step += 1
# Checkpoint periodically
if self.global_step % self.checkpoint_every == 0:
self._save_checkpoint()
print(f"β Checkpoint saved at step {self.global_step}")
def _save_checkpoint(self):
"""Save model + training state"""
import torch
checkpoint = {
'model_state': self.model.state_dict(),
'global_step': self.global_step,
'timestamp': time.time()
}
path = f"{self.checkpoint_dir}/ckpt-{self.global_step}.pt"
torch.save(checkpoint, path)
def _load_checkpoint(self):
"""Load latest checkpoint if exists"""
import glob
import torch
checkpoints = glob.glob(f"{self.checkpoint_dir}/ckpt-*.pt")
if not checkpoints:
return 0
# Load latest
latest = max(checkpoints, key=lambda p: int(p.split('-')[1].split('.')[0]))
checkpoint = torch.load(latest)
self.model.load_state_dict(checkpoint['model_state'])
print(f"β Resumed from step {checkpoint['global_step']}")
return checkpoint['global_step']
3. Circuit Breaker (Prevent Cascading Failures)
class CircuitBreaker:
"""
Prevent cascading failures
If service keeps failing, stop calling it (open circuit)
Give it time to recover, then try again
"""
def __init__(self, failure_threshold=5, timeout=60):
self.failure_threshold = failure_threshold
self.timeout = timeout
self.failures = 0
self.state = 'closed' # closed, open, half_open
self.last_failure_time = 0
def call(self, func, *args, **kwargs):
"""
Call function with circuit breaker protection
"""
import time
# Check if circuit is open
if self.state == 'open':
# Check if timeout passed
if time.time() - self.last_failure_time > self.timeout:
self.state = 'half_open'
print("π Circuit half-open, trying again...")
else:
raise Exception("Circuit breaker OPEN - service unavailable")
# Try the call
try:
result = func(*args, **kwargs)
# Success! Reset if we were half-open
if self.state == 'half_open':
self.state = 'closed'
self.failures = 0
print("β Circuit closed - service recovered")
return result
except Exception as e:
# Failure
self.failures += 1
self.last_failure_time = time.time()
# Open circuit if too many failures
if self.failures >= self.failure_threshold:
self.state = 'open'
print(f"β οΈ Circuit breaker OPEN after {self.failures} failures")
raise e
# Example usage
circuit_breaker = CircuitBreaker(failure_threshold=3, timeout=30)
def call_unreliable_service(data):
"""This service sometimes fails"""
import random
if random.random() < 0.5:
raise Exception("Service failed!")
return "Success"
# Try calling with circuit breaker
for i in range(10):
try:
result = circuit_breaker.call(call_unreliable_service, "data")
print(f"Request {i}: {result}")
except Exception as e:
print(f"Request {i}: {e}")
time.sleep(1)
Consistency Models
Strong Consistency
Guarantee: All reads see the most recent write
class StronglyConsistentStore:
"""
Every read returns the latest write
Achieved by: Single master, synchronous replication
"""
def __init__(self):
self.master = {} # Single source of truth
self.replicas = [{}, {}] # Read replicas
self.version = 0
def write(self, key, value):
"""
Write to master, then synchronously replicate
Slow but consistent!
"""
# Update version
self.version += 1
# Write to master
self.master[key] = {'value': value, 'version': self.version}
# Synchronously replicate to all replicas
for replica in self.replicas:
replica[key] = {'value': value, 'version': self.version}
# Only return after all replicas updated
print(f"β Write {key}={value} replicated to all")
def read(self, key):
"""
Read from master (always latest)
"""
return self.master.get(key, {}).get('value')
Pros: Simple to reason about Cons: Slow (sync replication), single point of failure
Eventual Consistency
Guarantee: Reads eventually see the latest write (but not immediately)
class EventuallyConsistentStore:
"""
Reads may see stale data temporarily
Achieved by: Asynchronous replication
"""
def __init__(self):
self.replicas = [{}, {}, {}]
self.version = 0
def write(self, key, value):
"""
Write to one replica, asynchronously propagate
Fast but eventually consistent
"""
self.version += 1
# Write to first replica immediately
self.replicas[0][key] = {'value': value, 'version': self.version}
# Asynchronously replicate to others
import threading
for replica in self.replicas[1:]:
thread = threading.Thread(
target=self._async_replicate,
args=(replica, key, value, self.version)
)
thread.start()
# Return immediately (don't wait for replication)
return "OK"
def _async_replicate(self, replica, key, value, version):
"""Replicate asynchronously"""
import time
time.sleep(0.1) # Simulate network delay
replica[key] = {'value': value, 'version': version}
def read(self, key):
"""
Read from random replica
May return stale data if replication not complete!
"""
import random
replica = random.choice(self.replicas)
return replica.get(key, {}).get('value')
Pros: Fast, highly available Cons: Can read stale data temporarily
For ML systems:
- Model weights: Eventual consistency OK (small staleness acceptable)
- Feature store: Strong consistency for critical features
- Predictions: No consistency needed (stateless)
Consensus Algorithms
Problem: How do multiple nodes agree on a value when some might fail?
Example: Leader election - which node should be the master?
Understanding the Challenge
Scenario: 3 nodes need to elect a leader
Node A thinks: "I should be leader!"
Node B thinks: "No, I should be leader!"
Node C crashes before voting
Challenge:
- Network delays mean messages arrive out of order
- Nodes might fail mid-process
- Must guarantee exactly ONE leader elected
This is the consensus problem!
Raft Algorithm (Simplified)
Raft is easier to understand than Paxos, achieving the same goal.
Key concepts:
- States: Each node is in one of three states:
- Follower: Accepts commands from leader
- Candidate: Trying to become leader
- Leader: Sends commands to followers
- Terms: Time divided into terms (like presidencies)
- Each term has at most one leader
- Term number increases after each election
- Election process:
class RaftNode:
"""
Simplified Raft consensus node
Real implementation is more complex!
"""
def __init__(self, node_id, peers):
self.node_id = node_id
self.peers = peers
self.state = 'follower'
self.current_term = 0
self.voted_for = None
import random, time
self.election_timeout = random.uniform(150, 300) # ms
self.last_heartbeat = time.time()
def start_election(self):
"""
Become candidate and request votes
Called when election timeout expires without hearing from leader
"""
# Increment term
self.current_term += 1
self.state = 'candidate'
self.voted_for = self.node_id # Vote for self
print(f"Node {self.node_id}: Starting election for term {self.current_term}")
# Request votes from all peers
votes_received = 1 # Self vote
for peer in self.peers:
if self._request_vote(peer):
votes_received += 1
# Check if won election (majority)
majority = (len(self.peers) + 1) // 2 + 1
if votes_received >= majority:
self._become_leader()
else:
# Lost election, revert to follower
self.state = 'follower'
def _request_vote(self, peer):
"""
Request vote from peer
Peer grants vote if:
- Haven't voted in this term yet
- Candidate's log is at least as up-to-date
"""
request = {
'term': self.current_term,
'candidate_id': self.node_id
}
response = peer.handle_vote_request(request)
return response.get('vote_granted', False)
def _become_leader(self):
"""
Become leader for this term
Start sending heartbeats to maintain leadership
"""
self.state = 'leader'
print(f"Node {self.node_id}: Became leader for term {self.current_term}")
# Send heartbeats to all followers
self._send_heartbeats()
def _send_heartbeats(self):
"""
Send periodic heartbeats to prevent new elections
Leader must send heartbeats < election_timeout
"""
import time
while self.state == 'leader':
for peer in self.peers:
peer.receive_heartbeat({
'term': self.current_term,
'leader_id': self.node_id
})
time.sleep(0.05) # 50ms heartbeat interval
def receive_heartbeat(self, message):
"""
Receive heartbeat from leader
Reset election timeout
"""
# Check term
if message['term'] >= self.current_term:
self.current_term = message['term']
self.state = 'follower'
self.last_heartbeat = time.time()
# Reset election timeout
return {'success': True}
def handle_vote_request(self, request):
"""
Handle vote request from candidate
Grant vote if haven't voted in this term yet
"""
# Check term
if request['term'] < self.current_term:
return {'vote_granted': False}
# Check if already voted
if self.voted_for is None or self.voted_for == request['candidate_id']:
self.voted_for = request['candidate_id']
self.current_term = request['term']
return {'vote_granted': True}
return {'vote_granted': False}
Why this works:
- Split votes: If multiple candidates, may get no majority β retry
- Random timeouts: Reduces likelihood of split votes
- Term numbers: Ensures old messages ignored
- Majority requirement: Ensures at most one leader per term
Use in ML systems:
- Distributed training: Elect master node
- Model serving: Elect coordinator for A/B test assignments
- Feature store: Elect primary for writes
Data Partitioning Strategies
Problem: Training data is 10TB. Canβt fit on one machine!
Solution: Partition (shard) across multiple machines.
Strategy 1: Range Partitioning
Idea: Split data by key ranges
User IDs: 0 - 1,000,000
Partition 1: Users 0 - 250,000
Partition 2: Users 250,001 - 500,000
Partition 3: Users 500,001 - 750,000
Partition 4: Users 750,001 - 1,000,000
Pros: Simple, range queries efficient Cons: Hotspots if data skewed
Example:
class RangePartitioner:
"""
Partition data by key ranges
"""
def __init__(self, partitions):
self.partitions = partitions # [(0, 250000, node1), (250001, 500000, node2), ...]
def get_partition(self, key):
"""
Find which partition handles this key
"""
for start, end, node in self.partitions:
if start <= key <= end:
return node
raise ValueError(f"Key {key} not in any partition")
def write(self, key, value):
"""Write to appropriate partition"""
node = self.get_partition(key)
node.write(key, value)
def read(self, key):
"""Read from appropriate partition"""
node = self.get_partition(key)
return node.read(key)
# Usage
partitioner = RangePartitioner([
(0, 250000, node1),
(250001, 500000, node2),
(500001, 750000, node3),
(750001, 1000000, node4)
])
# Write user data
partitioner.write(user_id=123456, value={'name': 'Alice', ...})
# Read user data
user_data = partitioner.read(user_id=123456)
Hotspot problem:
If most users have IDs 0-100,000:
Partition 1: Overloaded! π
Partition 2-4: Idle π€
Unbalanced load!
Strategy 2: Hash Partitioning
Idea: Hash key, use hash to determine partition
key β hash(key) β partition
Example:
user_id = 123456
hash(123456) = 42
partition = 42 % 4 = 2
β Send to Partition 2
Pros: Even distribution (no hotspots) Cons: Range queries impossible
class HashPartitioner:
"""
Partition data by hash of key
"""
def __init__(self, nodes):
self.nodes = nodes
self.num_nodes = len(nodes)
def get_partition(self, key):
"""
Hash key to determine partition
"""
# Hash key
hash_value = hash(key)
# Modulo to get partition index
partition_idx = hash_value % self.num_nodes
return self.nodes[partition_idx]
def write(self, key, value):
node = self.get_partition(key)
node.write(key, value)
def read(self, key):
node = self.get_partition(key)
return node.read(key)
# Usage
partitioner = HashPartitioner([node1, node2, node3, node4])
# Even distribution!
partitioner.write(1, 'data1') # node2
partitioner.write(2, 'data2') # node4
partitioner.write(3, 'data3') # node1
partitioner.write(123456, 'data') # node2
Problem with adding/removing nodes:
With 4 nodes: hash(key) % 4 = 2 β node2
Add node5 (now 5 nodes): hash(key) % 5 = 4 β node5
All keys need remapping! π±
Expensive!
Strategy 3: Consistent Hashing
Idea: Minimize remapping when adding/removing nodes
How it works:
- Hash both keys and nodes to same space (e.g., 0-360Β°)
- Place nodes on circle
- Key goes to next node clockwise
Circle (0-360Β°):
0Β°
|
Node B (45Β°)
|
Node C (120Β°)
|
Node D (200Β°)
|
Node A (290Β°)
|
360Β° (= 0Β°)
Key x hashes to 100Β° β Goes to Node C (next clockwise at 120Β°)
Key y hashes to 250Β° β Goes to Node A (next clockwise at 290Β°)
Add Node E at 160Β°:
- Only keys between 120Β° and 160Β° move from C to E
- All other keys unchanged!
import bisect
class ConsistentHashRing:
"""
Consistent hashing for minimal remapping
"""
def __init__(self, nodes, virtual_nodes=150):
self.virtual_nodes = virtual_nodes
self.ring = []
self.node_map = {}
for node in nodes:
self._add_node(node)
def _add_node(self, node):
"""
Add node to ring with multiple virtual nodes
Virtual nodes for better distribution
"""
for i in range(self.virtual_nodes):
# Hash node + replica number
virtual_key = f"{node.id}-{i}"
hash_value = hash(virtual_key) % (2**32)
# Insert into sorted ring
bisect.insort(self.ring, hash_value)
self.node_map[hash_value] = node
def get_node(self, key):
"""
Find node for key
O(log N) lookup using binary search
"""
# Hash key
hash_value = hash(key) % (2**32)
# Find next node clockwise
idx = bisect.bisect_right(self.ring, hash_value)
if idx == len(self.ring):
idx = 0 # Wrap around
ring_position = self.ring[idx]
return self.node_map[ring_position]
def add_node(self, node):
"""
Add new node
Only ~1/N keys need remapping!
"""
self._add_node(node)
print(f"Added {node.id}, only ~{100/len(self.ring)*self.virtual_nodes:.1f}% keys remapped")
def remove_node(self, node):
"""Remove node from ring"""
for i in range(self.virtual_nodes):
virtual_key = f"{node.id}-{i}"
hash_value = hash(virtual_key) % (2**32)
idx = self.ring.index(hash_value)
del self.ring[idx]
del self.node_map[hash_value]
# Usage
ring = ConsistentHashRing([node1, node2, node3, node4])
# Keys distributed evenly
key1_node = ring.get_node('user_123')
key2_node = ring.get_node('user_456')
# Add node - minimal disruption!
ring.add_node(node5)
Use in ML:
- Feature store: Partition features by entity ID
- Training data: Distribute examples across workers
- Model serving: Distribute prediction requests
Real-World Case Study: Netflix Recommendation System
Architecture
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Global Load Balancer β
β (GeoDNS) β
βββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββΌββββββββββββ
βΌ βΌ βΌ
ββββββββββ ββββββββββ ββββββββββ
β US β β EU β β APAC β β Regional clusters
β Region β β Region β β Region β
ββββββ¬ββββ ββββββ¬ββββ ββββββ¬ββββ
β β β
βΌ βΌ βΌ
βββββββββββββββββββββββββββββββ
β Cassandra (User Profiles) β β Distributed database
β Replicated across regions β
βββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββ
β Recommendation Service β β 1000s of instances
β (Load balanced) β
ββββββββ¬βββββββββββββββββββββββ
β
βββββ΄βββββ
βΌ βΌ
βββββββ βββββββ
βCacheβ βModelβ β Redis cache + Model replicas
βRedisβ βServeβ
βββββββ βββββββ
Key Distributed Systems Principles Used
- Geographic distribution: Users routed to nearest region (low latency)
- Replication: User data replicated across 3 regions (high availability)
- Caching: Hot recommendations cached (reduce compute)
- Load balancing: Requests distributed across 1000s of servers
- Eventual consistency: Viewing history can be slightly stale
- Partitioning: Users partitioned by user_id (horizontal scaling)
Numbers
- 200M+ users
- 1B+ recommendation requests/day
- 3 regions (US, EU, APAC)
- 1000s of servers per region
- < 100ms p99 latency for recommendations
How they handle failure:
- Region failure: Route traffic to other regions
- Server failure: Load balancer removes from pool
- Cache miss: Fall back to model inference
- Database failure: Serve stale data from replica
Key Takeaways
β
Horizontal scaling - Add machines, not bigger machines
β
Replication - Multiple copies for availability
β
Sharding - Split data for scalability
β
Load balancing - Distribute requests evenly
β
Fault tolerance - Design for failure, not perfection
β
Async communication - Pub-sub for decoupling
β
Consistency trade-offs - CP vs AP based on use case
Core principles:
- Failures are normal - design for them
- Network is unreliable - use retries, timeouts
- Consistency costs performance - choose wisely
- Monitoring is essential - you canβt fix what you canβt see
Originally published at: arunbaby.com/ml-system-design/0012-distributed-systems
If you found this helpful, consider sharing it with others who might benefit.