Transfer Learning Systems
“Standing on the shoulders of giants isn’t just a metaphor—it’s an engineering requirement.”
1. Problem Statement
In the modern ML landscape, training a model from scratch is rarely the answer.
- Cost: Training GPT-3 costs ~$4M-$12M.
- Data: You rarely have the billions of tokens needed for pre-training.
- Efficiency: Why relearn “grammar” or “edges” when open-source models already know them?
The System Design Problem: Design a platform that allows enterprise customers to build custom classifiers (e.g., “Legal Document Classifier”, “Spam Detector”, “Code Reviewer”) using Transfer Learning, while minimizing:
- Training Time: Max 1 hour per model.
- Inference Latency: <50ms.
- Storage Cost: We cannot store 100 copies of a 100GB model (10TB) just for 100 customers.
2. Understanding the Requirements
2.1 The Math of Efficiency
Transfer Learning (TL) works by taking a Teacher Model (e.g., BERT, ResNet) trained on a massive generic dataset (Wikipedia, ImageNet) and adapting it to a specific task.
There are three main flavors, each with detailed system implications:
- Feature Extraction (Frozen Backbone):
- Run input through the Backone (BERT).
- Get the vector (embedding).
- Train a tiny Logistic Regression/MLP on top.
- System: Very cheap training. Sharing the backbone at inference is easy.
- Full Fine-Tuning:
- Unfreeze all weights. Update everything.
- System: Expensive training. Result is a completely new 500MB file. Hard to multi-tenant.
- Parameter-Efficient Fine-Tuning (PEFT/Adapters):
- Inject tiny trainable layers (LoRA adapters) into the frozen backbone.
- System: The “Holy Grail”. Only store 10MB diffs.
2.2 Scale Constraints
- Base Models: BERT-Large (340M params), ViT (Vision Transformer).
- Throughput: 10,000 requests/sec aggregate.
- Tenancy: 1,000+ distinct customer models active simultaneously.
3. High-Level Architecture
We need a layered architecture that separates the “Heavy Lifting” (Base Models) from the “Specific Logic” (Heads/Adapters).
[Request: {text: "Sue him!", model_id: "client_A_legal"}]
|
v
[Load Balancer / Gateway]
|
v
[Model Serving Layer (Ray Serve / TorchServe)]
|
+----+--------------------------------+
| GPU Worker (Shared Inference) |
| |
| [ Base Model (BERT) - Frozen ] | <-- Loaded Once (VRAM: 2GB)
| | |
| v |
| [ Adapter Controller ] |
| / | \ |
| [LoRA A] [LoRA B] [LoRA C] | <-- Swapped Dynamically
| (Client A)(Client B)(Client C) | (VRAM: 10MB each)
+-------------------------------------+
4. Component Deep-Dives
4.1 The Model Registry
This isn’t just an S3 bucket. It’s a versioned graph.
- Parent:
bert-base-uncased(SHA256:a8d2...) - Child:
client_A_v1(Delta Weights + Config) -> Refers to Parenta8d2...
When a worker starts, it pulls the Parent. When a request comes for client_A, it hot-loads the Child weights.
4.2 The Adapter Controller
This is the specialized software component (often custom C++/CUDA).
- Function: Matrix Multiplication routing.
- Instead of
Y = W * X(Standard Linear), it computesY = W * X + (A * B) * Xwhere A and B are the low-rank adapter matrices. - Optimization: Because
AandBare small, we can batch requests from different clients together!- Request 1 (Client A), Request 2 (Client B) -> Both run through BERT backbone together (Batch Size 2).
- At the specific layer, split the computation or use specialized CUDA kernels (like LoRA-Serving or vLLM) to apply per-row adaptors.
5. Data Flow
- Ingestion: User uploads 1,000 labeled examples (Instruction Tuning data).
- Preprocessing: Data is tokenized using the Base Model’s tokenizer. (Crucial: Adapters can’t change the tokenizer).
- Training (ephemeral):
- Spin up a spot GPU instance.
- Load Base Model (Frozen).
- Attach new Adapter layers.
- Train for 5 epochs (takes ~10 mins).
- Extract only the Adapter weights.
- Save to Registry (Size: 5MB).
- Inference:
- Worker loads Adapter weights into Host RAM.
- On request, moves weights to GPU Cache (if not present).
- Executes forward pass.
6. Scaling Strategies
6.1 Multi-Tenancy (The “Cold Start” problem)
If a user hasn’t sent a request in 24 hours, we offload their adapter from GPU VRAM. When they return:
- Model Load Time: Loading a 500MB Fine-Tuned model takes 5-10 seconds. (Too slow).
- Adapter Load Time: Loading 5MB LoRA weights takes 50 milliseconds. (Acceptable). Conclusion: PEFT/Adapters enable “Serverless” feel for LLMs.
6.2 Caching Strategy
CACHE heavily on:
- Embeddings: If using Feature Extraction, cache the output of the backbone. If the same email text is classified by 5 different distinct classifiers (Spam, Urgent, Sales, HR, Legal), compute the BERT embedding once, then run the 5 lightweight heads.
7. Implementation: LoRA Injection Conceptual Code
import torch.nn as nn
class LoRALayer(nn.Module):
def __init__(self, original_layer, rank=8):
super().__init__()
self.original_layer = original_layer # Frozen
self.rank = rank
# Low Rank Adaption matrices
in_dim = original_layer.in_features
out_dim = original_layer.out_features
# A: Gaussian init, B: Zero init
self.lora_A = nn.Parameter(torch.randn(in_dim, rank))
self.lora_B = nn.Parameter(torch.zeros(rank, out_dim))
def forward(self, x):
# Path 1: Frozen backbone
# The original weights don't change
base_out = self.original_layer(x)
# Path 2: Trainable low-rank branch
# B @ A is a low-rank matrix that approximates the weight update delta
adapter_out = (x @ self.lora_A) @ self.lora_B
return base_out + adapter_out
# Usage in System
# To switch clients, we only update lora_A and lora_B
def switch_context(model, new_client_weights):
model.layer1.lora_A.data = new_client_weights['params_A']
model.layer1.lora_B.data = new_client_weights['params_B']
8. Monitoring & Metrics
In a Transfer Learning system, typical metrics (Latency, Error Rate) aren’t enough. We need Drift Detection.
- Base Model Drift: Does the underlying pre-training data (Wikipedia 2020) still represent the world? (e.g., “Covid” meaning change).
- Task Drift: Is the customer’s definition of “Spam” changing?
- Correlation: If the Base Model is updated/patched (e.g., security fix), it might break all 1,000 child adapters. We need rigorous regression testing of the Base Model before upgrades.
9. Failure Modes
- Catastrophic Forgetting: (Less relevant here since base is frozen, but crucial if full fine-tuning).
- Tokenizer Mismatch: User uploads data with Emojis. BERT tokenizer ignores them (mapping to
[UNK]). Classifier performs poorly.- Mitigation: Automated Data Validation step in the pipeline that warns users about
% [UNK]tokens.
- Mitigation: Automated Data Validation step in the pipeline that warns users about
- Noisy Neighbors: One client sends batch size 128 request, starving the GPU compute for the other 50 clients sharing the backbone.
- Mitigation: Strict semaphore/queue management on the GPU worker.
10. Real-World Case Study: Hugging Face Inference API
Hugging Face hosts >100,000 models. They don’t have 100,000 GPUs constantly hot. They use Shared Backbones aggressively.
- If you request
bert-base-finetuned-squad, they might route you to a generic BERT fleet and apply the distinction (if architecture permits) or use rapid model swapping. - For LLMs (Llama-2), they use LoRA Exchange (LoRAX) servers allowing one GPU to serve 100s of specialized adapters.
11. Cost Analysis
Scenario: 10 distinct specialized models (7B params each).
Option A: Dedicated Instances (Full Fine-Tune)
- 10 x A10 GPUs ($1.50/hr).
- Cost: $15/hr.
- Utilization: Low (each model sits idle mostly).
Option B: Adapter Serving
- 1 x A10 GPU ($1.50/hr).
- Base Model (14GB VRAM).
- 10 Adapters (200MB VRAM).
- Total VRAM: ~15GB. Fits on one card.
- Cost: $1.50/hr.
- Savings: 90%.
12. Key Takeaways
- Don’t Retrain, Adapt: Creating new weights from scratch is practically illegal in modern engineering. Use Transfer Learning.
- Freeze the Backbone: This enables caching, storage savings, and multi-tenant serving.
- PEFT is King: Techniques like LoRA aren’t just “research hacks”; they are fundamental cloud infrastructure enablers that separate compute (Backbone) from logic (Adapter).
- System Design mirrors Model Design: The mathematical layers of a Neural Net (Frozen vs Trainable) directly dictate the microservices architecture (Shared Fleet vs Dedicated).
Originally published at: arunbaby.com/ml-system-design/0046-transfer-learning
If you found this helpful, consider sharing it with others who might benefit.