Model Serving Architecture
Design production-grade model serving systems that deliver predictions at scale with low latency and high reliability.
Introduction
Model serving is the process of deploying ML models to production and making predictions available to end users or downstream systems.
Why it’s critical:
- Bridge training and production: Trained models are useless without serving
- Performance matters: Latency directly impacts user experience
- Scale requirements: Handle millions of requests per second
- Reliability: Downtime = lost revenue
Key challenges:
- Low latency (< 100ms for many applications)
- High throughput (handle traffic spikes)
- Model versioning and rollback
- A/B testing and gradual rollouts
- Monitoring and debugging
Model Serving Architecture Overview
┌─────────────────────────────────────────────────────────┐
│ Client Applications │
│ (Web, Mobile, Backend Services) │
└────────────────────┬────────────────────────────────────┘
│ HTTP/gRPC requests
▼
┌─────────────────────────────────────────────────────────┐
│ Load Balancer │
│ (nginx, ALB, GCP Load Balancer) │
└────────────────────┬────────────────────────────────────┘
│
┌──────────┼──────────┐
▼ ▼ ▼
┌─────────┐ ┌─────────┐ ┌─────────┐
│ Serving │ │ Serving │ │ Serving │
│ Instance│ │ Instance│ │ Instance│
│ 1 │ │ 2 │ │ N │
└────┬────┘ └────┬────┘ └────┬────┘
│ │ │
▼ ▼ ▼
┌────────────────────────────────┐
│ Model Repository │
│ (S3, GCS, Model Registry) │
└────────────────────────────────┘
Serving Patterns
Pattern 1: REST API Serving
Best for: Web applications, microservices
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
import joblib
from typing import List
import time
app = FastAPI()
# Load model on startup
model = None
@app.on_event("startup")
async def load_model():
"""Load model when server starts"""
global model
model = joblib.load('model.pkl')
print("Model loaded successfully")
class PredictionRequest(BaseModel):
"""Request schema"""
features: List[float]
class PredictionResponse(BaseModel):
"""Response schema"""
prediction: float
confidence: float
model_version: str
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""
Make prediction
Returns: Prediction with confidence
"""
try:
# Convert to numpy array
features = np.array([request.features])
# Make prediction
prediction = model.predict(features)[0]
# Get confidence (if available)
if hasattr(model, 'predict_proba'):
proba = model.predict_proba(features)[0]
confidence = float(np.max(proba))
else:
confidence = 1.0
return PredictionResponse(
prediction=float(prediction),
confidence=confidence,
model_version="v1.0"
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
return {"status": "healthy", "model_loaded": True}
@app.get("/ready")
async def readiness_check():
"""Readiness probe endpoint"""
# Optionally include lightweight self-test
return {"ready": model is not None}
# Run with: uvicorn app:app --host 0.0.0.0 --port 8000
Usage:
curl -X POST "http://localhost:8000/predict" \
-H "Content-Type: application/json" \
-d '{"features": [1.0, 2.0, 3.0, 4.0]}'
Pattern 2: gRPC Serving
Best for: High-performance, low-latency applications
# prediction.proto
"""
syntax = "proto3";
service PredictionService {
rpc Predict (PredictRequest) returns (PredictResponse);
}
message PredictRequest {
repeated float features = 1;
}
message PredictResponse {
float prediction = 1;
float confidence = 2;
}
"""
# server.py
import grpc
from concurrent import futures
import prediction_pb2
import prediction_pb2_grpc
import numpy as np
import joblib
class PredictionServicer(prediction_pb2_grpc.PredictionServiceServicer):
"""gRPC Prediction Service"""
def __init__(self):
self.model = joblib.load('model.pkl')
def Predict(self, request, context):
"""Handle prediction request"""
try:
# Convert features
features = np.array([list(request.features)])
# Predict
prediction = self.model.predict(features)[0]
# Get confidence
if hasattr(self.model, 'predict_proba'):
proba = self.model.predict_proba(features)[0]
confidence = float(np.max(proba))
else:
confidence = 1.0
return prediction_pb2.PredictResponse(
prediction=float(prediction),
confidence=confidence
)
except Exception as e:
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return prediction_pb2.PredictResponse()
def serve():
"""Start gRPC server"""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
prediction_pb2_grpc.add_PredictionServiceServicer_to_server(
PredictionServicer(), server
)
server.add_insecure_port('[::]:50051')
server.start()
print("gRPC server started on port 50051")
server.wait_for_termination()
if __name__ == '__main__':
serve()
Performance comparison:
Metric REST API gRPC
───────────────────────────────────
Latency (p50) 15ms 5ms
Latency (p99) 50ms 20ms
Throughput 5K rps 15K rps
Payload size JSON Protocol Buffers (smaller)
Pattern 3: Batch Serving
Best for: Offline predictions, large-scale inference
import pandas as pd
import numpy as np
from multiprocessing import Pool
import joblib
class BatchPredictor:
"""
Batch prediction system
Efficient for processing large datasets
"""
def __init__(self, model_path, batch_size=1000, n_workers=4):
self.model = joblib.load(model_path)
self.batch_size = batch_size
self.n_workers = n_workers
def predict_batch(self, features_df: pd.DataFrame) -> np.ndarray:
"""
Predict on large dataset
Args:
features_df: DataFrame with features
Returns:
Array of predictions
"""
n_samples = len(features_df)
n_batches = (n_samples + self.batch_size - 1) // self.batch_size
predictions = []
for i in range(n_batches):
start_idx = i * self.batch_size
end_idx = min((i + 1) * self.batch_size, n_samples)
batch = features_df.iloc[start_idx:end_idx].values
batch_pred = self.model.predict(batch)
predictions.extend(batch_pred)
if (i + 1) % 10 == 0:
print(f"Processed {end_idx}/{n_samples} samples")
return np.array(predictions)
def predict_parallel(self, features_df: pd.DataFrame) -> np.ndarray:
"""
Parallel batch prediction
Splits data across multiple processes
"""
# Split data into chunks
chunk_size = len(features_df) // self.n_workers
chunks = [
features_df.iloc[i:i+chunk_size]
for i in range(0, len(features_df), chunk_size)
]
# Process in parallel
with Pool(self.n_workers) as pool:
results = pool.map(self._predict_chunk, chunks)
# Combine results
return np.concatenate(results)
def _predict_chunk(self, chunk_df):
"""Predict on single chunk"""
return self.model.predict(chunk_df.values)
# Usage
predictor = BatchPredictor('model.pkl', batch_size=10000, n_workers=8)
# Load large dataset
data = pd.read_parquet('features.parquet')
# Predict
predictions = predictor.predict_parallel(data)
# Save results
results_df = data.copy()
results_df['prediction'] = predictions
results_df.to_parquet('predictions.parquet')
Model Loading Strategies
Strategy 1: Eager Loading
class EagerModelServer:
"""
Load model on server startup
Pros: Fast predictions, simple
Cons: High startup time, high memory
"""
def __init__(self, model_path):
print("Loading model...")
self.model = joblib.load(model_path)
print("Model loaded!")
def predict(self, features):
"""Make prediction (fast)"""
return self.model.predict(features)
Strategy 2: Lazy Loading
class LazyModelServer:
"""
Load model on first request
Pros: Fast startup
Cons: First request is slow
"""
def __init__(self, model_path):
self.model_path = model_path
self.model = None
def predict(self, features):
"""Load model if needed, then predict"""
if self.model is None:
print("Loading model on first request...")
self.model = joblib.load(self.model_path)
return self.model.predict(features)
Strategy 3: Model Caching with Expiration
from datetime import datetime, timedelta
import threading
class CachedModelServer:
"""
Load model with cache expiration
Automatically reloads model periodically
"""
def __init__(self, model_path, cache_ttl_minutes=60):
self.model_path = model_path
self.cache_ttl = timedelta(minutes=cache_ttl_minutes)
self.model = None
self.last_loaded = None
self.lock = threading.Lock()
def _load_model(self):
"""Load model with lock"""
with self.lock:
print(f"Loading model from {self.model_path}")
self.model = joblib.load(self.model_path)
self.last_loaded = datetime.now()
def predict(self, features):
"""Predict with cache check"""
# Check if model needs refresh
if (self.model is None or
datetime.now() - self.last_loaded > self.cache_ttl):
self._load_model()
return self.model.predict(features)
Model Versioning & A/B Testing
Multi-Model Serving
from enum import Enum
from typing import Dict
import random
class ModelVersion(Enum):
V1 = "v1"
V2 = "v2"
V3 = "v3"
class MultiModelServer:
"""
Serve multiple model versions
Supports A/B testing and gradual rollouts
"""
def __init__(self):
self.models: Dict[str, any] = {}
self.traffic_split = {} # version → weight
def load_model(self, version: ModelVersion, model_path: str):
"""Load a specific model version"""
print(f"Loading {version.value} from {model_path}")
self.models[version.value] = joblib.load(model_path)
def set_traffic_split(self, split: Dict[str, float]):
"""
Set traffic distribution
Args:
split: Dict mapping version to weight
e.g., {"v1": 0.9, "v2": 0.1}
"""
# Validate weights sum to 1
total = sum(split.values())
assert abs(total - 1.0) < 1e-6, f"Weights must sum to 1, got {total}"
self.traffic_split = split
def select_model(self, user_id: str = None) -> str:
"""
Select model version based on traffic split
Args:
user_id: Optional user ID for deterministic routing
Returns:
Selected model version
"""
if user_id:
# Deterministic selection (consistent for same user)
import hashlib
hash_val = int(hashlib.md5(user_id.encode()).hexdigest(), 16)
rand_val = (hash_val % 10000) / 10000.0
else:
# Random selection
rand_val = random.random()
# Select based on cumulative weights
cumulative = 0
for version, weight in self.traffic_split.items():
cumulative += weight
if rand_val < cumulative:
return version
# Fallback to first version
return list(self.traffic_split.keys())[0]
def predict(self, features, user_id: str = None):
"""
Make prediction with version selection
Returns: (prediction, version_used)
"""
version = self.select_model(user_id)
model = self.models[version]
prediction = model.predict(features)
return prediction, version
# Usage
server = MultiModelServer()
# Load models
server.load_model(ModelVersion.V1, 'model_v1.pkl')
server.load_model(ModelVersion.V2, 'model_v2.pkl')
# Start with 90% v1, 10% v2
server.set_traffic_split({"v1": 0.9, "v2": 0.1})
# Make predictions
features = [[1, 2, 3, 4]]
prediction, version = server.predict(features, user_id="user_123")
print(f"Prediction: {prediction}, Version: {version}")
# Gradually increase v2 traffic
server.set_traffic_split({"v1": 0.5, "v2": 0.5})
Optimization Techniques
1. Model Quantization
import torch
import torch.quantization
def quantize_model(model, example_input):
"""
Quantize PyTorch model to INT8
Reduces model size by ~4x, speeds up inference
Args:
model: PyTorch model
example_input: Sample input for calibration
Returns:
Quantized model
"""
# Set model to eval mode
model.eval()
# Specify quantization configuration
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# Prepare for quantization
model_prepared = torch.quantization.prepare(model)
# Calibrate with example data
with torch.no_grad():
model_prepared(example_input)
# Convert to quantized model
model_quantized = torch.quantization.convert(model_prepared)
return model_quantized
# Example
model = torch.nn.Sequential(
torch.nn.Linear(10, 50),
torch.nn.ReLU(),
torch.nn.Linear(50, 2)
)
example_input = torch.randn(1, 10)
quantized_model = quantize_model(model, example_input)
# Quantized model is ~4x smaller and faster
print(f"Original size: {get_model_size(model):.2f} MB")
print(f"Quantized size: {get_model_size(quantized_model):.2f} MB")
2. Batch Inference
import asyncio
from collections import deque
import time
class BatchingPredictor:
"""
Batch multiple requests for efficient inference
Collects requests and processes them in batches
"""
def __init__(self, model, max_batch_size=32, max_wait_ms=10):
self.model = model
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.queue = deque()
self.processing = False
async def predict(self, features):
"""
Add request to batch queue
Returns: Future that resolves with prediction
"""
future = asyncio.Future()
self.queue.append((features, future))
# Start batch processing if not already running
if not self.processing:
asyncio.create_task(self._process_batch())
return await future
async def _process_batch(self):
"""Process accumulated requests as batch"""
self.processing = True
# Wait for batch to fill or timeout
await asyncio.sleep(self.max_wait_ms / 1000.0)
if not self.queue:
self.processing = False
return
# Collect batch
batch = []
futures = []
while self.queue and len(batch) < self.max_batch_size:
features, future = self.queue.popleft()
batch.append(features)
futures.append(future)
# Run batch inference
batch_array = np.array(batch)
predictions = self.model.predict(batch_array)
# Resolve futures
for future, pred in zip(futures, predictions):
future.set_result(pred)
self.processing = False
# Process remaining queue
if self.queue:
asyncio.create_task(self._process_batch())
# Usage
predictor = BatchingPredictor(model, max_batch_size=32, max_wait_ms=10)
async def handle_request(features):
prediction = await predictor.predict(features)
return prediction
Monitoring & Observability
Prediction Logging
import logging
from dataclasses import dataclass, asdict
from datetime import datetime
import json
@dataclass
class PredictionLog:
"""Log entry for each prediction"""
timestamp: str
model_version: str
features: list
prediction: float
confidence: float
latency_ms: float
user_id: str = None
class MonitoredModelServer:
"""
Model server with comprehensive monitoring
"""
def __init__(self, model, model_version):
self.model = model
self.model_version = model_version
# Setup logging
self.logger = logging.getLogger('model_server')
self.logger.setLevel(logging.INFO)
# Metrics
self.prediction_count = 0
self.latencies = []
self.error_count = 0
def predict(self, features, user_id=None):
"""
Make prediction with logging
Returns: (prediction, confidence, metadata)
"""
start_time = time.time()
try:
# Make prediction
prediction = self.model.predict([features])[0]
# Get confidence
if hasattr(self.model, 'predict_proba'):
proba = self.model.predict_proba([features])[0]
confidence = float(np.max(proba))
else:
confidence = 1.0
# Calculate latency
latency_ms = (time.time() - start_time) * 1000
# Log prediction
log_entry = PredictionLog(
timestamp=datetime.now().isoformat(),
model_version=self.model_version,
features=features,
prediction=float(prediction),
confidence=confidence,
latency_ms=latency_ms,
user_id=user_id
)
self.logger.info(json.dumps(asdict(log_entry)))
# Update metrics
self.prediction_count += 1
self.latencies.append(latency_ms)
return prediction, confidence, {'latency_ms': latency_ms}
except Exception as e:
self.error_count += 1
self.logger.error(f"Prediction failed: {str(e)}")
raise
def get_metrics(self):
"""Get serving metrics"""
if not self.latencies:
return {}
return {
'prediction_count': self.prediction_count,
'error_count': self.error_count,
'error_rate': self.error_count / max(self.prediction_count, 1),
'latency_p50': np.percentile(self.latencies, 50),
'latency_p95': np.percentile(self.latencies, 95),
'latency_p99': np.percentile(self.latencies, 99),
}
Connection to BST Validation (Day 8 DSA)
Model serving systems validate predictions similar to BST range checking:
class PredictionBoundsValidator:
"""
Validate predictions fall within expected ranges
Similar to BST validation with min/max bounds
"""
def __init__(self):
self.bounds = {} # feature → (min, max)
def set_bounds(self, feature_name, min_val, max_val):
"""Set validation bounds"""
self.bounds[feature_name] = (min_val, max_val)
def validate_input(self, features):
"""
Validate input features
Like BST range checking: each value must be in [min, max]
"""
violations = []
for feature_name, value in features.items():
if feature_name in self.bounds:
min_val, max_val = self.bounds[feature_name]
# Range check (like BST validation)
if value < min_val or value > max_val:
violations.append({
'feature': feature_name,
'value': value,
'bounds': (min_val, max_val)
})
return len(violations) == 0, violations
Advanced Serving Patterns
1. Shadow Mode Deployment
class ShadowModeServer:
"""
Run new model in shadow mode
New model receives traffic but doesn't affect users
Predictions are logged for comparison
"""
def __init__(self, production_model, shadow_model):
self.production_model = production_model
self.shadow_model = shadow_model
self.comparison_logs = []
def predict(self, features):
"""
Make predictions with both models
Returns: Production prediction (shadow runs async)
"""
import asyncio
# Production prediction (synchronous)
prod_prediction = self.production_model.predict(features)
# Shadow prediction (async, doesn't block)
asyncio.create_task(self._shadow_predict(features, prod_prediction))
return prod_prediction
async def _shadow_predict(self, features, prod_prediction):
"""Run shadow model and log comparison"""
try:
shadow_prediction = self.shadow_model.predict(features)
# Log comparison
self.comparison_logs.append({
'features': features,
'production': prod_prediction,
'shadow': shadow_prediction,
'difference': abs(prod_prediction - shadow_prediction)
})
except Exception as e:
print(f"Shadow prediction failed: {e}")
def get_shadow_metrics(self):
"""Analyze shadow model performance"""
if not self.comparison_logs:
return {}
differences = [log['difference'] for log in self.comparison_logs]
return {
'num_predictions': len(self.comparison_logs),
'mean_difference': np.mean(differences),
'max_difference': np.max(differences),
'agreement_rate': sum(1 for d in differences if d < 0.01) / len(differences)
}
# Usage
shadow_server = ShadowModeServer(
production_model=model_v1,
shadow_model=model_v2
)
# Normal serving
prediction = shadow_server.predict(features)
# Analyze shadow performance
metrics = shadow_server.get_shadow_metrics()
print(f"Shadow agreement rate: {metrics['agreement_rate']:.2%}")
2. Canary Deployment
class CanaryDeployment:
"""
Gradual rollout with automated rollback
Monitors metrics and automatically rolls back if issues detected
"""
def __init__(self, stable_model, canary_model):
self.stable_model = stable_model
self.canary_model = canary_model
self.canary_percentage = 0.0
self.metrics = {
'stable': {'errors': 0, 'predictions': 0, 'latencies': []},
'canary': {'errors': 0, 'predictions': 0, 'latencies': []}
}
def set_canary_percentage(self, percentage):
"""Set canary traffic percentage"""
assert 0 <= percentage <= 100
self.canary_percentage = percentage
print(f"Canary traffic: {percentage}%")
def predict(self, features, user_id=None):
"""
Predict with canary logic
Routes percentage of traffic to canary
"""
import random
import time
# Determine which model to use
use_canary = random.random() < (self.canary_percentage / 100)
model_name = 'canary' if use_canary else 'stable'
model = self.canary_model if use_canary else self.stable_model
# Make prediction with metrics
start_time = time.time()
try:
prediction = model.predict(features)
latency = time.time() - start_time
# Record metrics
self.metrics[model_name]['predictions'] += 1
self.metrics[model_name]['latencies'].append(latency)
return prediction, model_name
except Exception as e:
# Record error
self.metrics[model_name]['errors'] += 1
raise
def check_health(self):
"""
Check canary health
Returns: (is_healthy, should_rollback, reason)
"""
canary_metrics = self.metrics['canary']
stable_metrics = self.metrics['stable']
if canary_metrics['predictions'] < 100:
# Not enough data yet
return True, False, "Insufficient data"
# Calculate error rates
canary_error_rate = canary_metrics['errors'] / canary_metrics['predictions']
stable_error_rate = stable_metrics['errors'] / max(stable_metrics['predictions'], 1)
# Check if error rate is significantly higher
if canary_error_rate > stable_error_rate * 2:
return False, True, f"Error rate too high: {canary_error_rate:.2%}"
# Check latency
canary_p95 = np.percentile(canary_metrics['latencies'], 95)
stable_p95 = np.percentile(stable_metrics['latencies'], 95)
if canary_p95 > stable_p95 * 1.5:
return False, True, f"Latency too high: {canary_p95:.1f}ms"
return True, False, "Healthy"
def auto_rollout(self, target_percentage=100, step=10, check_interval=60):
"""
Automatically increase canary traffic
Rolls back if health checks fail
"""
current = 0
while current < target_percentage:
# Increase canary traffic
current = min(current + step, target_percentage)
self.set_canary_percentage(current)
# Wait and check health
time.sleep(check_interval)
is_healthy, should_rollback, reason = self.check_health()
if should_rollback:
print(f"❌ Rollback triggered: {reason}")
self.set_canary_percentage(0) # Rollback to stable
return False
print(f"✓ Health check passed at {current}%")
print(f"🎉 Canary rollout complete!")
return True
# Usage
canary = CanaryDeployment(stable_model=model_v1, canary_model=model_v2)
# Start with 5% traffic
canary.set_canary_percentage(5)
# Automatic gradual rollout
success = canary.auto_rollout(target_percentage=100, step=10, check_interval=300)
3. Multi-Armed Bandit Serving
class BanditModelServer:
"""
Multi-armed bandit for model selection
Dynamically allocates traffic based on performance
"""
def __init__(self, models: dict):
"""
Args:
models: Dict of {model_name: model}
"""
self.models = models
self.rewards = {name: [] for name in models.keys()}
self.counts = {name: 0 for name in models.keys()}
self.epsilon = 0.1 # Exploration rate
def select_model(self):
"""
Select model using epsilon-greedy strategy
Returns: model_name
"""
import random
# Explore: random selection
if random.random() < self.epsilon:
return random.choice(list(self.models.keys()))
# Exploit: select best performing model
avg_rewards = {
name: np.mean(rewards) if rewards else 0
for name, rewards in self.rewards.items()
}
return max(avg_rewards, key=avg_rewards.get)
def predict(self, features, true_label=None):
"""
Make prediction and optionally update rewards
Args:
features: Input features
true_label: Optional ground truth for reward
Returns: (prediction, model_used)
"""
# Select model
model_name = self.select_model()
model = self.models[model_name]
# Make prediction
prediction = model.predict(features)
self.counts[model_name] += 1
# Update reward if ground truth available
if true_label is not None:
reward = 1.0 if prediction == true_label else 0.0
self.rewards[model_name].append(reward)
return prediction, model_name
def get_model_stats(self):
"""Get statistics for each model"""
stats = {}
for name in self.models.keys():
if self.rewards[name]:
stats[name] = {
'count': self.counts[name],
'avg_reward': np.mean(self.rewards[name]),
'selection_rate': self.counts[name] / sum(self.counts.values())
}
else:
stats[name] = {
'count': self.counts[name],
'avg_reward': 0,
'selection_rate': 0
}
return stats
# Usage
bandit = BanditModelServer({
'model_a': model_a,
'model_b': model_b,
'model_c': model_c
})
# Serve with automatic optimization
for features, label in data_stream:
prediction, model_used = bandit.predict(features, true_label=label)
# Check which model performs best
stats = bandit.get_model_stats()
for name, stat in stats.items():
print(f"{name}: {stat['avg_reward']:.2%} accuracy, {stat['selection_rate']:.1%} traffic")
Infrastructure & Deployment
Containerized Serving with Docker
# Dockerfile for model serving
FROM python:3.9-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy model and code
COPY model.pkl .
COPY serve.py .
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=3s \
CMD curl -f http://localhost:8000/health || exit 1
# Run server
CMD ["uvicorn", "serve:app", "--host", "0.0.0.0", "--port", "8000"]
# docker-compose.yml
version: '3.8'
services:
model-server:
build: .
ports:
- "8000:8000"
environment:
- MODEL_PATH=/app/model.pkl
- LOG_LEVEL=INFO
volumes:
- ./models:/app/models
deploy:
replicas: 3
resources:
limits:
cpus: '2'
memory: 4G
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
load-balancer:
image: nginx:alpine
ports:
- "80:80"
volumes:
- ./nginx.conf:/etc/nginx/nginx.conf
depends_on:
- model-server
Kubernetes Deployment
# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: model-serving
spec:
replicas: 5
selector:
matchLabels:
app: model-serving
template:
metadata:
labels:
app: model-serving
version: v1
spec:
containers:
- name: model-server
image: your-registry/model-serving:v1
ports:
- containerPort: 8000
env:
- name: MODEL_VERSION
value: "v1.0"
resources:
requests:
memory: "2Gi"
cpu: "1000m"
limits:
memory: "4Gi"
cpu: "2000m"
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
readinessProbe:
httpGet:
path: /ready
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
name: model-serving-service
spec:
selector:
app: model-serving
ports:
- protocol: TCP
port: 80
targetPort: 8000
type: LoadBalancer
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: model-serving-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: model-serving
minReplicas: 3
maxReplicas: 20
metrics:
- type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
- type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: 80
Feature Store Integration
class ModelServerWithFeatureStore:
"""
Model server integrated with feature store
Fetches features on-demand for prediction
"""
def __init__(self, model, feature_store):
self.model = model
self.feature_store = feature_store
def predict_from_entity_id(self, entity_id: str):
"""
Make prediction given entity ID
Fetches features from feature store
Args:
entity_id: ID to fetch features for
Returns: Prediction
"""
# Fetch features from feature store
features = self.feature_store.get_online_features(
entity_id=entity_id,
feature_names=[
'user_age',
'user_income',
'user_num_purchases_30d',
'user_avg_purchase_amount'
]
)
# Convert to array
feature_array = [
features['user_age'],
features['user_income'],
features['user_num_purchases_30d'],
features['user_avg_purchase_amount']
]
# Make prediction
prediction = self.model.predict([feature_array])[0]
return {
'entity_id': entity_id,
'prediction': float(prediction),
'features_used': features
}
# Usage with caching
from functools import lru_cache
class CachedFeatureStore:
"""Feature store with caching"""
def __init__(self, backend):
self.backend = backend
@lru_cache(maxsize=10000)
def get_online_features(self, entity_id, feature_names):
"""Cached feature retrieval"""
return self.backend.get_features(entity_id, feature_names)
Cost Optimization
1. Request Batching for Cost Reduction
class CostOptimizedServer:
"""
Optimize costs by batching and caching
Reduces number of model invocations
"""
def __init__(self, model, batch_wait_ms=50, batch_size=32):
self.model = model
self.batch_wait_ms = batch_wait_ms
self.batch_size = batch_size
self.pending_requests = []
self.cache = {}
self.stats = {
'cache_hits': 0,
'cache_misses': 0,
'batches_processed': 0,
'cost_saved': 0
}
async def predict_with_caching(self, features, cache_key=None):
"""
Predict with caching
Args:
features: Input features
cache_key: Optional cache key
Returns: Prediction
"""
# Check cache
if cache_key and cache_key in self.cache:
self.stats['cache_hits'] += 1
return self.cache[cache_key]
self.stats['cache_misses'] += 1
# Add to batch
future = asyncio.Future()
self.pending_requests.append((features, future, cache_key))
# Trigger batch processing if needed
if len(self.pending_requests) >= self.batch_size:
await self._process_batch()
return await future
async def _process_batch(self):
"""Process accumulated requests as batch"""
if not self.pending_requests:
return
# Extract batch
batch_features = [req[0] for req in self.pending_requests]
futures = [req[1] for req in self.pending_requests]
cache_keys = [req[2] for req in self.pending_requests]
# Run batch inference
predictions = self.model.predict(batch_features)
self.stats['batches_processed'] += 1
# Distribute results
for pred, future, cache_key in zip(predictions, futures, cache_keys):
# Cache result
if cache_key:
self.cache[cache_key] = pred
# Resolve future
future.set_result(pred)
# Clear requests
self.pending_requests = []
# Calculate cost savings (batching is cheaper)
cost_per_single_request = 0.001 # $0.001 per request
cost_per_batch = 0.010 # $0.01 per batch
savings = (len(predictions) * cost_per_single_request) - cost_per_batch
self.stats['cost_saved'] += savings
def get_cost_stats(self):
"""Get cost optimization statistics"""
total_requests = self.stats['cache_hits'] + self.stats['cache_misses']
return {
'total_requests': total_requests,
'cache_hit_rate': self.stats['cache_hits'] / max(total_requests, 1),
'batches_processed': self.stats['batches_processed'],
'avg_batch_size': total_requests / max(self.stats['batches_processed'], 1),
'estimated_cost_saved': self.stats['cost_saved']
}
2. Model Compression for Cheaper Hosting
import torch
def compress_model_for_deployment(model, sample_input):
"""
Compress model for cheaper hosting
Techniques:
- Quantization (INT8)
- Pruning
- Knowledge distillation
Returns: Compressed model
"""
# 1. Quantization
model.eval()
model_quantized = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8
)
# 2. Pruning (remove small weights)
import torch.nn.utils.prune as prune
for name, module in model_quantized.named_modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.3)
# 3. Verify accuracy
with torch.no_grad():
original_output = model(sample_input)
compressed_output = model_quantized(sample_input)
diff = torch.abs(original_output - compressed_output).mean()
print(f"Compression error: {diff:.4f}")
return model_quantized
# Compare costs
original_size_mb = get_model_size(model)
compressed_size_mb = get_model_size(compressed_model)
print(f"Size reduction: {original_size_mb:.1f}MB → {compressed_size_mb:.1f}MB")
print(f"Cost savings: ~${(original_size_mb - compressed_size_mb) * 0.10:.2f}/month")
Troubleshooting & Debugging
Prediction Debugging
class DebuggableModelServer:
"""
Model server with debugging capabilities
Helps diagnose prediction issues
"""
def __init__(self, model):
self.model = model
def predict_with_debug(self, features, debug=False):
"""
Make prediction with optional debug info
Returns: (prediction, debug_info)
"""
debug_info = {}
if debug:
# Record input stats
debug_info['input_stats'] = {
'mean': np.mean(features),
'std': np.std(features),
'min': np.min(features),
'max': np.max(features),
'nan_count': np.isnan(features).sum()
}
# Check for anomalies
debug_info['anomalies'] = self._detect_anomalies(features)
# Make prediction
prediction = self.model.predict([features])[0]
if debug:
# Record prediction confidence
if hasattr(self.model, 'predict_proba'):
proba = self.model.predict_proba([features])[0]
debug_info['confidence'] = float(np.max(proba))
debug_info['class_probabilities'] = proba.tolist()
return prediction, debug_info
def _detect_anomalies(self, features):
"""Detect input anomalies"""
anomalies = []
# Check for NaN
if np.any(np.isnan(features)):
anomalies.append("Contains NaN values")
# Check for extreme values
z_scores = np.abs((features - np.mean(features)) / (np.std(features) + 1e-8))
if np.any(z_scores > 3):
anomalies.append("Contains outliers (z-score > 3)")
return anomalies
def explain_prediction(self, features):
"""
Explain prediction using SHAP or similar
Returns: Feature importance
"""
# Simplified explanation (in practice, use SHAP)
if hasattr(self.model, 'feature_importances_'):
importances = self.model.feature_importances_
return {
f'feature_{i}': {'value': features[i], 'importance': imp}
for i, imp in enumerate(importances)
}
return {}
Key Takeaways
✅ Multiple serving patterns - REST, gRPC, batch for different needs
✅ Model versioning essential - Support A/B testing and rollbacks
✅ Optimize for latency - Quantization, batching, caching
✅ Monitor everything - Latency, errors, prediction distribution
✅ Validate inputs/outputs - Catch issues early
✅ Scale horizontally - Add more serving instances
✅ Connection to validation - Like BST range checking
Originally published at: arunbaby.com/ml-system-design/0008-model-serving-architecture
If you found this helpful, consider sharing it with others who might benefit.