Graph-based Recommendation Systems
“Leveraging the connection structure to predict what users will love.”
1. Why Graph-based Recommendations?
Traditional recommender systems use user-item matrices. Graph-based systems model the entire interaction network.
Example: Social Media
Users: Alice, Bob, Charlie
Items: Post1, Post2, Post3
Graph:
Alice --likes--> Post1 <--likes-- Bob
| |
+--follows--> Charlie
|
created
|
Post2
Advantages of Graphs:
- Richer Context: Capture multi-hop relationships (friend-of-friend recommendations).
- Heterogeneous: Mix users, items, tags, locations in one graph.
- Explainability: “We recommend this post because your friend Bob liked it.”
- Cold Start: New users can benefit from their social connections.
2. Graph Representation
Homogeneous Graph
All nodes and edges are the same type. Example: Friendship network (all nodes are users, all edges are “friends”).
Heterogeneous Graph
Multiple node/edge types. Example: E-commerce
- Nodes: Users, Products, Brands, Categories
- Edges: User –bought–> Product, Product –belongs_to–> Category
Bipartite Graph
Two types of nodes with edges only between different types. Example: User-Item interactions.
Adjacency Matrix:
\[
A_{ij} =
\begin{cases}
1 & \text{if user } i \text{ interacted with item } j
0 & \text{otherwise}
\end{cases}
\]
3. Traditional Graph-Based Approaches
Approach 1: Collaborative Filtering on Graphs
Idea: If users A and B both liked items X and Y, recommend to A what B liked but A hasn’t seen.
Graph Random Walk:
- Start at user node.
- Walk to liked items.
- Walk to other users who liked those items.
- Walk to items those users liked.
- Recommend items with highest visit frequency.
def personalized_pagerank(graph, user_node, damping=0.85, iterations=100):
scores = {node: 0 for node in graph.nodes}
scores[user_node] = 1.0
for _ in range(iterations):
new_scores = {node: 0 for node in graph.nodes}
for node in graph.nodes:
for neighbor in graph.neighbors(node):
new_scores[neighbor] += damping * scores[node] / len(list(graph.neighbors(node)))
new_scores[node] += (1 - damping) if node == user_node else 0
scores = new_scores
return scores
# Recommend top-K items with highest scores
Time Complexity: \(O(I \cdot E)\) where I is iterations and E is number of edges.
Approach 2: Node2Vec
Idea: Learn node embeddings by treating random walks as “sentences” and applying Skip-Gram (Word2Vec).
Algorithm:
- Generate random walks starting from each node.
- Treat walks as sentences:
[UserA, Item1, UserB, Item3, ...]. - Train Skip-Gram to predict context nodes given target node.
from node2vec import Node2Vec
# Generate walks
walks = Node2Vec(graph, dimensions=128, walk_length=80, num_walks=10, workers=4)
# Train Skip-Gram
model = walks.fit(window=10, min_count=1, batch_words=4)
# Get embeddings
user_embedding = model.wv['UserA']
item_embedding = model.wv['Item1']
# Recommend by cosine similarity
recommended_items = model.wv.most_similar('UserA', topn=10)
Pros:
- Simple and effective.
- Works on any graph.
Cons:
- Doesn’t use node features (only structure).
- Expensive for large graphs (millions of walks).
4. Graph Neural Networks (GNNs)
Core Idea: Aggregate information from neighbors to update node representations.
Message Passing Framework
General Form: \[ h_v^{(k+1)} = \text{UPDATE}\left(h_v^{(k)}, \text{AGGREGATE}({h_u^{(k)} : u \in \mathcal{N}(v)})\right) \]
- \(h_v^{(k)}\): Representation of node \(v\) at layer \(k\).
- \(\mathcal{N}(v)\): Neighbors of \(v\).
After K layers: Node \(v\) has aggregated information from \(K\)-hop neighbors.
Graph Convolutional Network (GCN)
Update Rule: \[ H^{(k+1)} = \sigma\left(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(k)} W^{(k)}\right) \]
- \(\tilde{A} = A + I\) (adjacency matrix + self-loops).
- \(\tilde{D}\): Degree matrix of \(\tilde{A}\).
- \(W^{(k)}\): Learnable weight matrix.
Implementation:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
class GCNRecommender(nn.Module):
def __init__(self, num_users, num_items, embedding_dim=128):
super().__init__()
self.user_embedding = nn.Embedding(num_users, embedding_dim)
self.item_embedding = nn.Embedding(num_items, embedding_dim)
self.conv1 = GCNConv(embedding_dim, 256)
self.conv2 = GCNConv(256, 128)
def forward(self, edge_index):
# edge_index: [2, num_edges] (source and target nodes)
# Initialize embeddings
x = torch.cat([self.user_embedding.weight, self.item_embedding.weight], dim=0)
# Message passing
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return x # [num_nodes, 128]
# Predict interaction score
user_emb = embeddings[user_id]
item_emb = embeddings[item_id]
score = torch.dot(user_emb, item_emb)
GraphSAGE (Sampling and Aggregation)
Problem: GCN needs the full adjacency matrix (doesn’t scale to billions of edges).
Solution: Sample a fixed number of neighbors.
Algorithm:
class GraphSAGE(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.linear = nn.Linear(in_dim * 2, hidden_dim) # Concat self + aggregated
def forward(self, x, edge_index, num_samples=10):
# x: [num_nodes, in_dim]
# edge_index: [2, num_edges]
aggregated = []
for node in range(x.size(0)):
# Sample neighbors
neighbors = edge_index[1][edge_index[0] == node]
if len(neighbors) > num_samples:
neighbors = neighbors[torch.randperm(len(neighbors))[:num_samples]]
# Aggregate (mean pooling)
neighbor_embs = x[neighbors]
agg = neighbor_embs.mean(dim=0)
aggregated.append(agg)
aggregated = torch.stack(aggregated)
# Concat self + aggregated
combined = torch.cat([x, aggregated], dim=1)
output = F.relu(self.linear(combined))
return output
Benefit: \(O(K \cdot S \cdot D)\) where K = layers, S = samples per node, D = embedding dim (independent of graph size!).
5. Pinterest’s PinSage
PinSage is the largest-scale GNN in production (3B nodes, 18B edges).
Key Innovations:
- Importance-based Sampling: Sample neighbors with highest visit frequency (from random walks).
- Hard Negative Mining: For each positive interaction, sample negative items similar to the positive (harder to distinguish).
- Multi-GPU Training: Distribute graph across GPUs.
- MapReduce Inference: Precompute embeddings offline using Spark.
Architecture:
Input: Pin features (image, text) + Graph structure
|
v
3 layers of GraphSAGE (neighbor sampling)
|
v
Pin Embedding (256-dim)
|
v
Cosine Similarity → Recommendations
Results:
- Offline: +40% recall@100 over baseline.
- Online: +20% engagement (repins, clicks).
6. LinkedIn’s Skills Graph
LinkedIn models Users, Jobs, Skills, Companies as a heterogeneous graph.
Example Query: “Recommend jobs for a user with skills in Python, ML.”
Solution: Meta-Path-Based Random Walk
- Meta-path: User –has_skill–> Skill <–requires– Job
- Walk:
[UserA, Python, Job1, ML, UserB, Scala, Job2] - Recommend jobs that appear frequently in walks starting from UserA.
Heterogeneous GNN:
class HeteroGNN(nn.Module):
def __init__(self):
self.user_conv = GCNConv(128, 256)
self.job_conv = GCNConv(128, 256)
self.skill_conv = GCNConv(128, 256)
def forward(self, user_features, job_features, skill_features, edges):
# edges: {('user', 'has_skill', 'skill'): edge_index, ...}
user_emb = self.user_conv(user_features, edges[('user', 'has_skill', 'skill')])
job_emb = self.job_conv(job_features, edges[('job', 'requires', 'skill')])
skill_emb = self.skill_conv(skill_features, edges[('user', 'has_skill', 'skill')])
return user_emb, job_emb, skill_emb
Deep Dive: Training at Scale (Billion-Edge Graphs)
Challenge: Graph doesn’t fit in GPU memory.
Solution 1: Mini-Batch Training with Neighbor Sampling
Cluster-GCN:
- Partition the graph into clusters (Louvain algorithm).
- Sample a batch of clusters.
- Train GNN on subgraph induced by those clusters.
Benefit: Each mini-batch is a small, densely connected subgraph.
Solution 2: Distributed Training
DistDGL (Distributed Deep Graph Library):
- Graph Store: Distributed across multiple machines (sharded by node ID).
- Sampling: Each worker samples locally and fetches remote neighbors via RPC.
- Aggregation: Use MPI All-Reduce to aggregate gradients.
Scalability: Trains on graphs with 100B+ edges (Alibaba’s product graph).
Deep Dive: Cold Start with Side Information
Problem: New user has no interactions.
Solution: Use Content Features
class HybridGNN(nn.Module):
def __init__(self):
self.text_encoder = BERTModel() # Encode user bio, item description
self.image_encoder = ResNet() # Encode user profile pic, item image
self.gnn = GraphSAGE()
def forward(self, text, image, edge_index):
text_emb = self.text_encoder(text)
image_emb = self.image_encoder(image)
# Initial embedding = concat(text, image)
x_init = torch.cat([text_emb, image_emb], dim=1)
# Message passing
x_final = self.gnn(x_init, edge_index)
return x_final
For new users: Use \(x_{\text{init}}\) directly (no graph info yet).
Deep Dive: Temporal Graphs (Dynamic Recommendations)
Problem: User preferences change over time.
Solution: Temporal GNN
class TemporalGNN(nn.Module):
def __init__(self):
self.gru = nn.GRU(input_size=128, hidden_size=256)
self.gnn = GraphSAGE()
def forward(self, snapshots):
# snapshots: List of (features, edge_index) at different timestamps
h_t = None
for features, edge_index in snapshots:
x = self.gnn(features, edge_index)
x, h_t = self.gru(x.unsqueeze(0), h_t)
return x # Final embedding incorporates temporal dynamics
Use Case: Reddit recommending trending posts (graph changes every minute).
Deep Dive: Knowledge Graph Embeddings (TransE, DistMult)
Knowledge Graph: Entities and Relations. Example:
(Python, is_a, Programming Language)
(TensorFlow, used_for, Deep Learning)
(Alice, knows, Python)
TransE: Embed entities and relations in the same space. \[ h + r \approx t \] where \(h\) = head entity, \(r\) = relation, \(t\) = tail entity.
Loss: \[ \mathcal{L} = \sum_{(h, r, t) \in \mathcal{T}} \max(0, \gamma + d(h + r, t) - d(h’ + r, t’)) \] where \((h’, r, t’)\) is a negative sample.
Application: Amazon’s product knowledge graph for recommendations.
Deep Dive: Graph Augmentation for Robustness
Problem: Sparse graphs lead to poor embeddings.
Solutions:
- Edge Dropout: Randomly remove edges during training (forces model to not rely on single edges).
- Node Mixup: Interpolate between node features: \(x_{\text{mix}} = \lambda x_i + (1 - \lambda) x_j\).
- Virtual Nodes: Add a global node connected to all nodes (helps with long-range dependencies).
def graph_augmentation(edge_index, drop_rate=0.1):
num_edges = edge_index.size(1)
mask = torch.rand(num_edges) > drop_rate
return edge_index[:, mask]
Deep Dive: Explainability with GNN (GNNExplainer)
Problem: Why did the model recommend Item X to User Y?
GNNExplainer: Find the minimal subgraph that most influences the prediction.
Algorithm:
- Given a node \(v\) and prediction \(y\), find a subgraph \(G_S\).
-
Maximize \(MI(Y, G_S) = H(Y) - H(Y G = G_S)\) (mutual information). - Optimize via gradient descent with edge mask.
Output: “We recommended this movie because you liked these 3 similar movies (highlighted subgraph).”
Deep Dive: Negative Sampling Strategies
Problem: For each positive interaction (User –likes–> Item), we need negatives.
Strategies:
- Random: Sample random items (easy negatives, model learns quickly but not well).
- Popularity-based: Sample popular items (harder, but can bias toward popular items).
- Hard Negatives: Sample items similar to the positive item (e.g., using k-NN on item embeddings).
Dynamic Hard Negative Mining:
# During training
positive_items = batch['items']
positive_embs = item_embeddings[positive_items]
# Find K nearest items in embedding space
hard_negatives = faiss_index.search(positive_embs, K)
loss = bpr_loss(user_emb, positive_embs, item_embeddings[hard_negatives])
Deep Dive: Fairness in Graph-based Recommendations
Problem: Graph structure can encode bias (e.g., popular items get more exposure).
Metrics:
- Exposure Fairness: Items with equal quality should get equal exposure.
- Demographic Parity: Recommendations should be similar across demographic groups.
Debiasing:
- Re-weighting: Upweight interactions with underrepresented items.
- Adversarial Training: Train a discriminator to predict user demographics from embeddings. Maximize recommendation loss, minimize discriminator accuracy.
class FairGNN(nn.Module):
def forward(self, x, edge_index):
emb = self.gnn(x, edge_index)
# Recommendation loss
rec_loss = self.recommendation_loss(emb)
# Fairness loss (fool the discriminator)
demographics_pred = self.discriminator(emb)
fair_loss = -self.discriminator_loss(demographics_pred, true_demographics)
total_loss = rec_loss + lambda * fair_loss
return total_loss
Deep Dive: Graph-based Bandits (Exploration vs. Exploitation)
Problem: Should we recommend popular items (exploitation) or explore new items?
LinUCB with Graphs: \[ \text{Score}(item) = \theta^T x_{item} + \alpha \sqrt{x_{item}^T A^{-1} x_{item}} \] where the second term is the uncertainty (exploration bonus).
Graph Extension: Use GNN to compute \(x_{\text{item}}\) (includes neighborhood information).
Implementation: Full GNN Recommender
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data
class GraphRecommender(nn.Module):
def __init__(self, num_users, num_items, embedding_dim=128, hidden_dim=256):
super().__init__()
self.num_users = num_users
# Initial embeddings
self.user_embedding = nn.Embedding(num_users, embedding_dim)
self.item_embedding = nn.Embedding(num_items, embedding_dim)
# GNN layers
self.conv1 = SAGEConv(embedding_dim, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, embedding_dim)
def forward(self, edge_index):
# Concat user and item embeddings
x = torch.cat([self.user_embedding.weight, self.item_embedding.weight], dim=0)
# Message passing
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x
def predict(self, user_emb, item_emb):
# Dot product
return (user_emb * item_emb).sum(dim=1)
# Training
def train(model, data, optimizer, num_epochs=100):
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
# Forward pass
embeddings = model(data.edge_index)
# BPR Loss (Bayesian Personalized Ranking)
user_embs = embeddings[data.pos_edges[0]]
pos_item_embs = embeddings[data.pos_edges[1]]
neg_item_embs = embeddings[data.neg_edges]
pos_scores = model.predict(user_embs, pos_item_embs)
neg_scores = model.predict(user_embs, neg_item_embs)
loss = -torch.log(torch.sigmoid(pos_scores - neg_scores)).mean()
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
# Inference
@torch.no_grad()
def recommend(model, user_id, top_k=10):
model.eval()
embeddings = model(data.edge_index)
user_emb = embeddings[user_id]
item_embs = embeddings[model.num_users:] # All items
scores = model.predict(user_emb.unsqueeze(0), item_embs)
top_items = scores.argsort(descending=True)[:top_k]
return top_items
Top Interview Questions
Q1: How do you handle graphs that don’t fit in memory? Answer: Use neighbor sampling (GraphSAGE) to limit the number of neighbors aggregated. Use distributed training (DistDGL) to shard the graph across machines. For inference, precompute embeddings offline.
Q2: GNNs vs. Matrix Factorization: when to use which? Answer:
- Matrix Factorization: Simpler, faster, works well if you only care about direct user-item interactions.
- GNNs: Better when you have rich graph structure (social connections, item similarities, multi-hop relationships).
Q3: How do you evaluate graph-based recommenders? Answer:
- Offline: Recall@K, NDCG@K, Hit Rate.
- Online: A/B test (CTR, engagement).
- Graph-specific: Coverage (% of items recommended), Diversity (how different are recommended items).
Q4: How do you handle new users/items (cold start)? Answer: Use content features (text, images) in addition to graph structure. For new items with no interactions, compute initial embedding from content. As interactions occur, refine embedding via GNN.
Key Takeaways
- Graphs Capture Structure: Use connections (social, similarity) for better recommendations.
- GNNs are SOTA: Message passing aggregates multi-hop information.
- Scalability Challenges: Use sampling (GraphSAGE) and distributed training (DistDGL).
- Real-World Systems: Pinterest (PinSAGE), LinkedIn (Skills Graph), Alibaba (Product Graph).
- Hybrid Approaches: Combine graph structure + content features for cold start robustness.
Summary
| Aspect | Insight |
|---|---|
| Core Idea | Aggregate neighbor information to learn node embeddings |
| Key Architectures | GCN, GraphSAGE, GAT, PinSage |
| Challenges | Scalability, cold start, fairness |
| Applications | Social media, e-commerce, job recommendations |
Originally published at: arunbaby.com/ml-system-design/0030-graph-based-recommendations
If you found this helpful, consider sharing it with others who might benefit.