Graph Algorithms and GNNs
Reading time: ~45 min · Interview relevance: Very High · Target roles: ML Engineer, Graph ML Engineer, RecSys Engineer
The Day the Recommendation Engine Died
It was a Tuesday morning at a major e-commerce company. The real-time recommendation system - responsible for 30% of revenue - had been serving stale results for six hours. The on-call engineer traced the issue to a graph database query that was supposed to find "users who bought items similar to items bought by users similar to you." The query was timing out.
The underlying algorithm was a naive nested-loop approach to finding second-degree neighbors in a graph of 200 million users and 50 million products. Nobody had thought about the graph structure when they wrote it. In a dense neighborhood - a popular product category like consumer electronics - the query was visiting billions of nodes before timing out.
The fix took 10 minutes once the engineer understood graph algorithms: switch from exhaustive BFS to a sampling-based approach, add a sparse CSR matrix representation, and cap neighborhood expansion at two hops with a maximum fan-out of 25. Revenue recovered within the hour. But the lesson was permanent: if you work on systems that involve relationships between entities, you need to understand graphs deeply.
This is increasingly most of ML. Knowledge graphs power entity linking in LLMs. Social graphs drive feed ranking at every major platform. Molecular graphs define protein structure prediction - AlphaFold uses attention over graph representations of amino acids. Dependency graphs define how PyTorch's autograd executes your neural network backward pass. Heterogeneous graphs model the user-item-context relationships that recommendation systems depend on.
Graph Neural Networks have moved from research curiosity to production staple in under a decade. Pinterest's PinSage (GraphSAGE-based) serves billions of recommendations daily. Uber's fraud detection uses GNNs over transaction graphs. Airbnb's search ranking uses graph attention over user-listing interaction graphs. Every one of these systems requires you to understand both the classical graph algorithms that underpin data processing and the neural architectures that learn over graph structure.
This lesson builds both layers. We start with representations and classical algorithms - the foundations you need regardless of whether you use GNNs. Then we build the GNN framework from first principles, implement GCN and GAT layers from scratch, and work through production patterns for graphs at scale. Every concept connects to a real ML use case.
Why This Exists
Before graph-structured ML, the dominant approach was to flatten graph data into feature vectors. For a social network user, you might engineer features like "number of friends," "average friend activity," and "number of mutual friends with active users." This worked - barely - but it threw away almost all the structural information.
The problem is that relationships are not just counts. A user connected to 10 very active users in a tight-knit community is fundamentally different from a user connected to 10 isolated users, even if both users have "10 friends." The community structure, the network topology, the multi-hop reachability - all of this is invisible to feature engineering.
Convolutional Neural Networks solved the analogous problem for images by exploiting spatial locality: nearby pixels are related, and a convolution kernel can learn to aggregate local structure. The insight behind GNNs is to apply the same idea to graphs: nearby nodes (neighbors in the graph) are related, and we can learn to aggregate their features. The challenge is that graphs have irregular structure - nodes can have varying numbers of neighbors, there is no canonical ordering of neighbors, and the graph might not fit in GPU memory.
The message passing framework, formalized in the MPNN paper (Gilmer et al., 2017), unified dozens of earlier GNN architectures into a single conceptual model: at each layer, every node gathers messages from its neighbors, aggregates them, and updates its own representation. Iterate for L layers, and each node's representation captures information from its L-hop neighborhood. Stack two GCN layers, and a node "sees" everything within two hops. Stack four, and it sees four hops. This is powerful - and as we will see, it also breaks down quickly due to oversmoothing.
Classical graph algorithms remain essential even in the deep learning era. They define how you preprocess graph data, sample subgraphs for mini-batch training, find important nodes to prioritize, and build graph-structured features that complement learned representations.
Historical Context
Graph theory begins with Euler's solution to the Seven Bridges of Konigsberg problem in 1736 - the first proof that you cannot walk through a city crossing each bridge exactly once. This established the field, but it took until the 20th century for algorithmic graph theory to emerge at scale.
Dijkstra published his shortest path algorithm in 1959, motivated by finding efficient routes between cities. He noted it took him only 20 minutes to design while sitting in a Amsterdam cafe - and he wrote it without pencil or paper. The algorithm remains one of the most elegant in computer science: a greedy approach that extends the known shortest-path tree one node at a time, always picking the closest unvisited node.
Breadth-first search and depth-first search were formalized in the 1950s-1960s. Tarjan's strongly connected components algorithm (1972) introduced the concept of recursive DFS with a stack - a pattern that appears constantly in compiler design and dependency resolution. Kruskal's MST algorithm (1956) and Prim's (1957) were early examples of greedy algorithms with correctness proofs derived from the matroid structure of spanning trees.
PageRank was developed by Larry Page and Sergey Brin at Stanford in 1996-1998 as the core of the original Google search algorithm. The insight was to model web surfing as a Markov chain over the link graph. A page's importance is a function of how many important pages link to it - a recursive definition that converges to a stable distribution via power iteration.
Graph neural networks have a shorter history. Gori et al. (2005) and Scarselli et al. (2009) introduced the concept, but it was Kipf and Welling's GCN paper (2016) that made the architecture practical and accessible. Hamilton et al.'s GraphSAGE (2017) made inductive learning on large graphs tractable. Velickovic et al.'s GAT (2017) added attention. Since then: Graph Transformers (Kreuzer et al., 2021), heterogeneous GNNs (HAN, HGT), and graph foundation models are active frontiers.
Graph Representations
The Core Question: How Do You Store a Graph?
A graph has vertices and edges . For a graph with nodes and edges, your choice of representation determines the time and space complexity of every algorithm you run.
Adjacency Matrix: An matrix where if edge exists. Space: . Edge lookup: . Iterating over neighbors: . Works well for dense graphs. For ML, the adjacency matrix is the natural input to spectral GNNs: the graph Laplacian (where is the degree matrix) encodes graph structure in a form amenable to eigendecomposition.
Adjacency List: Each node stores a list of its neighbors. Space: . Edge lookup: . Iterating over neighbors: . Works well for sparse graphs. Most real-world graphs are sparse: a social network with 1 billion users averages maybe 100-500 friends per person, giving rather than .
Sparse Matrix Formats - CSR and CSC:
For large graphs in ML systems, you need compressed sparse formats. The Compressed Sparse Row (CSR) format stores three arrays:
data: non-zero values (edge weights, or just 1s for unweighted)indices: column indices of non-zero valuesindptr: pointer array whereindptr[i]is the start of row 's entries
To find neighbors of node : indices[indptr[i]:indptr[i+1]]. Space: . This is the format used internally by scipy.sparse, and most graph ML frameworks use the analogous COO (coordinate list) format as edge_index - a tensor of (source, destination) pairs.
import numpy as np
from scipy.sparse import csr_matrix
# Build a simple graph: 0->1, 0->2, 1->3, 2->3
rows = [0, 0, 1, 2]
cols = [1, 2, 3, 3]
data_vals = [1, 1, 1, 1]
n_nodes = 4
# CSR format
adj = csr_matrix((data_vals, (rows, cols)), shape=(n_nodes, n_nodes))
# Neighbors of node 0 - O(deg) time:
node = 0
neighbors = adj.indices[adj.indptr[node]:adj.indptr[node+1]]
print(f"Neighbors of node 0: {neighbors}") # [1, 2]
# COO format (PyG's edge_index)
import torch
edge_index = torch.tensor([rows, cols], dtype=torch.long)
print(f"edge_index shape: {edge_index.shape}") # [2, 4]
# Convert dense -> sparse for large graphs
dense_adj = np.array([
[0, 1, 1, 0],
[0, 0, 0, 1],
[0, 0, 0, 1],
[0, 0, 0, 0],
])
sparse_adj = csr_matrix(dense_adj)
print(f"Density: {sparse_adj.nnz / (n_nodes ** 2):.2%}") # 25.00%
BFS and DFS
Breadth-First Search
BFS explores a graph layer by layer - first all nodes at distance 1, then distance 2, and so on. It finds shortest paths in unweighted graphs and is the basis for computing node neighborhoods in GNN preprocessing.
Time complexity: . Space: .
from collections import deque
from typing import Dict, List, Set
def bfs(graph: Dict[int, List[int]], start: int) -> Dict[int, int]:
"""
BFS returning distances from start node.
graph: adjacency list {node: [neighbors]}
Returns: {node: distance}
"""
distances = {start: 0}
queue = deque([start])
while queue:
node = queue.popleft()
for neighbor in graph.get(node, []):
if neighbor not in distances:
distances[neighbor] = distances[node] + 1
queue.append(neighbor)
return distances
def bfs_k_hop_neighborhood(graph: Dict[int, List[int]],
start: int, k: int) -> Set[int]:
"""
Find all nodes within k hops of start.
Used in GNN subgraph sampling - defines the receptive field
of a k-layer GNN centered on node 'start'.
"""
visited = {start}
current_layer = {start}
for hop in range(k):
next_layer = set()
for node in current_layer:
for neighbor in graph.get(node, []):
if neighbor not in visited:
next_layer.add(neighbor)
visited.add(neighbor)
current_layer = next_layer
if not current_layer:
break
return visited
# Example: 2-hop neighborhood for GNN message passing
graph = {
0: [1, 2],
1: [0, 3, 4],
2: [0, 4],
3: [1],
4: [1, 2]
}
neighborhood = bfs_k_hop_neighborhood(graph, start=0, k=2)
print(f"2-hop neighborhood of node 0: {neighborhood}")
# {0, 1, 2, 3, 4} - full graph, since it has diameter 2
Depth-First Search and Topological Sort
DFS is the backbone of topological sorting, which is how PyTorch's autograd determines the order to execute backward operations in a computation graph.
def dfs_topological_sort(graph: Dict) -> List:
"""
Topological sort via DFS post-order.
graph: DAG as adjacency list
Returns nodes in topological order (dependencies before dependents).
PyTorch autograd uses topological sort to determine backward pass
execution order: the loss tensor's grad_fn references a DAG of
operations. Backward computes gradients in reverse topological order.
"""
visited = set()
stack = []
def dfs(node):
visited.add(node)
for neighbor in graph.get(node, []):
if neighbor not in visited:
dfs(neighbor)
# Post-order: append AFTER visiting all descendants
stack.append(node)
for node in graph:
if node not in visited:
dfs(node)
return stack[::-1] # reverse post-order = topological order
# Computation graph: a = x*w, b = a+bias, loss = mse(b, y)
# Edges represent "depends on" relationships
computation_graph = {
'x': [],
'w': [],
'bias': [],
'y': [],
'a': ['x', 'w'], # a depends on x, w
'b': ['a', 'bias'], # b depends on a, bias
'loss': ['b', 'y'], # loss depends on b, y
}
order = dfs_topological_sort(computation_graph)
print("Execution order:", order)
# Forward: x, w, bias, y, a, b, loss
# Backward (reversed): loss, b, a, bias, y, w, x
Shortest Path Algorithms
Dijkstra's Algorithm
Dijkstra finds shortest paths from a single source in graphs with non-negative edge weights. In ML, this appears in knowledge graph path finding, routing in recommendation graphs, and computing graph-based features like "minimum relationship distance" between two entities.
Time complexity: with a binary heap.
import heapq
from typing import Tuple, Optional
def dijkstra(
graph: Dict[int, List[Tuple[int, float]]],
source: int,
target: Optional[int] = None
) -> Tuple[Dict[int, float], Dict[int, Optional[int]]]:
"""
Dijkstra's algorithm for single-source shortest paths.
graph: {node: [(neighbor, weight), ...]}
Returns: (distances, predecessors) for path reconstruction
ML use case: Find shortest path in knowledge graph,
e.g., "Einstein -> bornIn -> Germany -> partOf -> Europe"
Edge weights = 1 - confidence_score (lower = more confident).
"""
distances: Dict[int, float] = {source: 0.0}
predecessors: Dict[int, Optional[int]] = {source: None}
heap = [(0.0, source)]
while heap:
dist_u, u = heapq.heappop(heap)
if target is not None and u == target:
break
# Skip stale entries (we process each node once)
if dist_u > distances.get(u, float('inf')):
continue
for v, weight in graph.get(u, []):
new_dist = dist_u + weight
if new_dist < distances.get(v, float('inf')):
distances[v] = new_dist
predecessors[v] = u
heapq.heappush(heap, (new_dist, v))
return distances, predecessors
def reconstruct_path(predecessors: Dict, source: int, target: int) -> List:
"""Reconstruct shortest path from predecessors dict."""
path = []
current: Optional[int] = target
while current is not None:
path.append(current)
current = predecessors.get(current)
return path[::-1]
# Knowledge graph: entities as nodes, relations as edges
# Edge weight = 1 - relation_confidence (lower = more confident path)
knowledge_graph = {
0: [(1, 0.1), (2, 0.4)], # Einstein -> Berlin (0.9 conf), Germany (0.6 conf)
1: [(3, 0.2)], # Berlin -> Germany (0.8 conf)
2: [(3, 0.1), (4, 0.5)], # Germany -> EU (0.9 conf)
3: [(4, 0.3)], # Germany -> EU (0.7 conf)
4: []
}
dists, preds = dijkstra(knowledge_graph, source=0, target=4)
path = reconstruct_path(preds, 0, 4)
print(f"Highest confidence path: {path}") # Most reliable reasoning chain
print(f"Path cost: {dists[4]:.2f}")
Floyd-Warshall for All-Pairs Shortest Paths
Floyd-Warshall finds all-pairs shortest paths in time. Practical for small graphs like neural architecture search DAGs (hundreds of nodes) and for computing graph kernel features.
def floyd_warshall(n: int, edges: List[Tuple[int, int, float]]) -> np.ndarray:
"""
All-pairs shortest paths. O(n^3) time, O(n^2) space.
Used for: graph distance features, structural equivalence checks,
graph kernel computation in kernel-based graph classification.
"""
INF = float('inf')
dist = [[INF] * n for _ in range(n)]
for i in range(n):
dist[i][i] = 0
for u, v, w in edges:
dist[u][v] = w
for k in range(n): # intermediate node
for i in range(n): # source
for j in range(n): # destination
if dist[i][k] + dist[k][j] < dist[i][j]:
dist[i][j] = dist[i][k] + dist[k][j]
return np.array(dist)
PageRank
PageRank models a random web surfer who, with probability (damping factor, typically 0.85), follows a random outgoing link, and with probability , teleports to a uniformly random page. The stationary distribution of this Markov chain is the PageRank vector.
The recurrence:
In matrix form:
where is the column-stochastic transition matrix. We solve via power iteration: repeatedly multiply until convergence.
def pagerank(
adj_matrix: np.ndarray,
damping: float = 0.85,
max_iter: int = 100,
tol: float = 1e-6
) -> np.ndarray:
"""
PageRank via power iteration.
adj_matrix[i,j] = 1 if there is a link from j to i.
ML applications:
- Node importance in citation graphs
- Authority scores in knowledge graphs
- Ranking nodes in GNN preprocessing (priority sampling)
- TextRank: word/sentence importance for extractive summarization
- HITS algorithm variant for hub/authority detection
"""
n = adj_matrix.shape[0]
# Column-stochastic: normalize columns so each sums to 1
col_sums = adj_matrix.sum(axis=0)
# Handle dangling nodes (out-degree 0): distribute uniformly
dangling_mask = (col_sums == 0)
col_sums[dangling_mask] = 1
M = adj_matrix / col_sums[np.newaxis, :] # column-stochastic transition
M[:, dangling_mask] = 1.0 / n # dangling nodes -> uniform jump
# Power iteration
r = np.ones(n) / n
for iteration in range(max_iter):
r_new = damping * (M @ r) + (1 - damping) / n
delta = np.abs(r_new - r).sum()
r = r_new
if delta < tol:
print(f"Converged at iteration {iteration + 1}, delta={delta:.2e}")
break
return r
# Test on a small web graph
adj = np.array([
[0, 0, 1, 1], # node 0 receives links from 2, 3
[1, 0, 0, 0], # node 1 receives link from 0
[1, 1, 0, 0], # node 2 receives links from 0, 1
[0, 1, 1, 0], # node 3 receives links from 1, 2
])
scores = pagerank(adj)
print("PageRank scores:", np.round(scores, 4))
ranked = np.argsort(scores)[::-1]
print(f"Node ranking (most important first): {ranked}")
Minimum Spanning Tree
Minimum spanning tree algorithms find the cheapest connected subgraph. In ML: used for hierarchical clustering (single-linkage is equivalent to MST), graph coarsening, and feature extraction.
Kruskal's algorithm: Sort edges by weight, greedily add edges that don't create cycles. Uses Union-Find for cycle detection. Time: .
Prim's algorithm: Start from any node, greedily grow the tree by adding the cheapest edge connecting the tree to a new node. Time: with a heap.
class UnionFind:
"""Union-Find (Disjoint Set Union) for Kruskal's MST."""
def __init__(self, n: int):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x: int) -> int:
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # path compression
return self.parent[x]
def union(self, x: int, y: int) -> bool:
px, py = self.find(x), self.find(y)
if px == py:
return False # already connected - adding edge creates cycle
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
return True
def kruskal_mst(n: int, edges: List[Tuple[float, int, int]]) -> List:
"""
Kruskal's MST algorithm.
edges: [(weight, u, v), ...]
Returns list of MST edges.
ML use: single-linkage hierarchical clustering, graph coarsening
"""
edges_sorted = sorted(edges)
uf = UnionFind(n)
mst = []
for weight, u, v in edges_sorted:
if uf.union(u, v): # no cycle created
mst.append((weight, u, v))
if len(mst) == n - 1:
break # MST complete
return mst
GNN Message Passing Framework
The message passing neural network (MPNN) framework unifies virtually all GNN architectures. At layer , each node updates its representation through three operations:
where is the neighbor set of . After layers, each node's representation captures information from its -hop neighborhood.
The READOUT operation produces a graph-level representation for graph classification or regression:
Graph Convolutional Networks (GCN)
Kipf and Welling (2016) derived GCN from spectral graph theory. The key insight: graph convolution in the spectral domain simplifies, via a first-order Chebyshev polynomial approximation, to a simple neighborhood aggregation rule.
The GCN layer:
where (self-loops), is the degree matrix of , and is a learnable weight matrix.
The term is symmetric normalization - it scales message by , preventing high-degree hub nodes from dominating.
import torch
import torch.nn as nn
import torch.nn.functional as F
class GCNLayer(nn.Module):
"""
Graph Convolutional Network layer from scratch.
Kipf & Welling, ICLR 2017.
H_out = sigma( D_tilde^{-1/2} A_tilde D_tilde^{-1/2} H W )
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.W = nn.Linear(in_features, out_features, bias=True)
nn.init.xavier_uniform_(self.W.weight)
nn.init.zeros_(self.W.bias)
def normalize_adjacency(self, A: torch.Tensor) -> torch.Tensor:
"""Compute D_tilde^{-1/2} (A + I) D_tilde^{-1/2}."""
n = A.shape[0]
A_tilde = A + torch.eye(n, device=A.device)
D_tilde = A_tilde.sum(dim=1)
D_inv_sqrt = torch.diag(D_tilde.pow(-0.5))
return D_inv_sqrt @ A_tilde @ D_inv_sqrt
def forward(self, H: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
"""
H: [n_nodes, in_features]
A: [n_nodes, n_nodes] binary adjacency matrix
Returns: [n_nodes, out_features]
"""
A_norm = self.normalize_adjacency(A)
H_agg = A_norm @ H # aggregate: weighted sum of neighbor features
return self.W(H_agg) # transform: linear projection
class GCN(nn.Module):
"""Two-layer GCN for node classification."""
def __init__(self, n_features: int, n_hidden: int, n_classes: int,
dropout: float = 0.5):
super().__init__()
self.conv1 = GCNLayer(n_features, n_hidden)
self.conv2 = GCNLayer(n_hidden, n_classes)
self.dropout = dropout
def forward(self, X: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
H = F.relu(self.conv1(X, A))
H = F.dropout(H, p=self.dropout, training=self.training)
H = self.conv2(H, A)
return F.log_softmax(H, dim=1)
# Quick test
torch.manual_seed(42)
n_nodes, n_features, n_classes = 100, 16, 7
A = torch.randint(0, 2, (n_nodes, n_nodes)).float()
A = ((A + A.T) > 0).float() # undirected, no self-loops
X = torch.randn(n_nodes, n_features)
labels = torch.randint(0, n_classes, (n_nodes,))
model = GCN(n_features=n_features, n_hidden=64, n_classes=n_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
for epoch in range(200):
model.train()
optimizer.zero_grad()
logits = model(X, A)
loss = F.nll_loss(logits, labels)
loss.backward()
optimizer.step()
if epoch % 50 == 0:
acc = (logits.argmax(dim=1) == labels).float().mean()
print(f"Epoch {epoch:3d}: loss={loss.item():.4f}, acc={acc.item():.4f}")
PyTorch Geometric - Cora Dataset
In production, you use PyG rather than implementing sparse message passing from scratch. PyG handles CSR adjacency, mini-batch sampling, and efficient CUDA message passing.
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
# Cora: 2708 papers, 5429 citation links, 7 topic classes
dataset = Planetoid(root='/tmp/Cora', name='Cora',
transform=NormalizeFeatures())
data = dataset[0]
print(f"Nodes: {data.num_nodes}") # 2708
print(f"Edges: {data.num_edges}") # 10556 (bidirected)
print(f"Features per node: {data.num_node_features}") # 1433 (bag-of-words)
print(f"Classes: {dataset.num_classes}") # 7
class PygGCN(torch.nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, out_ch)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = PygGCN(dataset.num_node_features, 64, dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
def test():
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
return [(pred[m] == data.y[m]).float().mean().item()
for m in [data.train_mask, data.val_mask, data.test_mask]]
for epoch in range(1, 201):
loss = train()
if epoch % 50 == 0:
tr, val, te = test()
print(f"Epoch {epoch}: loss={loss:.4f}, "
f"train={tr:.4f}, val={val:.4f}, test={te:.4f}")
# Expected: ~81% test accuracy
GraphSAGE - Inductive Learning
GCN is transductive: it requires the full graph at training time and cannot embed nodes not seen during training. GraphSAGE (Hamilton et al., 2017) fixes this by learning an aggregation function that generalizes to unseen nodes.
Key differences from GCN:
- Samples a fixed-size neighborhood rather than using all neighbors (controls computation cost)
- Concatenates current node representation with aggregated neighbor representation
- The model learns a function over features and structure that generalizes inductively
class SAGELayer(nn.Module):
"""
GraphSAGE layer with mean aggregation.
Hamilton et al., NeurIPS 2017.
Inductive: learns to aggregate neighbor features.
Generalizes to unseen nodes by running the same aggregation function
on their (sampled) neighborhoods.
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
# Input dim = 2 * in_features because we concatenate self + neighbors
self.W = nn.Linear(in_features * 2, out_features)
def forward(self, H: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
"""
H: [n_nodes, in_features]
A: [n_nodes, n_nodes] adjacency (raw, not normalized)
"""
# Mean aggregation: average neighbor features
degree = A.sum(dim=1, keepdim=True).clamp(min=1)
H_neighbors = (A @ H) / degree # [n, in_features]
# Concatenate self + aggregated neighbors, then transform
H_concat = torch.cat([H, H_neighbors], dim=1) # [n, 2*in_features]
out = self.W(H_concat)
# L2 normalize output (standard in GraphSAGE)
return F.normalize(out, p=2, dim=1)
Graph Attention Networks (GAT)
GAT (Velickovic et al., 2017) replaces GCN's fixed symmetric normalization with learned attention. Different neighbors contribute different amounts based on feature compatibility.
The attention coefficient between node and neighbor :
The update:
Multi-head attention runs independent attention functions and concatenates (or averages) their outputs, improving stability and representational capacity.
class GATLayer(nn.Module):
"""
Graph Attention Network layer.
Velickovic et al., ICLR 2018.
Multi-head attention with concat or average.
"""
def __init__(self, in_features: int, out_features: int,
n_heads: int = 4, dropout: float = 0.6,
concat: bool = True):
super().__init__()
self.n_heads = n_heads
self.out_features = out_features
self.concat = concat
# Shared linear transform applied before attention
self.W = nn.Linear(in_features, out_features * n_heads, bias=False)
# Attention vector: one per head, size 2*out_features (src + dst)
self.a = nn.Parameter(torch.zeros(n_heads, 2 * out_features))
nn.init.xavier_uniform_(self.a.unsqueeze(0))
self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
self.attn_drop = nn.Dropout(dropout)
def forward(self, H: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
"""
H: [n_nodes, in_features]
A: [n_nodes, n_nodes] adjacency
Returns [n_nodes, n_heads * out_features] if concat
[n_nodes, out_features] if average
"""
n = H.shape[0]
Wh = self.W(H).view(n, self.n_heads, self.out_features) # [n, H, F]
head_outputs = []
for h_idx in range(self.n_heads):
Wh_h = Wh[:, h_idx, :] # [n, F]
# Attention logits: e[i,j] = LReLU(a_src[i] + a_dst[j])
a_src = (Wh_h * self.a[h_idx, :self.out_features]).sum(dim=1)
a_dst = (Wh_h * self.a[h_idx, self.out_features:]).sum(dim=1)
e = self.leaky_relu(a_src.unsqueeze(1) + a_dst.unsqueeze(0))
# Mask out non-edges (set to -inf before softmax)
no_edge = (A == 0)
e = e.masked_fill(no_edge, float('-inf'))
alpha = torch.softmax(e, dim=1)
alpha = self.attn_drop(alpha)
h_new = alpha @ Wh_h # [n, F]
head_outputs.append(h_new)
if self.concat:
out = torch.cat(head_outputs, dim=1) # [n, n_heads * F]
else:
out = torch.stack(head_outputs).mean(0) # [n, F]
return F.elu(out)
GNN Architecture Comparison
Heterogeneous Graphs for Recommendation Systems
Real recommendation systems have multiple node types (users, items, categories, brands) and multiple edge types (clicked, purchased, viewed, belongs_to). This is a heterogeneous graph.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv
from torch_geometric.data import HeteroData
def build_recsys_hetero_graph() -> HeteroData:
"""
Heterogeneous graph for e-commerce recommendation.
Node types: user, item, category
Edge types:
user -[purchased]-> item
user -[viewed]-> item
item -[belongs_to]-> category
"""
data = HeteroData()
n_users, n_items, n_cats = 1000, 5000, 50
data['user'].x = torch.randn(n_users, 64)
data['item'].x = torch.randn(n_items, 128)
data['category'].x = torch.randn(n_cats, 32)
data['user', 'purchased', 'item'].edge_index = torch.stack([
torch.randint(0, n_users, (10000,)),
torch.randint(0, n_items, (10000,))
])
data['user', 'viewed', 'item'].edge_index = torch.stack([
torch.randint(0, n_users, (50000,)),
torch.randint(0, n_items, (50000,))
])
data['item', 'belongs_to', 'category'].edge_index = torch.stack([
torch.randint(0, n_items, (5000,)),
torch.randint(0, n_cats, (5000,))
])
return data
class HeteroRecsysGNN(nn.Module):
"""
Heterogeneous GNN: each edge type has its own SAGE convolution.
Outputs per-node embeddings in a shared hidden space.
"""
def __init__(self, hidden: int = 64):
super().__init__()
# Project all node types to shared dimension
self.proj = nn.ModuleDict({
'user': nn.Linear(64, hidden),
'item': nn.Linear(128, hidden),
'category': nn.Linear(32, hidden),
})
self.conv = HeteroConv({
('user', 'purchased', 'item'): SAGEConv(hidden, hidden),
('user', 'viewed', 'item'): SAGEConv(hidden, hidden),
('item', 'belongs_to', 'category'): SAGEConv(hidden, hidden),
}, aggr='sum')
def forward(self, data: HeteroData):
x_dict = {
nt: F.relu(self.proj[nt](data[nt].x))
for nt in ['user', 'item', 'category']
}
x_dict = self.conv(x_dict, data.edge_index_dict)
x_dict = {k: F.relu(v) for k, v in x_dict.items()}
return x_dict
# Score user-item affinity for ranking
def compute_scores(user_embs: torch.Tensor,
item_embs: torch.Tensor) -> torch.Tensor:
"""Dot product similarity for recommendation ranking."""
return torch.matmul(user_embs, item_embs.T)
Graph Sampling for Large Graphs
Full-batch GNN training is intractable when graphs exceed GPU memory. Three main sampling strategies used in production:
from torch_geometric.loader import NeighborLoader, ClusterData, ClusterLoader
def create_minibatch_loaders(data, batch_size: int = 1024):
"""
NeighborLoader: GraphSAGE-style mini-batch training.
num_neighbors=[25, 10] means:
- Sample 25 neighbors for 1-hop expansion
- Sample 10 neighbors for 2-hop expansion
Total nodes per target: up to 25 * 10 = 250 nodes
(vs full graph with potentially millions of nodes)
"""
train_loader = NeighborLoader(
data,
num_neighbors=[25, 10],
batch_size=batch_size,
input_nodes=data.train_mask,
shuffle=True,
)
return train_loader
def create_cluster_loaders(data, num_parts: int = 1500):
"""
ClusterLoader: Cluster-GCN style training.
Partition graph into num_parts dense clusters using METIS.
Each mini-batch = a few complete clusters.
Key advantage: few cross-cluster edges means little information loss.
Used for ogbn-products (2.4M nodes), Reddit (232k nodes).
"""
cluster_data = ClusterData(data, num_parts=num_parts, recursive=False)
return ClusterLoader(cluster_data, batch_size=20, shuffle=True)
Production Engineering Notes
Sparse Message Passing
Dense matrix multiplication for message passing is catastrophically inefficient for sparse graphs. A million-node graph with average degree 50 is 99.995% zeros. Use scatter operations on edge lists instead:
def sparse_message_passing(
x: torch.Tensor, # [n_nodes, features]
edge_index: torch.Tensor, # [2, n_edges] COO format
edge_weight: torch.Tensor, # [n_edges]
n_nodes: int
) -> torch.Tensor:
"""
Efficient scatter-gather message passing.
This is the core operation inside PyG's MessagePassing class.
~100-1000x faster than dense matmul for sparse graphs.
"""
src, dst = edge_index
# Gather: collect source node features for each edge
messages = x[src] * edge_weight.unsqueeze(-1) # [n_edges, features]
# Scatter: aggregate messages at destination nodes
out = torch.zeros(n_nodes, x.shape[1], device=x.device)
out.scatter_add_(0, dst.unsqueeze(-1).expand_as(messages), messages)
return out
Oversmoothing Detection
Adding more GNN layers almost always hurts. Node representations converge to the same vector because the normalized adjacency is a low-pass filter that averages out differences across the graph.
def mean_average_distance(embeddings: torch.Tensor,
edge_index: torch.Tensor) -> float:
"""
Measure oversmoothing: compute average cosine distance between
connected node pairs. Lower MAD = more smoothed = worse.
A well-trained 2-layer GCN on Cora has MAD ~0.5.
After 8 layers it drops below 0.1 (embeddings nearly identical).
"""
src, dst = edge_index
h_src = F.normalize(embeddings[src], dim=1)
h_dst = F.normalize(embeddings[dst], dim=1)
cosine_sim = (h_src * h_dst).sum(dim=1)
return (1 - cosine_sim).mean().item()
The 2-4 Layer Rule for GNNs
Most GNN benchmarks peak at 2-layer networks. Stacking more layers degrades performance due to oversmoothing - repeated averaging converges all node representations to the same vector. If you need long-range dependencies, use skip connections (ResGCN), jumping knowledge networks (concatenate all layer outputs), or switch to a Graph Transformer.
Never Use Dense Adjacency for Graphs Larger Than ~10k Nodes
The adjacency matrix for a 1M-node graph is 1M x 1M x 4 bytes = 4TB. Even for 10k nodes: 10k x 10k x 4 = 400MB just for the adjacency. Use sparse COO edge_index format (PyG default) or CSR for all production-scale graphs. Also: full-batch training on 100k+ node graphs will OOM on any GPU - always use NeighborLoader or ClusterLoader.
Interview Questions and Answers
Q1: Explain how GCN performs graph convolution. Why does it add self-loops and use symmetric normalization?
GCN approximates spectral graph convolution using a first-order Chebyshev polynomial on the graph Laplacian. The normalized adjacency aggregates neighbor features, and the learned weight matrix transforms the aggregated representation.
Self-loops () are added so each node includes its own features during aggregation. Without self-loops, a node's layer- representation depends only on neighbors at layer , throwing away the node's current representation. With self-loops, each node is its own neighbor, so its existing features persist.
Symmetric normalization prevents high-degree hub nodes from dominating. Without it, a hub node with 1000 neighbors sums 1000 neighbor vectors while a leaf sums only 1. This causes instability and poor gradient flow. Symmetric normalization makes every neighbor's contribution proportional to , balancing the aggregation regardless of degree.
Q2: What is the difference between transductive and inductive GNNs? Which production use cases require inductive learning?
Transductive GNNs (like GCN) require the full graph at training time - including test nodes, even without labels. They learn fixed embedding vectors for specific nodes in the training graph. They cannot embed nodes that appear after training. This is fine for static datasets like citation networks where the graph does not change.
Inductive GNNs (GraphSAGE, GAT) learn an aggregation function parameterized by features and structure. Given any node's features and neighborhood, the model computes an embedding by running the learned aggregation. This generalizes to completely new nodes.
Production use cases requiring inductive learning: (1) recommendation systems where new users and items appear daily - Pinterest PinSage is the canonical example; (2) fraud detection where new accounts appear constantly; (3) molecular property prediction where you train on known molecules and predict for new ones; (4) code analysis for new repositories not in the training set.
Q3: How does attention work in GAT and what advantage does it have over GCN's fixed normalization?
GAT computes a scalar attention score for each edge . The score is the output of a learned single-layer feedforward network (parameterized by vector ) applied to the concatenation of node and node 's transformed features. Scores are normalized via softmax over each node's neighborhood.
Compared to GCN's fixed normalization: (1) attention weights are learned end-to-end, adapting to the task; (2) different neighbors contribute different amounts based on feature relevance, not just graph topology; (3) multi-head attention provides multiple "views" of the neighborhood, improving stability; (4) attention weights are interpretable - in citation graphs, GAT learns to weight same-topic papers more heavily, and you can visualize which neighbors were important for each prediction.
Q4: You have a graph with 10 million nodes and 500 million edges. How do you train a GNN on it?
Full-batch training is impossible - the normalized adjacency alone requires 10M x 10M entries. Three practical approaches:
NeighborLoader (GraphSAGE-style): For each mini-batch of target nodes, sample k-hop neighborhoods with fixed fan-out. With 2 layers and fan-out [25, 10], each target node expands to at most 250 nodes. GPU processes manageable subgraphs. Used at Pinterest (PinSage), Twitter, and Alibaba.
ClusterLoader (Cluster-GCN): Partition the graph into dense clusters using METIS. Each mini-batch processes several complete clusters. Intra-cluster edges are preserved (most edges), only cross-cluster edges are dropped per batch. GPU sees dense subgraphs with high edge utilization. Used for ogbn-products (2.4M nodes).
For billion-scale (Meta, LinkedIn): distributed training with the graph sharded across machines, remote node embeddings fetched via a parameter server or embedding lookup service.
Q5: Explain PageRank's connection to eigenvectors and Markov chains. What does the damping factor do?
PageRank defines a random walk Markov chain over the web graph. The transition matrix is column-stochastic: = probability of jumping from page to page (following any of 's outgoing links uniformly). The stationary distribution satisfies - it is the dominant eigenvector of (eigenvalue 1).
Power iteration converges to this eigenvector because the Markov chain is ergodic: from any page, you can reach any other page (teleportation ensures this). Convergence rate depends on the spectral gap - the ratio of the second eigenvalue to the first.
The damping factor does two things: it models a random surfer who sometimes gets bored and teleports to a random page, AND it guarantees ergodicity (the teleportation component connects every node to every other node, ensuring a unique stationary distribution exists). Without damping, dangling nodes (pages with no outgoing links) would cause the Markov chain to "trap" probability mass and fail to converge.
Q6: How does PyTorch autograd use topological sort, and what would happen if a computation graph had a cycle?
PyTorch builds a computation graph dynamically during the forward pass. Each tensor tracks its grad_fn - the operation that produced it - and a list of input tensors. This forms a directed acyclic graph (DAG) where edges represent data flow.
When .backward() is called, PyTorch traverses this DAG in reverse topological order. Topological sort ensures that each operation's backward method runs only after all downstream gradients are accumulated. This implements the chain rule correctly: to compute , you need first.
If the computation graph had a cycle - say tensor depended on tensor which depended on - topological sort would fail (a cycle has no valid topological ordering). This is why standard RNNs are "unrolled through time": the recurrence appears cyclic, but by creating distinct tensor copies , the computation becomes a DAG with time steps. Gradient checkpointing recomputes activations during backward to reduce memory, but the DAG structure must remain acyclic.
Summary
Graph algorithms underpin a surprising fraction of production ML systems. The path from Euler's bridges to Pinterest's billion-scale GNN recommendations runs through BFS/DFS, Dijkstra, PageRank, and the message passing framework.
Understanding graph representations - adjacency matrices, CSR format, COO edge_index - determines whether your system scales to 10 million nodes or crashes on 10 thousand. GCN, GraphSAGE, and GAT represent three complementary philosophies: spectral simplicity, inductive scalability, and learned attention.
In practice, GraphSAGE with neighbor sampling is the workhorse of production graph ML. GAT adds interpretability where understanding which neighbors matter is valuable. Graph Transformers with global attention are best when the graph is small enough to fit their attention cost. Oversmoothing limits useful depth to 2-4 layers, and mini-batch training via subgraph sampling is essential at scale.
