Caching Strategies for ML Systems
Design efficient caching layers for ML systems to reduce latency, save compute costs, and improve user experience at scale.
Introduction
Caching temporarily stores computed results to serve future requests faster. In ML systems, caching is critical for:
Why caching matters:
- Latency reduction: ms instead of seconds for predictions
- Cost savings: Avoid expensive model inference
- Scalability: Handle more requests with same resources
- Availability: Serve cached results if model service is down
Common caching scenarios in ML:
- Model predictions (feature → prediction)
- Feature computations (raw data → engineered features)
- Embeddings (entity → vector representation)
- Model artifacts (model weights, config)
- Training data (preprocessed datasets)
Cache Hierarchy
┌────────────────────────────────────────────────────┐
│ Client/Browser │
│ (Local Storage, Cookies) │
└──────────────────────┬─────────────────────────────┘
│
▼
┌────────────────────────────────────────────────────┐
│ CDN Cache │
│ (CloudFlare, Akamai, CloudFront) │
└──────────────────────┬─────────────────────────────┘
│
▼
┌────────────────────────────────────────────────────┐
│ Application Cache │
│ (Redis, Memcached, Local) │
└──────────────────────┬─────────────────────────────┘
│
▼
┌────────────────────────────────────────────────────┐
│ ML Model Service │
│ (TensorFlow Serving, etc.) │
└──────────────────────┬─────────────────────────────┘
│
▼
┌────────────────────────────────────────────────────┐
│ Database │
│ (PostgreSQL, MongoDB, etc.) │
└────────────────────────────────────────────────────┘
Cache Eviction Policies
LRU (Least Recently Used)
Most common for ML systems
from collections import OrderedDict
class LRUCache:
"""
LRU Cache implementation
Evicts least recently used items when capacity is reached
"""
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.capacity = capacity
def get(self, key):
"""
Get value and mark as recently used
Time: O(1)
"""
if key not in self.cache:
return None
# Move to end (most recent)
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key, value):
"""
Put key-value pair
Time: O(1)
"""
if key in self.cache:
# Update and move to end
self.cache.move_to_end(key)
self.cache[key] = value
# Evict if over capacity
if len(self.cache) > self.capacity:
# Remove first item (least recently used)
self.cache.popitem(last=False)
def stats(self):
"""Get cache statistics"""
return {
'size': len(self.cache),
'capacity': self.capacity,
'utilization': len(self.cache) / self.capacity
}
# Usage
cache = LRUCache(capacity=1000)
# Cache predictions
def get_prediction_cached(features, model):
cache_key = hash(tuple(features))
# Check cache
cached_result = cache.get(cache_key)
if cached_result is not None:
return cached_result
# Compute prediction
prediction = model.predict([features])[0]
# Cache result
cache.put(cache_key, prediction)
return prediction
LFU (Least Frequently Used)
Good for skewed access patterns
from collections import defaultdict
import heapq
class LFUCache:
"""
LFU Cache - evicts least frequently used items
Better for "hot" items that are accessed repeatedly
"""
def __init__(self, capacity: int):
self.capacity = capacity
self.cache = {} # key -> (value, frequency)
self.freq_map = defaultdict(set) # frequency -> set of keys
self.min_freq = 0
self.access_count = 0
def get(self, key):
"""Get value and increment frequency"""
if key not in self.cache:
return None
value, freq = self.cache[key]
# Update frequency
self.freq_map[freq].remove(key)
if not self.freq_map[freq] and freq == self.min_freq:
self.min_freq += 1
new_freq = freq + 1
self.freq_map[new_freq].add(key)
self.cache[key] = (value, new_freq)
return value
def put(self, key, value):
"""Put key-value pair"""
if self.capacity == 0:
return
if key in self.cache:
# Update existing key
_, freq = self.cache[key]
self.cache[key] = (value, freq)
self.get(key) # Update frequency
return
# Evict if at capacity
if len(self.cache) >= self.capacity:
# Remove item with minimum frequency
evict_key = next(iter(self.freq_map[self.min_freq]))
self.freq_map[self.min_freq].remove(evict_key)
del self.cache[evict_key]
# Add new key
self.cache[key] = (value, 1)
self.freq_map[1].add(key)
self.min_freq = 1
def get_top_k(self, k: int):
"""Get top k most frequently accessed items"""
items = [(freq, key) for key, (val, freq) in self.cache.items()]
return heapq.nlargest(k, items)
# Usage for embeddings (frequently accessed)
embedding_cache = LFUCache(capacity=10000)
def get_embedding_cached(entity_id, embedding_model):
cached_emb = embedding_cache.get(entity_id)
if cached_emb is not None:
return cached_emb
embedding = embedding_model.encode(entity_id)
embedding_cache.put(entity_id, embedding)
return embedding
TTL (Time-To-Live) Cache
Good for time-sensitive data
import time
class TTLCache:
"""
TTL Cache - items expire after specified time
Perfect for:
- User sessions
- Real-time features (stock prices, weather)
- Model predictions that become stale
"""
def __init__(self, default_ttl_seconds=3600):
self.cache = {} # key -> (value, expiration_time)
self.default_ttl = default_ttl_seconds
def get(self, key):
"""Get value if not expired"""
if key not in self.cache:
return None
value, expiration = self.cache[key]
# Check expiration
if time.time() > expiration:
del self.cache[key]
return None
return value
def put(self, key, value, ttl=None):
"""Put key-value pair with TTL"""
if ttl is None:
ttl = self.default_ttl
expiration = time.time() + ttl
self.cache[key] = (value, expiration)
def cleanup(self):
"""Remove expired entries"""
current_time = time.time()
expired_keys = [
k for k, (v, exp) in self.cache.items()
if current_time > exp
]
for key in expired_keys:
del self.cache[key]
return len(expired_keys)
# Usage for time-sensitive predictions
prediction_cache = TTLCache(default_ttl_seconds=300) # 5 minutes
def predict_stock_price(symbol, model):
"""Predictions expire quickly for real-time data"""
cached = prediction_cache.get(symbol)
if cached is not None:
return cached
prediction = model.predict(symbol)
prediction_cache.put(symbol, prediction, ttl=60) # 1 minute TTL
return prediction
Distributed Caching
Redis-Based Cache
import redis
import json
import pickle
import hashlib
class RedisMLCache:
"""
Redis-based cache for ML predictions
Features:
- Distributed across multiple servers
- Persistence
- TTL support
- Pub/sub for cache invalidation
"""
def __init__(self, host='localhost', port=6379, db=0):
self.redis_client = redis.Redis(
host=host,
port=port,
db=db,
decode_responses=False
)
self.hits = 0
self.misses = 0
def _serialize(self, obj):
"""Serialize Python object"""
return pickle.dumps(obj)
def _deserialize(self, data):
"""Deserialize to Python object"""
if data is None:
return None
return pickle.loads(data)
def _make_key(self, prefix, *args):
"""Generate cache key"""
# Hash arguments for consistent key
key_str = f"{prefix}:{':'.join(map(str, args))}"
return key_str
def get_prediction(self, model_id, features):
"""
Get cached prediction
Args:
model_id: Model identifier
features: Feature vector (hashable)
Returns:
Cached prediction or None
"""
# Create cache key
feature_hash = hashlib.md5(
str(features).encode()
).hexdigest()
key = self._make_key('prediction', model_id, feature_hash)
# Get from Redis
cached = self.redis_client.get(key)
if cached is not None:
self.hits += 1
return self._deserialize(cached)
self.misses += 1
return None
def set_prediction(self, model_id, features, prediction, ttl=3600):
"""Cache prediction with TTL"""
feature_hash = hashlib.md5(
str(features).encode()
).hexdigest()
key = self._make_key('prediction', model_id, feature_hash)
# Serialize and store
value = self._serialize(prediction)
self.redis_client.setex(key, ttl, value)
def get_embedding(self, entity_id):
"""Get cached embedding"""
key = self._make_key('embedding', entity_id)
cached = self.redis_client.get(key)
if cached:
self.hits += 1
# Embeddings stored as JSON arrays
return json.loads(cached)
self.misses += 1
return None
def set_embedding(self, entity_id, embedding, ttl=None):
"""Cache embedding"""
key = self._make_key('embedding', entity_id)
value = json.dumps(embedding.tolist() if hasattr(embedding, 'tolist') else embedding)
if ttl:
self.redis_client.setex(key, ttl, value)
else:
self.redis_client.set(key, value)
def invalidate_model(self, model_id):
"""Invalidate all predictions for a model (SCAN + DEL)"""
pattern = self._make_key('prediction', model_id, '*')
cursor = 0
total_deleted = 0
while True:
cursor, keys = self.redis_client.scan(cursor=cursor, match=pattern, count=1000)
if keys:
total_deleted += self.redis_client.delete(*keys)
if cursor == 0:
break
return total_deleted
def get_stats(self):
"""Get cache statistics"""
total_requests = self.hits + self.misses
hit_rate = self.hits / total_requests if total_requests > 0 else 0
return {
'hits': self.hits,
'misses': self.misses,
'hit_rate': hit_rate,
'total_keys': self.redis_client.dbsize()
}
# Usage
cache = RedisMLCache(host='localhost', port=6379)
def predict_with_cache(features, model, model_id):
"""Predict with Redis caching"""
# Check cache
cached = cache.get_prediction(model_id, features)
if cached is not None:
return cached
# Compute prediction
prediction = model.predict([features])[0]
# Cache result
cache.set_prediction(model_id, features, prediction, ttl=3600)
return prediction
# Check cache performance
stats = cache.get_stats()
print(f"Cache hit rate: {stats['hit_rate']:.2%}")
Multi-Level Cache
class MultiLevelCache:
"""
Multi-level caching with L1 (local) and L2 (Redis)
Pattern:
1. Check L1 (in-memory, fastest)
2. If miss, check L2 (Redis, shared)
3. If miss, compute and populate both levels
"""
def __init__(self, l1_capacity=1000, redis_host='localhost'):
# L1: Local LRU cache
self.l1 = LRUCache(capacity=l1_capacity)
# L2: Redis cache
self.l2 = RedisMLCache(host=redis_host)
self.l1_hits = 0
self.l2_hits = 0
self.misses = 0
def get(self, key):
"""Get value from multi-level cache"""
# Try L1
value = self.l1.get(key)
if value is not None:
self.l1_hits += 1
return value
# Try L2
value = self.l2.redis_client.get(key)
if value is not None:
self.l2_hits += 1
# Populate L1
value = self.l2._deserialize(value)
self.l1.put(key, value)
return value
# Miss
self.misses += 1
return None
def put(self, key, value, ttl=3600):
"""Put value in both cache levels"""
# Store in L1
self.l1.put(key, value)
# Store in L2
self.l2.redis_client.setex(
key,
ttl,
self.l2._serialize(value)
)
def get_stats(self):
"""Get multi-level cache statistics"""
total = self.l1_hits + self.l2_hits + self.misses
return {
'l1_hits': self.l1_hits,
'l2_hits': self.l2_hits,
'misses': self.misses,
'total_requests': total,
'l1_hit_rate': self.l1_hits / total if total > 0 else 0,
'l2_hit_rate': self.l2_hits / total if total > 0 else 0,
'overall_hit_rate': (self.l1_hits + self.l2_hits) / total if total > 0 else 0
}
# Usage
ml_cache = MultiLevelCache(l1_capacity=1000, redis_host='localhost')
def get_user_embedding(user_id, embedding_model):
"""Get user embedding with multi-level caching"""
key = f"user_emb:{user_id}"
# Try cache
embedding = ml_cache.get(key)
if embedding is not None:
return embedding
# Compute
embedding = embedding_model.encode(user_id)
# Cache
ml_cache.put(key, embedding, ttl=3600)
return embedding
Cache Warming Strategies
Proactive Cache Warming
import threading
import time
from queue import Queue
class CacheWarmer:
"""
Proactively warm cache before requests arrive
Strategies:
1. Popular items (based on historical data)
2. Scheduled warmup (daily, hourly)
3. Predictive warmup (ML-based)
"""
def __init__(self, cache, compute_fn):
self.cache = cache
self.compute_fn = compute_fn
self.warmup_queue = Queue()
self.is_running = False
def warm_popular_items(self, items, priority='high'):
"""Warm cache with popular items"""
print(f"Warming {len(items)} popular items...")
for item in items:
key, args = item
# Check if already cached
if self.cache.get(key) is not None:
continue
# Compute and cache
try:
result = self.compute_fn(*args)
self.cache.put(key, result)
except Exception as e:
print(f"Error warming {key}: {e}")
def warm_on_schedule(self, items, interval_seconds=3600):
"""Periodically warm cache"""
def warmup_worker():
while self.is_running:
self.warm_popular_items(items)
time.sleep(interval_seconds)
self.is_running = True
worker = threading.Thread(target=warmup_worker, daemon=True)
worker.start()
def stop(self):
"""Stop scheduled warmup"""
self.is_running = False
# Usage
def compute_recommendation(user_id, model):
"""Expensive recommendation computation"""
return model.recommend(user_id, n=10)
cache = LRUCache(capacity=10000)
warmer = CacheWarmer(cache, compute_recommendation)
# Warm cache with top 1000 users
popular_users = get_top_1000_active_users()
items = [
(f"rec:{user_id}", (user_id, recommendation_model))
for user_id in popular_users
]
warmer.warm_popular_items(items)
# Or schedule periodic warmup
warmer.warm_on_schedule(items, interval_seconds=3600)
Cache Invalidation
Push-Based Invalidation
import redis
class CacheInvalidator:
"""
Cache invalidation using Redis Pub/Sub
Pattern:
- When model updates, publish invalidation message
- All cache instances subscribe and clear relevant entries
"""
def __init__(self, redis_host='localhost'):
self.redis_pub = redis.Redis(host=redis_host)
self.redis_sub = redis.Redis(host=redis_host)
self.cache = {}
self.invalidation_count = 0
def subscribe_to_invalidations(self, channel='cache:invalidate'):
"""Subscribe to invalidation messages"""
pubsub = self.redis_sub.pubsub()
pubsub.subscribe(channel)
def listen():
for message in pubsub.listen():
if message['type'] == 'message':
self._handle_invalidation(message['data'])
# Start listener thread
listener = threading.Thread(target=listen, daemon=True)
listener.start()
def _handle_invalidation(self, message):
"""Handle invalidation message"""
# Message format: "model_id:v2"
invalidation_key = message.decode('utf-8')
# Remove matching cache entries
keys_to_remove = [
k for k in self.cache.keys()
if k.startswith(invalidation_key)
]
for key in keys_to_remove:
del self.cache[key]
self.invalidation_count += len(keys_to_remove)
print(f"Invalidated {len(keys_to_remove)} cache entries")
def invalidate_model(self, model_id):
"""Publish invalidation message"""
message = f"{model_id}:v"
self.redis_pub.publish('cache:invalidate', message)
# Usage
invalidator = CacheInvalidator()
invalidator.subscribe_to_invalidations()
# When model is updated
def update_model(model_id, new_model):
"""Update model and invalidate cache"""
# Deploy new model
deploy_model(new_model)
# Invalidate all predictions for this model
invalidator.invalidate_model(model_id)
Feature Store Caching
class FeatureStoreCache:
"""
Caching layer for feature store
Features:
- Cache precomputed features
- Batch feature retrieval
- Freshness guarantees
"""
def __init__(self, redis_client, ttl=3600):
self.redis = redis_client
self.ttl = ttl
def get_features(self, entity_ids, feature_names):
"""
Get features for multiple entities (batch)
Args:
entity_ids: List of entity IDs
feature_names: List of feature names
Returns:
Dict of entity_id -> feature_dict
"""
results = {}
cache_misses = []
# Try cache first
for entity_id in entity_ids:
cache_key = f"features:{entity_id}"
cached = self.redis.get(cache_key)
if cached:
# Parse cached features
features = json.loads(cached)
# Filter to requested features
filtered = {
fname: features[fname]
for fname in feature_names
if fname in features
}
if len(filtered) == len(feature_names):
results[entity_id] = filtered
else:
cache_misses.append(entity_id)
else:
cache_misses.append(entity_id)
# Compute missing features
if cache_misses:
computed = self._compute_features(cache_misses, feature_names)
# Cache computed features
for entity_id, features in computed.items():
self._cache_features(entity_id, features)
results[entity_id] = features
return results
def _compute_features(self, entity_ids, feature_names):
"""Compute features from feature store"""
# Call actual feature store
return compute_features_batch(entity_ids, feature_names)
def _cache_features(self, entity_id, features):
"""Cache features for entity"""
cache_key = f"features:{entity_id}"
self.redis.setex(
cache_key,
self.ttl,
json.dumps(features)
)
def invalidate_entity(self, entity_id):
"""Invalidate features for entity"""
cache_key = f"features:{entity_id}"
self.redis.delete(cache_key)
# Usage
feature_cache = FeatureStoreCache(redis_client, ttl=300)
# Get features for batch of users
user_ids = [123, 456, 789]
feature_names = ['age', 'location', 'purchase_count']
features = feature_cache.get_features(user_ids, feature_names)
Connection to Linked Lists (Day 10 DSA)
Cache implementations heavily use linked list concepts:
class DoublyLinkedNode:
"""Node for doubly-linked list (used in LRU)"""
def __init__(self, key, value):
self.key = key
self.value = value
self.prev = None
self.next = None
class ProductionLRUCache:
"""
Production LRU cache using doubly-linked list
Connection to Day 10 DSA:
- Uses linked list for maintaining order
- Pointer manipulation similar to reversal
- O(1) operations through careful pointer management
"""
def __init__(self, capacity: int):
self.capacity = capacity
self.cache = {}
# Dummy head and tail
self.head = DoublyLinkedNode(0, 0)
self.tail = DoublyLinkedNode(0, 0)
self.head.next = self.tail
self.tail.prev = self.head
def _add_node(self, node):
"""Add node right after head"""
node.prev = self.head
node.next = self.head.next
self.head.next.prev = node
self.head.next = node
def _remove_node(self, node):
"""Remove node from list"""
prev_node = node.prev
next_node = node.next
prev_node.next = next_node
next_node.prev = prev_node
def _move_to_head(self, node):
"""Move node to head (most recently used)"""
self._remove_node(node)
self._add_node(node)
def _pop_tail(self):
"""Remove least recently used (tail.prev)"""
res = self.tail.prev
self._remove_node(res)
return res
def get(self, key):
"""Get value"""
node = self.cache.get(key)
if not node:
return -1
self._move_to_head(node)
return node.value
def put(self, key, value):
"""Put key-value"""
node = self.cache.get(key)
if node:
node.value = value
self._move_to_head(node)
else:
new_node = DoublyLinkedNode(key, value)
self.cache[key] = new_node
self._add_node(new_node)
if len(self.cache) > self.capacity:
tail = self._pop_tail()
del self.cache[tail.key]
Understanding Cache Performance
Cache Hit Rate Analysis
class CachePerformanceAnalyzer:
"""
Analyze and optimize cache performance
Key metrics:
- Hit rate: % of requests served from cache
- Miss rate: % of requests requiring computation
- Latency reduction: Time saved by caching
- Memory efficiency: Cache size vs hit rate
"""
def __init__(self):
self.total_requests = 0
self.cache_hits = 0
self.cache_misses = 0
self.hit_latencies = []
self.miss_latencies = []
def record_hit(self, latency_ms):
"""Record cache hit"""
self.cache_hits += 1
self.total_requests += 1
self.hit_latencies.append(latency_ms)
def record_miss(self, latency_ms):
"""Record cache miss"""
self.cache_misses += 1
self.total_requests += 1
self.miss_latencies.append(latency_ms)
def get_metrics(self):
"""Calculate performance metrics"""
if self.total_requests == 0:
return {}
hit_rate = self.cache_hits / self.total_requests
miss_rate = self.cache_misses / self.total_requests
avg_hit_latency = (
sum(self.hit_latencies) / len(self.hit_latencies)
if self.hit_latencies else 0
)
avg_miss_latency = (
sum(self.miss_latencies) / len(self.miss_latencies)
if self.miss_latencies else 0
)
# Calculate latency reduction
avg_latency_with_cache = (
hit_rate * avg_hit_latency + miss_rate * avg_miss_latency
)
latency_reduction = (
(avg_miss_latency - avg_latency_with_cache) / avg_miss_latency
if avg_miss_latency > 0 else 0
)
return {
'total_requests': self.total_requests,
'cache_hits': self.cache_hits,
'cache_misses': self.cache_misses,
'hit_rate': hit_rate,
'miss_rate': miss_rate,
'avg_hit_latency_ms': avg_hit_latency,
'avg_miss_latency_ms': avg_miss_latency,
'avg_overall_latency_ms': avg_latency_with_cache,
'latency_reduction_pct': latency_reduction * 100
}
def print_report(self):
"""Print performance report"""
metrics = self.get_metrics()
print("\n" + "="*60)
print("CACHE PERFORMANCE REPORT")
print("="*60)
print(f"Total Requests: {metrics['total_requests']:,}")
print(f"Cache Hits: {metrics['cache_hits']:,}")
print(f"Cache Misses: {metrics['cache_misses']:,}")
print(f"Hit Rate: {metrics['hit_rate']:.2%}")
print(f"Miss Rate: {metrics['miss_rate']:.2%}")
print(f"\nLatency Analysis:")
print(f" Cache Hit: {metrics['avg_hit_latency_ms']:.2f} ms")
print(f" Cache Miss: {metrics['avg_miss_latency_ms']:.2f} ms")
print(f" Overall Average: {metrics['avg_overall_latency_ms']:.2f} ms")
print(f" Latency Reduction: {metrics['latency_reduction_pct']:.1f}%")
print("="*60)
# Usage example
analyzer = CachePerformanceAnalyzer()
# Simulate requests
import random
import time
cache = LRUCache(capacity=100)
for i in range(1000):
key = f"key_{random.randint(1, 150)}"
# Check cache
start = time.perf_counter()
value = cache.get(key)
if value is not None:
# Cache hit (fast)
latency = (time.perf_counter() - start) * 1000
analyzer.record_hit(latency)
else:
# Cache miss (slow - simulate computation)
time.sleep(0.001) # 1ms computation
latency = (time.perf_counter() - start) * 1000
analyzer.record_miss(latency)
# Store in cache
cache.put(key, f"value_{key}")
analyzer.print_report()
Cache Size Optimization
class CacheSizeOptimizer:
"""
Find optimal cache size for given workload
Trade-off: Larger cache = higher hit rate but more memory
"""
def __init__(self, workload):
"""
Args:
workload: List of access patterns (keys)
"""
self.workload = workload
def find_optimal_size(self, max_size=10000, step=100):
"""
Test different cache sizes
Returns optimal size based on diminishing returns
"""
results = []
print("Testing cache sizes...")
print(f"{'Size':<10} {'Hit Rate':<12} {'Marginal Gain':<15}")
print("-" * 40)
prev_hit_rate = 0
for size in range(step, max_size + 1, step):
hit_rate = self._simulate_cache(size)
marginal_gain = hit_rate - prev_hit_rate
results.append({
'size': size,
'hit_rate': hit_rate,
'marginal_gain': marginal_gain
})
print(f"{size:<10} {hit_rate:<12.2%} {marginal_gain:<15.4%}")
prev_hit_rate = hit_rate
# Stop if marginal gain is too small
if marginal_gain < 0.001: # 0.1% gain
print(f"\nDiminishing returns detected at size {size}")
break
return results
def _simulate_cache(self, size):
"""Simulate cache with given size"""
cache = LRUCache(capacity=size)
hits = 0
for key in self.workload:
if cache.get(key) is not None:
hits += 1
else:
cache.put(key, True)
return hits / len(self.workload)
# Generate workload (Zipf distribution - realistic for many applications)
import numpy as np
def generate_zipf_workload(n_items=1000, n_requests=10000, alpha=1.5):
"""
Generate Zipf-distributed workload
Zipf law: Some items are accessed much more frequently
(80/20 rule, power law distribution)
"""
# Zipf distribution
probabilities = np.array([1.0 / (i ** alpha) for i in range(1, n_items + 1)])
probabilities /= probabilities.sum()
# Generate requests
workload = np.random.choice(
[f"key_{i}" for i in range(n_items)],
size=n_requests,
p=probabilities
)
return workload.tolist()
# Find optimal cache size
workload = generate_zipf_workload(n_items=1000, n_requests=10000)
optimizer = CacheSizeOptimizer(workload)
results = optimizer.find_optimal_size(max_size=500, step=50)
# Plot results
import matplotlib.pyplot as plt
sizes = [r['size'] for r in results]
hit_rates = [r['hit_rate'] for r in results]
plt.figure(figsize=(10, 6))
plt.plot(sizes, hit_rates, marker='o')
plt.xlabel('Cache Size')
plt.ylabel('Hit Rate')
plt.title('Cache Size vs Hit Rate')
plt.grid(True)
plt.savefig('cache_size_optimization.png')
Advanced Caching Patterns
Write-Through vs Write-Back Cache
class WriteThroughCache:
"""
Write-through cache: Write to cache and database simultaneously
Pros:
- Data consistency
- Simple to implement
Cons:
- Slower writes
- Every write hits database
"""
def __init__(self, cache, database):
self.cache = cache
self.database = database
def get(self, key):
"""Read with cache"""
# Try cache first
value = self.cache.get(key)
if value is not None:
return value
# Cache miss: read from database
value = self.database.get(key)
if value is not None:
self.cache.put(key, value)
return value
def put(self, key, value):
"""Write to both cache and database"""
# Write to database first
self.database.put(key, value)
# Then update cache
self.cache.put(key, value)
class WriteBackCache:
"""
Write-back cache: Write to cache only, flush to database later
Pros:
- Fast writes
- Batching possible
Cons:
- Risk of data loss
- More complex
"""
def __init__(self, cache, database, flush_interval=5):
self.cache = cache
self.database = database
self.flush_interval = flush_interval
self.dirty_keys = set()
self.last_flush = time.time()
def get(self, key):
"""Read with cache"""
value = self.cache.get(key)
if value is not None:
return value
value = self.database.get(key)
if value is not None:
self.cache.put(key, value)
return value
def put(self, key, value):
"""Write to cache only"""
self.cache.put(key, value)
self.dirty_keys.add(key)
# Check if we need to flush
if time.time() - self.last_flush > self.flush_interval:
self.flush()
def flush(self):
"""Flush dirty keys to database"""
if not self.dirty_keys:
return
print(f"Flushing {len(self.dirty_keys)} dirty keys...")
for key in self.dirty_keys:
value = self.cache.get(key)
if value is not None:
self.database.put(key, value)
self.dirty_keys.clear()
self.last_flush = time.time()
# Example database simulation
class SimpleDatabase:
def __init__(self):
self.data = {}
self.read_count = 0
self.write_count = 0
def get(self, key):
self.read_count += 1
time.sleep(0.001) # Simulate latency
return self.data.get(key)
def put(self, key, value):
self.write_count += 1
time.sleep(0.001) # Simulate latency
self.data[key] = value
# Compare write-through vs write-back
db1 = SimpleDatabase()
cache1 = LRUCache(capacity=100)
write_through = WriteThroughCache(cache1, db1)
db2 = SimpleDatabase()
cache2 = LRUCache(capacity=100)
write_back = WriteBackCache(cache2, db2)
# Benchmark writes
import time
# Write-through
start = time.time()
for i in range(100):
write_through.put(f"key_{i}", f"value_{i}")
wt_time = time.time() - start
# Write-back
start = time.time()
for i in range(100):
write_back.put(f"key_{i}", f"value_{i}")
write_back.flush() # Final flush
wb_time = time.time() - start
print(f"Write-through: {wt_time:.3f}s, DB writes: {db1.write_count}")
print(f"Write-back: {wb_time:.3f}s, DB writes: {db2.write_count}")
Cache Aside Pattern
class CacheAsidePattern:
"""
Cache-aside (lazy loading): Application manages cache
Most common pattern for ML systems
Flow:
1. Check cache
2. If miss, query database
3. Store in cache
4. Return result
"""
def __init__(self, cache, database):
self.cache = cache
self.database = database
self.stats = {
'reads': 0,
'cache_hits': 0,
'cache_misses': 0,
'writes': 0
}
def get(self, key):
"""
Get with cache-aside pattern
Application is responsible for loading cache
"""
self.stats['reads'] += 1
# Try cache first
value = self.cache.get(key)
if value is not None:
self.stats['cache_hits'] += 1
return value
# Cache miss: load from database
self.stats['cache_misses'] += 1
value = self.database.get(key)
if value is not None:
# Populate cache for next time
self.cache.put(key, value)
return value
def put(self, key, value):
"""
Write to database, invalidate cache
Simple approach: Just write to DB and remove from cache
Next read will repopulate
"""
self.stats['writes'] += 1
# Write to database
self.database.put(key, value)
# Invalidate cache entry
# (Could also update cache here - depends on use case)
if key in self.cache.cache:
del self.cache.cache[key]
def get_stats(self):
"""Get cache statistics"""
hit_rate = (
self.stats['cache_hits'] / self.stats['reads']
if self.stats['reads'] > 0 else 0
)
return {
**self.stats,
'hit_rate': hit_rate
}
# Usage for ML predictions
class MLPredictionService:
"""
ML prediction service with cache-aside pattern
"""
def __init__(self, model, cache_capacity=1000):
self.model = model
self.cache = LRUCache(capacity=cache_capacity)
# Fake database for persisted predictions
self.prediction_db = {}
self.pattern = CacheAsidePattern(
self.cache,
self.prediction_db
)
def predict(self, features):
"""
Predict with caching
Args:
features: Feature vector (tuple for hashability)
Returns:
Prediction
"""
# Create cache key from features
cache_key = hash(features)
# Try cache-aside pattern
cached_prediction = self.pattern.get(cache_key)
if cached_prediction is not None:
return cached_prediction
# Compute prediction (expensive)
prediction = self.model.predict([features])[0]
# Store in database and cache
self.pattern.put(cache_key, prediction)
return prediction
def get_cache_stats(self):
"""Get caching statistics"""
return self.pattern.get_stats()
# Example usage
from sklearn.ensemble import RandomForestClassifier
import numpy as np
# Train simple model
X_train = np.random.randn(100, 5)
y_train = (X_train.sum(axis=1) > 0).astype(int)
model = RandomForestClassifier(n_estimators=10)
model.fit(X_train, y_train)
# Create prediction service
service = MLPredictionService(model, cache_capacity=100)
# Make predictions (some repeated)
for _ in range(1000):
# Generate features (with some repetition)
features = tuple(np.random.randint(0, 10, size=5))
prediction = service.predict(features)
print("Cache statistics:")
print(service.get_cache_stats())
Cache Stampede Prevention
Problem: Thundering Herd
class CacheStampedeProtection:
"""
Prevent cache stampede (thundering herd)
Problem:
- Cache entry expires
- Many requests try to regenerate simultaneously
- Database/model gets overwhelmed
Solution:
- Use locking to ensure only one request regenerates
- Others wait for that request to complete
"""
def __init__(self, cache, compute_fn):
self.cache = cache
self.compute_fn = compute_fn
# Lock for each key
self.locks = {}
self.master_lock = threading.Lock()
def get(self, key):
"""
Get with stampede protection
Uses double-check locking pattern
"""
# First check: Try cache (no lock)
value = self.cache.get(key)
if value is not None:
return value
# Get or create lock for this key
with self.master_lock:
if key not in self.locks:
self.locks[key] = threading.Lock()
key_lock = self.locks[key]
# Acquire key-specific lock
with key_lock:
# Second check: Try cache again (another thread might have filled it)
value = self.cache.get(key)
if value is not None:
return value
# Compute value (only one thread does this)
print(f"Computing value for {key} (thread: {threading.current_thread().name})")
value = self.compute_fn(key)
# Store in cache
self.cache.put(key, value)
return value
# Demo: Simulate stampede
import threading
import time
def expensive_computation(key):
"""Simulate expensive computation"""
time.sleep(0.1) # 100ms
return f"computed_value_for_{key}"
cache = LRUCache(capacity=100)
protector = CacheStampedeProtection(cache, expensive_computation)
# Simulate stampede: 10 threads requesting same key
def make_request(key, results, index):
start = time.time()
result = protector.get(key)
duration = time.time() - start
results[index] = duration
results = [0] * 10
threads = []
# Clear cache to force computation
cache = LRUCache(capacity=100)
protector.cache = cache
print("Simulating cache stampede for key 'popular_item'...")
start_time = time.time()
for i in range(10):
t = threading.Thread(
target=make_request,
args=('popular_item', results, i),
name=f"Thread-{i}"
)
threads.append(t)
t.start()
for t in threads:
t.join()
total_time = time.time() - start_time
print(f"\nTotal time: {total_time:.3f}s")
print(f"Average request time: {sum(results)/len(results):.3f}s")
print(f"Max request time: {max(results):.3f}s")
print(f"Min request time: {min(results):.3f}s")
print("\nWith protection, only one thread computed (others waited)")
Probabilistic Early Expiration
class ProbabilisticCache:
"""
Probabilistic early expiration to prevent stampede
Idea: Refresh cache before expiration with increasing probability
This spreads out refresh operations
"""
def __init__(self, cache, compute_fn, ttl=60, beta=1.0):
"""
Args:
ttl: Time to live in seconds
beta: Controls early expiration probability
"""
self.cache = cache
self.compute_fn = compute_fn
self.ttl = ttl
self.beta = beta
# Track insertion times
self.insertion_times = {}
def get(self, key):
"""
Get with probabilistic early expiration
Formula: Should refresh if:
current_time - stored_time * beta * log(random) >= ttl
"""
# Check cache
value = self.cache.get(key)
if value is not None and key in self.insertion_times:
# Calculate age
age = time.time() - self.insertion_times[key]
# Probabilistic early expiration
import random
import math
# XFetch algorithm
delta = self.ttl - age
if delta * self.beta * math.log(random.random()) < 0:
# Refresh early
print(f"Early refresh for {key} (age: {age:.1f}s)")
value = self._refresh(key)
return value
# Cache miss or expired
return self._refresh(key)
def _refresh(self, key):
"""Refresh cache entry"""
value = self.compute_fn(key)
self.cache.put(key, value)
self.insertion_times[key] = time.time()
return value
# Demo
def compute_value(key):
time.sleep(0.01)
return f"value_{key}_{time.time()}"
pcache = ProbabilisticCache(
LRUCache(capacity=100),
compute_value,
ttl=5, # 5 second TTL
beta=1.0
)
# Access same key multiple times
for i in range(20):
value = pcache.get('test_key')
time.sleep(0.3) # 300ms between requests
Distributed Cache Challenges
Cache Consistency
class DistributedCacheCoordinator:
"""
Coordinate cache across multiple instances
Challenges:
1. Keeping caches in sync
2. Handling partial failures
3. Eventual consistency
"""
def __init__(self, redis_client, instance_id):
self.redis = redis_client
self.instance_id = instance_id
# Local L1 cache
self.local_cache = LRUCache(capacity=1000)
# Subscribe to invalidation messages
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe('cache:invalidate')
# Start listener thread
self.listener_thread = threading.Thread(
target=self._listen_for_invalidations,
daemon=True
)
self.listener_thread.start()
def get(self, key):
"""
Get from multi-level cache
L1 (local) -> L2 (Redis) -> Compute
"""
# Try local cache
value = self.local_cache.get(key)
if value is not None:
return value
# Try Redis
value = self.redis.get(key)
if value is not None:
value = pickle.loads(value)
# Populate local cache
self.local_cache.put(key, value)
return value
return None
def put(self, key, value, ttl=3600):
"""
Put in both levels and notify others
"""
# Store in local cache
self.local_cache.put(key, value)
# Store in Redis
self.redis.setex(key, ttl, pickle.dumps(value))
# Notify other instances to invalidate their L1
self.redis.publish(
'cache:invalidate',
json.dumps({
'key': key,
'source_instance': self.instance_id
})
)
def _listen_for_invalidations(self):
"""Listen for invalidation messages"""
for message in self.pubsub.listen():
if message['type'] == 'message':
data = json.loads(message['data'])
# Don't invalidate if we sent the message
if data['source_instance'] != self.instance_id:
key = data['key']
# Invalidate local cache
if key in self.local_cache.cache:
del self.local_cache.cache[key]
print(f"Invalidated {key} from local cache")
# Usage across multiple instances
# Instance 1
coordinator1 = DistributedCacheCoordinator(redis_client, instance_id='instance1')
# Instance 2
coordinator2 = DistributedCacheCoordinator(redis_client, instance_id='instance2')
# Instance 1 writes
coordinator1.put('shared_key', 'value_from_instance1')
# Instance 2 reads (will get from Redis)
value = coordinator2.get('shared_key')
Key Takeaways
✅ Multiple eviction policies - LRU, LFU, TTL for different use cases
✅ Distributed caching - Redis for shared cache across services
✅ Multi-level caching - L1 (local) + L2 (distributed) for optimal performance
✅ Cache warming - Proactive population of hot items
✅ Invalidation strategies - Push-based and pull-based
✅ Linked list connection - Understanding pointer manipulation helps with cache implementation
✅ Monitor cache metrics - Hit rate, latency, memory usage
Originally published at: arunbaby.com/ml-system-design/0010-caching-strategies
If you found this helpful, consider sharing it with others who might benefit.