Skip to main content

GraphSAGE and Inductive Learning

Reading time: 54 min | Interview relevance: Very High - PinSage is a canonical system design case study; inductive learning is the key distinction for production GNNs | Target roles: ML Engineer, AI Engineer, Research Engineer, Graph ML Engineer


The Real Interview Moment

"Pinterest has 200 million users and 100 billion pins. How would you apply a GNN to recommend pins to users in real time?"

This is a classic system design interview question at companies like Pinterest, Airbnb, Twitter, and Spotify that operate recommendation systems at scale on graph-structured data. The naive answer - "train GCN on the full graph" - fails on three counts: (1) the full graph does not fit in GPU memory, (2) GCN cannot compute embeddings for new pins added after training, (3) re-running the full forward pass for every new pin is computationally infeasible.

The correct answer requires understanding three things: why transductive GNNs fail at scale, how GraphSAGE enables inductive inference and mini-batch training, and how Pinterest's PinSage system deployed these ideas to a 3-billion-node graph with 150ms latency requirements.

Hamilton, Ying, and Leskovec (2017) developed GraphSAGE (Graph SAmple and aggreGAtE) to solve exactly this problem. It is now the foundation of graph-based recommendation systems at Pinterest, Airbnb, and many other companies.


Why Transductive GNNs Fail at Scale

GCN and standard GAT are transductive: they learn a direct embedding vector for each training node.

In GCN, the node feature matrix H(0)Rn×dH^{(0)} \in \mathbb{R}^{n \times d} is an input, and the model produces embeddings H(L)Rn×dLH^{(L)} \in \mathbb{R}^{n \times d_L} for all nn nodes simultaneously. The embedding for node vv depends on all nodes in its LL-hop neighborhood. During training, the entire graph must be in memory.

Problem 1 - Memory: for Pinterest's graph (3B nodes, 18B edges), storing even a binary adjacency matrix at uint8 precision requires 18GB. Node features might be 512-dimensional floats - 3B × 512 × 4 bytes = 6TB. This cannot fit in a single GPU or even a cluster of GPUs without distributed infrastructure.

Problem 2 - New nodes: GCN learns embeddings for a fixed set of nodes. A new pin added after training has no embedding vector. To get an embedding for it, you must either: (a) retrain the entire GCN (takes hours or days), or (b) use some ad-hoc approximation (e.g., average of neighbor embeddings) that is disconnected from the trained model. Neither is acceptable for a recommendation system that adds thousands of new pins per second.

Problem 3 - Computational graph explosion: at inference time for a transductive GCN, computing one node's embedding requires its entire LL-hop neighborhood to be in memory. For a graph with average degree 50 and L=2L=2 layers, this is up to 2500 nodes per target node. For a batch of 1024 target nodes, that is potentially 2.56M nodes - essentially the whole graph.

GraphSAGE solves all three problems by changing what the model learns.


The Fundamental Shift: Learn Functions, Not Embeddings

The key insight in GraphSAGE: instead of learning embedding vectors for specific nodes, learn the aggregation functions that compute any node's embedding from its neighborhood features.

Transductive GCN learns: Θ={H(0),W(1),W(2),}\Theta = \{H^{(0)}, W^{(1)}, W^{(2)}, \ldots\} where H(0)Rn×dH^{(0)} \in \mathbb{R}^{n \times d} is the node embedding table - a direct model parameter indexed by node ID. This table has nn rows. New nodes have no row.

GraphSAGE learns: Θ={W(1),W(2),,AGGREGATE1,AGGREGATE2,}\Theta = \{W^{(1)}, W^{(2)}, \ldots, \text{AGGREGATE}_1, \text{AGGREGATE}_2, \ldots\} The aggregation functions take neighbor feature vectors as input and produce an aggregated representation. For a new node vv' with known features hv(0)\mathbf{h}_{v'}^{(0)} and neighbors N(v)\mathcal{N}(v'):

hv(1)=σ ⁣(W(1)[hv(0)AGGREGATE1 ⁣({hu(0):uN(v)})])\mathbf{h}_{v'}^{(1)} = \sigma\!\left(W^{(1)} \cdot \left[\mathbf{h}_{v'}^{(0)} \| \text{AGGREGATE}_1\!\left(\{\mathbf{h}_u^{(0)}: u \in \mathcal{N}(v')\}\right)\right]\right)

No retraining required. The aggregation function generalizes to any node whose neighbors have features.


GraphSAGE: Sample and Aggregate

The Forward Algorithm

For each node vv, GraphSAGE computes its embedding layer by layer:

Layer kk forward pass for node vv:

Step 1 - Sample neighbors: instead of using all of N(v)\mathcal{N}(v) (which may be thousands of nodes), uniformly sample a fixed-size subset:

Sv(k)Uniform ⁣(N(v),Sk)\mathcal{S}_v^{(k)} \sim \text{Uniform}\!\left(\mathcal{N}(v), S_k\right)

where SkS_k is the number of neighbors to sample at layer kk. Typical values: S1=25S_1 = 25, S2=10S_2 = 10 for a 2-layer model.

Step 2 - Aggregate neighbor representations:

hS(k)=AGGREGATEk ⁣({hu(k1):uSv(k)})\mathbf{h}_{\mathcal{S}}^{(k)} = \text{AGGREGATE}_k\!\left(\left\{\mathbf{h}_u^{(k-1)} : u \in \mathcal{S}_v^{(k)}\right\}\right)

Step 3 - Concatenate self and aggregate, then transform:

hv(k)=σ ⁣(W(k)[hv(k1)hS(k)])\mathbf{h}_v^{(k)} = \sigma\!\left(W^{(k)} \cdot \left[\mathbf{h}_v^{(k-1)} \| \mathbf{h}_{\mathcal{S}}^{(k)}\right]\right)

The concatenation [hv(k1)hS(k)][\mathbf{h}_v^{(k-1)} \| \mathbf{h}_{\mathcal{S}}^{(k)}] is a key feature of GraphSAGE - it preserves the distinction between the node's own representation and what it aggregated from neighbors. In GCN, self and neighbors are merged in the normalization step; in GraphSAGE, they are kept separate until the linear transformation.

Step 4 - L2 normalize:

hv(k)=hv(k)hv(k)2\mathbf{h}_v^{(k)} = \frac{\mathbf{h}_v^{(k)}}{\left\|\mathbf{h}_v^{(k)}\right\|_2}

L2 normalization is a critical feature of GraphSAGE for recommendation tasks. After normalization, the dot product of two node embeddings equals their cosine similarity - directly interpretable as relevance between items. Approximate nearest neighbor (ANN) search in cosine similarity space powers production recommendation at Pinterest.


Three Aggregator Functions

GraphSAGE was designed to support interchangeable aggregators. Hamilton et al. tested three:

1. Mean Aggregator

hS(k)=MEAN ⁣({hu(k1):uSv(k)})\mathbf{h}_{\mathcal{S}}^{(k)} = \text{MEAN}\!\left(\left\{\mathbf{h}_u^{(k-1)} : u \in \mathcal{S}_v^{(k)}\right\}\right)

The simplest aggregator. Permutation-invariant (the order of neighbors does not matter). No additional parameters. Equivalent to GCN's aggregation when combined with self-concatenation. Fast and effective - often the best choice in practice.

The mean aggregator is equivalent to the GCN update rule (when you combine self and neighbor means) - GraphSAGE with mean aggregator can be seen as a generalization of GCN.

2. Max-Pooling Aggregator

First transform each neighbor's representation, then take element-wise maximum:

hS(k)=max ⁣({σ ⁣(Wpoolhu(k1)+b):uSv(k)})\mathbf{h}_{\mathcal{S}}^{(k)} = \max\!\left(\left\{\sigma\!\left(W_{\text{pool}} \mathbf{h}_u^{(k-1)} + \mathbf{b}\right) : u \in \mathcal{S}_v^{(k)}\right\}\right)

The element-wise max selects the most "activated" feature from any neighbor in the sampled set. This captures the "strongest signal" from the neighborhood rather than the average. For classification tasks where a single highly-relevant neighbor is more informative than the average of many mediocre neighbors, max-pooling outperforms mean.

The pooling transformation WpoolW_{\text{pool}} adds learnable parameters but is separate from the main weight matrix - it first transforms each neighbor into a higher-dimensional space before taking the max.

3. LSTM Aggregator

Pass the sampled neighbors (in a random order) through an LSTM:

hS(k)=LSTM ⁣([huπ(1)(k1),huπ(2)(k1),,huπ(S)(k1)])\mathbf{h}_{\mathcal{S}}^{(k)} = \text{LSTM}\!\left(\left[\mathbf{h}_{u_{\pi(1)}}^{(k-1)}, \mathbf{h}_{u_{\pi(2)}}^{(k-1)}, \ldots, \mathbf{h}_{u_{\pi(S)}}^{(k-1)}\right]\right)

where π\pi is a random permutation of the sampled neighbor indices.

The LSTM has more expressive power than mean or max, but with a critical caveat: LSTMs are not permutation-invariant. The same set of neighbors in different orders produces different representations. GraphSAGE handles this by randomly shuffling the neighbors each forward pass, effectively averaging over orderings during training. This adds stochasticity but also acts as implicit regularization.

In practice, the LSTM aggregator rarely outperforms max-pooling enough to justify its complexity and sequential computation cost.

Comparison:

AggregatorPermutation InvariantParametersBest For
MeanYesNoneMost tasks, fast, good default
Max-PoolingYesWpool,bW_{\text{pool}}, \mathbf{b}Single-salient-neighbor tasks
LSTMNo (use random shuffle)LSTM paramsExpressive but rarely best

Unsupervised GraphSAGE: Random Walk Loss

For tasks without node labels, GraphSAGE can be trained with an unsupervised objective based on random walks. The loss is:

L=log ⁣(σ ⁣(zuzv))QEvnPn(v) ⁣[log ⁣(σ ⁣(zuzvn))]\mathcal{L} = -\log\!\left(\sigma\!\left(\mathbf{z}_u^\top \mathbf{z}_v\right)\right) - Q \cdot \mathbb{E}_{v_n \sim P_n(v)}\!\left[\log\!\left(\sigma\!\left(-\mathbf{z}_u^\top \mathbf{z}_{v_n}\right)\right)\right]

where:

  • (u,v)(u, v) - positive pair: nodes that co-occur in a length-\ell random walk starting from uu. Co-occurrence in random walks indicates structural proximity.
  • vnv_n - negative samples drawn from Pn(v)dv3/4P_n(v) \propto d_v^{3/4} (degree-weighted noise distribution, same as Word2Vec negative sampling)
  • QQ - number of negative samples per positive pair (typically 5–20)

Interpretation: maximize the dot product (cosine similarity after L2 normalization) between nearby nodes, minimize it between distant nodes. This encourages embeddings where structurally similar nodes are close in the embedding space - the graph's connectivity structure is the training signal.

This is precisely DeepWalk/Node2Vec's objective, but with GNN-generated embeddings instead of lookup table embeddings. The advantage: the GNN can generalize to new nodes (using their features + neighbors), while DeepWalk cannot.

Two-stage training: many production systems first train GraphSAGE unsupervised (to capture general graph structure), then fine-tune supervised on labeled data (to capture task-specific patterns). The unsupervised pre-training initializes embeddings in a meaningful space, speeding up supervised convergence.


Mini-Batch Training with NeighborLoader

The computational graph expansion problem: to compute a 2-layer GraphSAGE embedding for a target node, you need its 2-hop neighborhood. Without sampling, for a graph with mean degree dˉ=50\bar{d} = 50:

  • 1-hop: 50 nodes
  • 2-hop: 50 × 50 = 2500 nodes
  • For a batch of 1024 targets: potentially 2.56M nodes (the whole graph)

With sampling (S1=25S_1 = 25, S2=10S_2 = 10):

  • 1-hop: 25 nodes per target
  • 2-hop: 10 nodes per 1-hop node = 250 per target
  • For 1024 targets: 1024 × 275 ≈ 281,600 nodes - fixed, predictable, controllable

This is the key scalability insight: by fixing the neighborhood size at each layer, the computation graph size grows linearly (not exponentially) with batch size.

The sampling introduces stochasticity - different runs see different neighborhood samples - which acts as implicit regularization and often helps generalization.


Full Code: GraphSAGE, NeighborLoader, Unsupervised Training

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Reddit, Planetoid
from torch_geometric.loader import NeighborLoader
import numpy as np
from typing import Optional

# ─── LOAD REDDIT (LARGE GRAPH EXAMPLE) ────────────────────────────────────────

# Reddit: 232,965 nodes, 114M edges, 602 features, 41 subreddit classes
# Too large for full-graph GCN training on most hardware - ideal for GraphSAGE
dataset_reddit = Reddit(root='/tmp/reddit')
data_reddit = dataset_reddit[0]

print(f"Reddit dataset:")
print(f" Nodes: {data_reddit.num_nodes:,}") # 232,965
print(f" Edges: {data_reddit.num_edges:,}") # 114,848,857
print(f" Features: {data_reddit.num_features}") # 602
print(f" Classes: {dataset_reddit.num_classes}") # 41

# ─── NEIGHBORLOADER: MINI-BATCH WITH NEIGHBOR SAMPLING ────────────────────────

# num_neighbors=[25, 10]: sample 25 neighbors at layer 1 (outermost), 10 at layer 2
# For a target node: computational graph has at most 25×10=250 2-hop neighbors
train_loader = NeighborLoader(
data_reddit,
num_neighbors=[25, 10], # per layer, innermost to outermost? No: [layer1, layer2]
# Note: PyG convention is outermost first: [neighbors_at_2hop, neighbors_at_1hop]
# For 2-layer model:
# num_neighbors[0] = neighbors for layer 2 (outermost aggregation)
# num_neighbors[1] = neighbors for layer 1 (innermost aggregation)
batch_size=1024,
input_nodes=data_reddit.train_mask,
shuffle=True,
num_workers=4,
)

val_loader = NeighborLoader(
data_reddit,
num_neighbors=[25, 10],
batch_size=2048,
input_nodes=data_reddit.val_mask,
shuffle=False,
num_workers=2,
)

# ─── GRAPHSAGE MODEL ──────────────────────────────────────────────────────────

class GraphSAGE(nn.Module):
"""
GraphSAGE with SAGEConv (mean aggregation).

SAGEConv implements:
h_v^k = σ(W^k · [h_v^(k-1) ‖ MEAN({h_u^(k-1) : u ∈ S_v^k})])

The concatenation preserves the distinction between self and aggregate.
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int = 2,
dropout: float = 0.5,
):
super().__init__()
self.dropout = dropout
self.convs = nn.ModuleList()

# Input → hidden
self.convs.append(SAGEConv(in_channels, hidden_channels))
# Hidden → hidden (optional middle layers)
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
# Hidden → output
self.convs.append(SAGEConv(hidden_channels, out_channels))

def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1: # all layers except last
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
return x # (num_nodes, out_channels)

def embed(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
normalize: bool = True,
) -> torch.Tensor:
"""
Get penultimate-layer embeddings, optionally L2 normalized.
L2 normalization enables cosine similarity via dot product.
"""
for i, conv in enumerate(self.convs[:-1]):
x = F.relu(conv(x, edge_index))
if normalize:
x = F.normalize(x, p=2, dim=1)
return x

model = GraphSAGE(
in_channels=dataset_reddit.num_features,
hidden_channels=256,
out_channels=dataset_reddit.num_classes,
num_layers=2,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# ─── MINI-BATCH TRAINING LOOP ─────────────────────────────────────────────────

def train_minibatch(model, loader, optimizer):
"""
Training with NeighborLoader mini-batches.

Key: batch.batch_size tells you how many nodes are "seed" (target) nodes.
The first batch.batch_size nodes are the targets; remaining are neighbors
needed for message passing but not for loss computation.
"""
model.train()
total_loss = 0
total_nodes = 0

for batch in loader:
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)

# Only compute loss for seed (target) nodes
# batch.batch_size: number of target nodes in this batch
loss = F.cross_entropy(
out[:batch.batch_size],
batch.y[:batch.batch_size],
)
loss.backward()
optimizer.step()

total_loss += loss.item() * batch.batch_size
total_nodes += batch.batch_size

return total_loss / total_nodes

@torch.no_grad()
def evaluate_minibatch(model, loader):
model.eval()
total_correct = total_nodes = 0
for batch in loader:
out = model(batch.x, batch.edge_index)
pred = out[:batch.batch_size].argmax(dim=1)
total_correct += (pred == batch.y[:batch.batch_size]).sum().item()
total_nodes += batch.batch_size
return total_correct / total_nodes

print("\nTraining GraphSAGE on Reddit (mini-batch)...")
for epoch in range(10):
loss = train_minibatch(model, train_loader, optimizer)
val_acc = evaluate_minibatch(model, val_loader)
print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f}")

# Expected: ~93-94% accuracy on Reddit with 2-layer GraphSAGE

# ─── INDUCTIVE INFERENCE: NEW NODE EMBEDDINGS ─────────────────────────────────

def inductive_embed_new_node(
model: GraphSAGE,
new_node_features: torch.Tensor,
neighbor_features: torch.Tensor,
neighbor_edge_index: torch.Tensor,
) -> torch.Tensor:
"""
Compute embedding for a new node without retraining.

The new node's features and its sampled neighbor features
are passed through the trained aggregation functions.
This is the inductive capability of GraphSAGE.

Args:
new_node_features: (1, feature_dim) - new node's feature vector
neighbor_features: (num_neighbors, feature_dim) - sampled neighbor features
neighbor_edge_index: (2, num_neighbors) - edges from neighbors to new node

Returns:
embedding: (1, hidden_dim) - L2-normalized embedding
"""
# Combine new node with its sampled neighbors into a mini-graph
all_features = torch.cat([new_node_features, neighbor_features], dim=0)

model.eval()
with torch.no_grad():
embedding = model.embed(all_features, neighbor_edge_index, normalize=True)

# Return only the new node's embedding (index 0)
return embedding[0:1]

# Example: embed a new product at Pinterest
# new_pin_features = extract_visual_text_features(new_pin)
# sampled_neighbors = sample_neighbors_from_graph(new_pin_id, k=25)
# new_embedding = inductive_embed_new_node(model, new_pin_features, ...)
# recommendation = ann_search(new_embedding, all_embeddings, top_k=50)

# ─── UNSUPERVISED GRAPHSAGE ───────────────────────────────────────────────────

class UnsupervisedSAGE(nn.Module):
"""
GraphSAGE with contrastive (random walk) unsupervised training.

Positive pairs: nodes co-occurring in short random walks (structurally nearby)
Negative pairs: randomly sampled nodes (structurally distant)

Loss: -log(σ(z_u·z_v)) - Q·E[log(σ(-z_u·z_vn))]
"""
def __init__(self, in_channels: int, hidden_channels: int, out_channels: int):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)

def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return F.normalize(x, p=2, dim=1) # L2 normalize for cosine similarity

def contrastive_loss(
self,
z: torch.Tensor, # (N, d) - all node embeddings
pos_edge_index: torch.Tensor, # (2, pos_edges) - positive pairs
neg_edge_index: torch.Tensor, # (2, neg_edges) - negative pairs
) -> torch.Tensor:
"""
L = -log(σ(z_u·z_v)) - Q·E[log(σ(-z_u·z_vn))]

Positive pairs: co-occurring in random walks → embeddings close
Negative pairs: random nodes → embeddings distant
"""
# Positive loss: push nearby nodes closer
pos_src, pos_dst = pos_edge_index
pos_scores = (z[pos_src] * z[pos_dst]).sum(dim=1) # dot product
pos_loss = -F.logsigmoid(pos_scores).mean()

# Negative loss: push random nodes apart
neg_src, neg_dst = neg_edge_index
neg_scores = (z[neg_src] * z[neg_dst]).sum(dim=1)
neg_loss = -F.logsigmoid(-neg_scores).mean()

return pos_loss + neg_loss

def generate_negative_samples(
self,
num_nodes: int,
pos_edge_index: torch.Tensor,
Q: int = 5,
) -> torch.Tensor:
"""
Generate negative pairs by random sampling.
Q negatives per positive edge.
"""
num_pos = pos_edge_index.shape[1]
# Sample random pairs (unlikely to be actual edges for sparse graphs)
neg_src = torch.randint(0, num_nodes, (Q * num_pos,))
neg_dst = torch.randint(0, num_nodes, (Q * num_pos,))
return torch.stack([neg_src, neg_dst], dim=0)

# ─── COMPARISON: MANUAL AGGREGATORS ──────────────────────────────────────────

class MeanAggregator(nn.Module):
"""GraphSAGE mean aggregator - permutation invariant, no extra params."""
def aggregate(self, neighbor_embeds: torch.Tensor) -> torch.Tensor:
return neighbor_embeds.mean(dim=0)

class MaxPoolingAggregator(nn.Module):
"""GraphSAGE max-pooling aggregator - captures strongest signal."""
def __init__(self, in_dim: int, pool_dim: int):
super().__init__()
self.pool = nn.Linear(in_dim, pool_dim)

def aggregate(self, neighbor_embeds: torch.Tensor) -> torch.Tensor:
transformed = F.relu(self.pool(neighbor_embeds)) # (k, pool_dim)
return transformed.max(dim=0)[0] # (pool_dim,) - elementwise max

class LSTMAggregator(nn.Module):
"""
GraphSAGE LSTM aggregator - expressive but NOT permutation invariant.
Random shuffle of neighbors during training averages over orderings.
"""
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.lstm = nn.LSTM(in_dim, hidden_dim, batch_first=True)

def aggregate(self, neighbor_embeds: torch.Tensor) -> torch.Tensor:
# Random permutation of neighbors (reduces order bias)
perm = torch.randperm(neighbor_embeds.shape[0])
shuffled = neighbor_embeds[perm].unsqueeze(0) # (1, k, dim)
_, (h_n, _) = self.lstm(shuffled)
return h_n.squeeze() # (hidden_dim,)

PinSage: GraphSAGE at Pinterest Scale

PinSage (Ying et al., 2018, KDD 2018) is the most influential production deployment of GNNs. It built a recommendation system on a graph with:

  • 3 billion nodes (users, pins, boards)
  • 18 billion edges
  • Embedding dimension: 256
  • Latency target: 150ms end-to-end for real-time recommendations

Key Engineering Contributions

1. Importance-based neighbor sampling via random walks

Uniform sampling treats all neighbors equally. PinSage observed that nodes vary enormously in relevance - a pin in 10,000 boards is connected to 10,000 board nodes, but only a small subset are relevant for the pin's representation.

PinSage runs TT short random walks (length =2\ell = 2, T=200T = 200) from each target node and counts visit frequencies. Neighbors visited frequently are ranked higher - they are more "structurally important" in the local graph topology:

importance(uv)count(u visited from v in T walks)\text{importance}(u | v) \propto \text{count}(u \text{ visited from } v \text{ in } T \text{ walks})

This focuses computation on the most relevant neighbors and dramatically reduces noise from loosely-connected nodes. In ablation studies, importance-based sampling improved recommendation quality by 8–12% over uniform sampling.

2. Curriculum training with hard negatives

Easy negatives (random pins) are quickly separated - the model learns to distinguish them in the first few epochs and stops learning. Hard negatives (pins from the same category but wrong sub-topic) require finer-grained discrimination and provide meaningful gradient signal throughout training:

L=logσ(qp+)positive+λ1logσ(qprandom)easy negative+λ2logσ(qphard)hard negative\mathcal{L} = \underbrace{-\log\sigma(\mathbf{q} \cdot \mathbf{p}^+)}_{\text{positive}} + \underbrace{\lambda_1 \log\sigma(\mathbf{q} \cdot \mathbf{p}^-_{\text{random}})}_{\text{easy negative}} + \underbrace{\lambda_2 \log\sigma(\mathbf{q} \cdot \mathbf{p}^-_{\text{hard}})}_{\text{hard negative}}

Hard negatives are generated by: retrieving the top-2000 pins by dot product similarity for each query, then sampling from positions 500-2000 (close but not in top-500). This curriculum was critical for achieving 97.3% precision@10 in offline evaluation.

3. MapReduce for billion-scale embedding computation

Training on 3B nodes is done with a producer-consumer GPU pipeline. Inference (computing all 3B embeddings) uses MapReduce:

  • Stage 1: compute all layer-1 embeddings in parallel across the cluster
  • Stage 2: load layer-1 embeddings, compute layer-2 embeddings
  • Result: all 3B final embeddings are written to a distributed key-value store

4. Vector index for real-time retrieval

Embeddings are loaded into a vector index (approximate nearest neighbor, ANN). The original PinSage paper used product quantization + inverted file index. Modern deployments use HNSW (Hierarchical Navigable Small World) graphs or Faiss IVFPQ, which provide sub-millisecond search over billions of vectors.

At query time (user requests recommendations):

  1. Retrieve user embedding (pre-computed or computed on-the-fly from recent interactions)
  2. Run ANN search over pin embeddings: top-1000 in ~5ms
  3. Re-rank with collaborative filtering signals: top-50 in ~10ms
  4. Serve with business logic filters: final recommendations in ~150ms total

5. New pin embedding at add time

When a new pin is uploaded, PinSage immediately computes its embedding:

  1. Extract visual and text features (image CNN embedding, title/description encoding)
  2. Sample neighbors from the existing graph (boards containing similar pins, users who pinned similar items)
  3. Apply the trained GraphSAGE aggregation functions
  4. Index the resulting embedding into the ANN vector store

Total time: 50–100ms per new pin. No retraining. This is the defining advantage of inductive learning at scale.


Comparison: GCN vs GAT vs GraphSAGE

PropertyGCNGATGraphSAGE
Learning targetNode embeddings (transductive)Node embeddings (transductive or inductive)Aggregation functions (inductive)
New nodesNo (requires retraining)Yes (with features)Yes (with features)
Full graph at trainingRequiredRequired (standard)Not required (mini-batch)
Neighbor weightingFixed (degree-based)Learned (attention)Fixed (mean/max) or learned
ScalabilityPoor (full graph)Poor (full graph)Excellent (mini-batch + sampling)
Memory complexity$O(n \cdot d +E)$
InterpretabilityLowHigh (attention)Low-medium
Production deploymentsResearch, small graphsResearch, medium graphsPinterest, Airbnb, Uber

When Inductive Matters in Production

E-commerce recommendations: new products are added continuously. The recommendation model must produce embeddings for new products immediately upon listing - before any user interaction data exists. GraphSAGE using product features (images, descriptions, category) and their connections to existing products provides instant embeddings.

Evolving citation graphs: new papers are published daily. A citation-based recommendation system for researchers must update paper embeddings as new citations accrue without retraining.

Fraud detection: new accounts are created constantly, including fraudulent ones. A GNN that can compute embeddings for new accounts from their initial transaction graph and feature profile (even before they have many transactions) enables real-time fraud detection at account creation.

Knowledge graphs: new entities are constantly added (companies are founded, people gain new roles, events occur). Inductive GNNs can embed new entities using their relational context without graph-wide retraining.


YouTube Resources

ResourceCreatorFocus
GraphSAGE Paper ExplainedAleksa GordićFull derivation, aggregators, PyG code
PinSage: Pinterest's GNN for RecommendationsRex YingKDD 2018 talk, billion-scale deployment
Inductive Representation LearningStanford CS224WInductive vs transductive GNNs
Graph Mini-Batch Training with PyGAntonio LongaNeighborLoader and mini-batch tutorial
Scalable Graph Learning with Neighbor SamplingNeurIPS TutorialLarge-scale GNN training strategies

Common Mistakes

:::danger Common Mistake 1: Using full-graph GCN for large production graphs Full-graph GCN training loads all node features and the full adjacency into GPU memory. For graphs with millions of nodes, this is infeasible. Use GraphSAGE with NeighborLoader for graphs above ~100K nodes. The sampling adds stochasticity but also regularization - GraphSAGE typically matches or exceeds GCN accuracy on large datasets. :::

:::danger Common Mistake 2: Indexing by batch.batch_size incorrectly in NeighborLoader NeighborLoader batches include the target nodes plus their sampled neighborhood (needed for message passing). The first batch.batch_size nodes are the targets; the remaining are context nodes. Using out (the full forward pass output) for loss computation trains on context nodes, not just targets - producing incorrect gradients. Always use out[:batch.batch_size] and batch.y[:batch.batch_size] for loss and evaluation. :::

:::warning Common Mistake 3: Not L2 normalizing GraphSAGE embeddings for retrieval tasks GraphSAGE's L2 normalization step is not cosmetic - it converts the dot product between embeddings into cosine similarity. Cosine similarity is bounded in [1,1][-1, 1], scale-invariant, and enables efficient approximate nearest neighbor search. Without normalization, dot products are dominated by embedding magnitude rather than direction, degrading retrieval quality. Always normalize embeddings when using GraphSAGE for recommendation, link prediction, or retrieval. :::

:::warning Common Mistake 4: Using the same neighborhood sample sizes for all layers Deeper layers aggregate over larger receptive fields and require fewer samples per node (because the information has already been aggregated at lower layers). Using [25,10][25, 10] for a 2-layer model is appropriate: the outer layer (closest to the target) samples 25 neighbors, and the inner layer samples 10 for each of those. Using [25,25][25, 25] doubles the computation with diminishing returns. Some production systems use [15,10,5][15, 10, 5] for 3-layer models. :::


Interview Q&A

Q1: What is the core difference between GraphSAGE and GCN in terms of what is learned?

GCN learns an embedding vector for each node in the training graph - these are directly optimized parameters. The model has n×dn \times d parameters just for the embedding table. If a new node joins, there is no parameter for it and retraining is required. GraphSAGE learns aggregation functions: parameterized transformations that compute any node's embedding from its neighbors' features. The model has parameters for the aggregation function (W(k)W^{(k)} matrices and potentially the aggregator's internal parameters), but not for individual nodes. For a new node vv' with known features and neighborhood, the trained aggregation functions produce an embedding without any parameter update. This is the inductive capability: the function generalizes to nodes not seen during training.

Q2: Why does GraphSAGE concatenate the node's own embedding with the aggregated neighbor embedding? What does GCN do differently?

GraphSAGE: hv=σ(W[hvselfAGG(N(v))])\mathbf{h}_v = \sigma(W[\mathbf{h}_v^{\text{self}} \| \text{AGG}(\mathcal{N}(v))]) - concatenation preserves the distinction between the node's own representation and what it received from neighbors. The weight matrix WW can learn to weight self vs. neighbor information differently; the first half of WW controls self-contribution, the second half controls neighbor contribution. GCN: hv=σ(WMEAN(N(v){v}))\mathbf{h}_v = \sigma(W \cdot \text{MEAN}(\mathcal{N}(v) \cup \{v\})) - self and neighbors are pooled together before the linear transformation, losing the distinction. The GCN approach is slightly simpler but less expressive. In practice, concatenation consistently outperforms pooling when the node's own features are important, which is most of the time.

Q3: Explain how neighbor sampling solves the computational graph explosion problem.

Without sampling, computing a kk-layer GNN for a node vv requires all nodes in its kk-hop neighborhood. For mean degree dˉ\bar{d}, this is dˉk\bar{d}^k nodes per target node. For dˉ=50,k=2\bar{d}=50, k=2: 2500 nodes per target. For a batch of 1024, potentially 2.56M nodes - the entire graph. Neighbor sampling fixes this: at each layer, sample exactly SlS_l neighbors. For S1=25,S2=10S_1=25, S_2=10: 25×10=250 per target, 256K for a batch of 1024. The computation graph size is fixed and predictable regardless of graph size or degree distribution. The sampling is stochastic - different runs see different neighborhood samples - which acts as data augmentation and regularization. Empirically, the accuracy loss from sampling (vs full neighborhood) is small for most tasks at typical sample sizes.

Q4: What is PinSage's random-walk neighbor sampling and why is it better than uniform sampling?

PinSage runs T=200T=200 short random walks (length =2\ell=2) from each target node and counts how many times each neighbor is visited. Nodes visited frequently (high visit count) are more "structurally important" - they are central to the local graph structure around the target. Uniform sampling treats all 10,000 board-neighbors of a popular pin equally, including boards where the pin is an outlier. Importance-based sampling focuses computation on the most topologically relevant neighbors. In PinSage's experiments, this improved offline recommendation precision by 8–12% over uniform sampling and reduced noise from low-relevance connections. It is also consistent with PageRank's intuition: nodes that are frequently visited by random walks from a source are the most proximally relevant to that source.

Q5: How does PinSage serve embeddings for 3 billion nodes in real time?

PinSage decouples training from serving. Offline: compute all 3B embeddings using MapReduce in two passes (layer-1 aggregations → layer-2 aggregations). This takes hours on a Hadoop cluster but only needs to run when the model is updated (roughly weekly). The embeddings are written to a distributed key-value store. Separately: load all embeddings into an ANN index (product quantization + inverted file, or HNSW) for sub-millisecond nearest-neighbor search. Online: at query time, retrieve the user's embedding (pre-computed or computed from recent interactions), run ANN search (~5ms), re-rank, and serve. New pins: at upload time, compute embedding using the trained GraphSAGE functions (50–100ms), write to the KV store, and index - no retraining. The full pipeline achieves the 150ms end-to-end latency requirement by pre-computing heavy computations offline and serving only lightweight retrieval online.

Q6: Compare the three GraphSAGE aggregators. When would you choose each?

Mean aggregator: no additional parameters, permutation invariant, computes the average of neighbor feature vectors. Best default choice - fast, effective, and equivalent to GCN's aggregation on many tasks. Max-pooling aggregator: applies a learned transformation to each neighbor before taking element-wise max. Captures the "most activated" feature from any neighbor in the sample. Better when a single highly-relevant neighbor (a category-defining exemplar) is more informative than the average of many mediocre neighbors. Good for tasks where the most extreme values matter. LSTM aggregator: passes randomly-shuffled neighbors through an LSTM. Most expressive but not permutation-invariant - requires random shuffling during training to reduce order bias. Rarely outperforms max-pooling enough to justify the sequential computation cost and training instability. Use only if mean and max-pooling have been exhausted and you have the compute budget for the LSTM's sequential processing.


Historical Context: From Node2Vec to GraphSAGE

The pre-GNN approach to node embeddings was DeepWalk (Perozzi et al., 2014) and Node2Vec (Grover & Leskovec, 2016). These methods ran random walks on the graph and trained Word2Vec-style skip-gram models to produce node embeddings. The resulting embeddings were effective but fundamentally transductive: every node received its own embedding vector, and new nodes required rerunning the entire training procedure.

The connection to GNNs became clear when Kipf & Welling (2017) showed that GCN could be derived as a special case of graph-based semi-supervised learning. GCN was expressive and learned from node features, but shared the transductive limitation for the standard formulation - the feature aggregation could in principle generalize, but the model was trained on a fixed graph.

Hamilton, Ying, and Leskovec (2017) at Stanford published GraphSAGE as a direct response to transductive limitations. The paper's title - "Inductive Representation Learning on Large Graphs" - made the contribution explicit. The key insight was to learn an aggregation function rather than node-specific parameters. Once the aggregation function is learned, it applies to any node anywhere on any graph.

The PinSage deployment (Ying et al., 2018) validated GraphSAGE at production scale. Pinterest's graph had 3 billion nodes (pins and boards) and 18 billion edges - three orders of magnitude larger than anything in prior academic work. The paper contributed importance-based random walk sampling, curriculum training, and a MapReduce inference pipeline - engineering innovations that made billion-scale GNN training feasible for the first time.

Deep Graph Infomax (Velickovic et al., 2019) extended the inductive paradigm to unsupervised learning. Rather than requiring node labels, DGI maximizes mutual information between node-level and graph-level representations using a contrastive objective - the same principle as GraphSAGE's unsupervised random walk loss but with a stronger theoretical grounding in information maximization.


Cluster-GCN and GraphSAINT: Mini-Batch Alternatives

GraphSAGE's neighbor sampling is not the only approach to scalable GNN training. Two alternatives address specific limitations.

Cluster-GCN (Chiang et al., 2019) partitions the graph into clusters using METIS (a graph partitioning algorithm), then trains on one cluster per batch. Within-cluster edges are preserved; between-cluster edges are dropped for that batch. This eliminates the computational graph explosion because the cluster provides a natural boundary.

Tradeoff: METIS requires a graph partitioning precomputation step that is O(E)O(|E|) but has a large constant. Also, the partition structure must align with the graph's natural community structure for best performance. For citation networks and social graphs with strong communities, Cluster-GCN matches or exceeds GraphSAGE accuracy with lower computation. For irregular graphs (random geometric graphs, power grids), partitions are artificial and the between-cluster edge dropout hurts performance.

GraphSAINT (Zeng et al., 2020) samples subgraphs rather than per-layer neighborhoods. It builds a training subgraph by sampling nodes via random walk or edge sampling, then runs a full GNN forward/backward pass on that subgraph. Each subgraph is small enough to fit in GPU memory. Normalization factors are precomputed to ensure unbiased gradient estimates despite the sampling.

# GraphSAINT-style subgraph sampling with PyG
from torch_geometric.loader import GraphSAINTRandomWalkSampler

loader = GraphSAINTRandomWalkSampler(
data,
batch_size=6000, # nodes per subgraph
walk_length=2, # random walk length for subgraph construction
num_steps=5, # number of batches per epoch
sample_coverage=100, # samples for normalization factor estimation
save_dir='data/saint_cache'
)

for subgraph in loader:
out = model(subgraph.x, subgraph.edge_index)
# Note: subgraph.node_norm and subgraph.edge_norm provide sampling corrections
loss = F.cross_entropy(
out[subgraph.train_mask],
subgraph.y[subgraph.train_mask],
)
# subgraph.node_norm can be used as sample_weight for unbiased estimation

GraphSAINT produces lower variance gradient estimates than NeighborLoader (because the entire subgraph is used for forward/backward), but requires precomputing normalization factors - a one-time cost that may be expensive on very large graphs.


GraphSAGE for Heterogeneous Graphs

Real-world graphs are often heterogeneous: multiple types of nodes and edges. An e-commerce graph may have user nodes, product nodes, category nodes, and edges like "purchased", "viewed", "belongs_to". Standard GraphSAGE treats all nodes and edges as the same type.

PyTorch Geometric handles this with HeteroData and to_hetero():

from torch_geometric.data import HeteroData
from torch_geometric.nn import to_hetero, SAGEConv
import torch_geometric.transforms as T

# Build heterogeneous graph
data = HeteroData()
data['user'].x = user_features # (num_users, user_feat_dim)
data['product'].x = product_features # (num_products, prod_feat_dim)
data['user', 'buys', 'product'].edge_index = purchase_edges
data['user', 'views', 'product'].edge_index = view_edges
data['product', 'belongs_to', 'category'].edge_index = category_edges

# Convert homogeneous SAGEConv model to heterogeneous
class HomoSAGE(torch.nn.Module):
def __init__(self, hidden):
super().__init__()
self.conv1 = SAGEConv((-1, -1), hidden) # lazy init for hetero
self.conv2 = SAGEConv((-1, -1), hidden)

def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
return self.conv2(x, edge_index)

homo_model = HomoSAGE(hidden=64)
# to_hetero wraps each layer with type-specific parameters
model = to_hetero(homo_model, data.metadata(), aggr='sum')

# Forward: produces embeddings for each node type
out_dict = model(data.x_dict, data.edge_index_dict)
user_embeddings = out_dict['user'] # (num_users, 64)
product_embeddings = out_dict['product'] # (num_products, 64)

The to_hetero() transformation creates separate weight matrices for each edge type, allowing the aggregation function to specialize per relationship. This is equivalent to Relational GCN (R-GCN) when using mean aggregation, but with GraphSAGE's inductive capability.


Production Deployment: Embedding Serving Architecture

For systems like PinSage, the gap between research prototype and production deployment is primarily engineering. The model training pipeline and the serving pipeline are completely separate.

Two serving paths:

  1. Pre-computed (existing nodes): embeddings for all known nodes are generated offline via MapReduce (two passes over the graph: layer-1 then layer-2 aggregations). Stored in a KV store keyed by node ID. At query time, a single KV lookup retrieves the embedding in under 1ms.
  2. Online (new nodes): when a new pin is uploaded, its embedding is computed in real time by fetching its neighborhood from the graph database, running the trained aggregation functions (50–100ms on CPU), and inserting the result into the KV store and ANN index.

The ANN index is refreshed weekly when the full embedding batch is recomputed. New node embeddings are inserted incrementally without rebuilding the index - most ANN libraries support online insertion with minor accuracy degradation.


Key Takeaways

GraphSAGE is the production-grade GNN. Where GCN requires the entire graph in memory and cannot embed new nodes, GraphSAGE handles million-node graphs on commodity hardware and embeds unseen nodes in milliseconds using learned aggregation functions.

The core design choices: (1) learn aggregation functions, not node embeddings - this is what makes it inductive; (2) concatenate self and neighbor before transformation - preserves the distinction; (3) L2 normalize final embeddings - enables cosine similarity and ANN retrieval; (4) sample neighbors at each layer - keeps computation bounded regardless of graph size.

Default setup: 2 layers, neighbor samples [25,10][25, 10], mean aggregation, NeighborLoader, Adam with lr=0.01lr=0.01. Add max-pooling aggregator if a single highly-relevant neighbor is more important than the average. Add importance-based sampling (PinSage style) for power-law degree distributions common in social and recommendation graphs. The unsupervised loss (random walk negative sampling) is your friend when labels are scarce - GraphSAGE embeddings trained unsupervised are competitive with GCN trained with labels on many downstream tasks.


GraphSAGE vs Cluster-GCN vs GraphSAINT: Choosing a Mini-Batch Strategy

All three methods solve the same problem - scalable GNN training - using different sampling strategies. The choice depends on graph structure and computational constraints.

MethodSampling UnitMemoryGradient BiasBest For
GraphSAGE (NeighborLoader)Per-node neighborhoodO(SKB)O(S^K \cdot B)Unbiased (node)Power-law degree, irregular graphs
Cluster-GCNGraph partitionO(cluster size)O(\text{cluster size})Biased (drops between-cluster edges)Graphs with strong community structure
GraphSAINTRandom walk subgraphO(subgraph size)O(\text{subgraph size})Unbiased (with normalization)Dense graphs, full-message accuracy needed

The rule of thumb: start with NeighborLoader (GraphSAGE). It is the easiest to set up, has strong PyG support, and works on any graph. If you have community structure (modularity > 0.3) and can afford the METIS precomputation, try Cluster-GCN - it often achieves slightly higher accuracy by preserving within-cluster message passing. If gradient noise from NeighborLoader is hurting convergence on a dense graph, switch to GraphSAINT.


Temporal Graphs: Extending GraphSAGE to Dynamic Settings

Production graphs are not static. New nodes and edges arrive continuously - new users sign up, new products are listed, new transactions occur. GraphSAGE handles new nodes naturally (compute from neighborhood), but the neighborhood itself changes over time.

Temporal GraphSAGE (a practical extension, not a formal paper) addresses this by treating time as an edge feature or by restricting neighbor sampling to temporally proximal edges:

from torch_geometric.data import TemporalData
import torch

# Temporal edge data: each edge has a timestamp
temporal_data = TemporalData(
src=torch.tensor([0, 1, 2, 0, 3]),
dst=torch.tensor([1, 2, 3, 3, 4]),
t=torch.tensor([100, 110, 120, 130, 140]), # Unix timestamps
msg=torch.randn(5, 16), # optional edge features (e.g., transaction amount)
)

# When computing embedding for node v at time T,
# only sample neighbors with edge timestamp < T.
# This prevents future information leakage (a common bug in temporal GNNs).
def sample_temporal_neighbors(node_id, current_time, edge_index,
edge_timestamps, num_samples=25):
"""Sample neighbors using only edges that occurred before current_time."""
dst_mask = (edge_index[1] == node_id) & (edge_timestamps < current_time)
valid_srcs = edge_index[0][dst_mask]
if len(valid_srcs) == 0:
return torch.tensor([], dtype=torch.long)
idx = torch.randperm(len(valid_srcs))[:num_samples]
return valid_srcs[idx]

For full temporal GNN support, PyG provides torch_geometric.nn.models.TGN (Temporal Graph Network, Rossi et al., 2020), which combines GraphSAGE-style aggregation with a memory module that stores each node's interaction history. TGN is state-of-the-art for temporal link prediction (predicting which user will interact with which product in the next hour).

The critical rule for all temporal GNNs: never allow future edges to influence past predictions. Leakage through temporal neighbors inflates test metrics to near-perfect levels - a common mistake that produces models that fail completely in production.


Advanced Interview Q&A

Q7: How would you detect and handle the over-smoothing problem in a 5-layer GraphSAGE model?

Over-smoothing - where deeper layers produce indistinguishable embeddings for all nodes - is less severe in GraphSAGE than in GCN because of two design choices: (1) concatenation of self-embedding with aggregated neighbor embedding preserves the node's identity at each layer; (2) L2 normalization after each layer prevents embedding collapse. Nevertheless, with 5 layers, the receptive field covers 25-to-10-to-10-to-10-to-10 sampled neighborhoods - a very large fraction of the graph. To detect: measure the cosine similarity between all pairs of embeddings at each layer. If layer-5 embeddings have average cosine similarity > 0.95, you have over-smoothing. To handle: (a) reduce to 2–3 layers; (b) add jumping knowledge connections (concatenate all layer outputs before the final prediction head, allowing each node to choose the most useful depth); (c) add a skip connection from the input features to the output layer; (d) increase dropout to reduce effective depth. The MAD (Mean Average Distance) metric: compute MAD=1Vv1N(v)uN(v)d(hv,hu)\text{MAD} = \frac{1}{|V|}\sum_v \frac{1}{|\mathcal{N}(v)|}\sum_{u \in \mathcal{N}(v)} d(h_v, h_u) for each layer. If MAD decreases monotonically toward zero, the model is over-smoothing.

Q8: Explain how GraphSAGE's unsupervised loss enables transfer learning across graphs.

The unsupervised random walk loss L=logσ(zuTzv)QEvn[logσ(zuTzvn)]\mathcal{L} = -\log\sigma(z_u^T z_v) - Q \mathbb{E}_{v_n}[\log\sigma(-z_u^T z_{v_n})] trains the model to produce high dot products for node pairs that co-occur on short random walks (structurally proximate) and low dot products for random pairs (negative samples). This objective does not depend on node labels - it only requires graph structure and features. As a result: the trained aggregation functions learn to encode structural and feature proximity into embedding space, independent of any downstream task. For transfer learning: train GraphSAGE on a large unlabeled graph (e.g., the full Twitter graph), then fine-tune the final linear layer on a smaller labeled subgraph. The aggregation functions (conv layers) transfer well because they encode general graph structure principles. This is analogous to BERT pretraining on unlabeled text and fine-tuning on labeled classification data. For molecular graphs, GraphSAGE pretrained on large unlabeled molecular databases consistently outperforms supervised-only training when labeled molecules are scarce.


GraphSAGE Debugging and Performance Tuning

Production GraphSAGE deployments fail in predictable ways. This systematic checklist addresses the most common failure modes.

SymptomLikely CauseFix
Low accuracy on new nodesInductive generalization failureCheck if new nodes have OOD features; retrain with more diverse data
Slow NeighborLoaderToo many neighbors per layerReduce from [25,10] to [15,5]; increase batch size to compensate
OOM during mini-batch trainingNeighborhood too largeReduce num_neighbors; use disjoint=True in NeighborLoader
Retrieval quality poorEmbeddings not L2 normalizedAdd F.normalize(out, p=2, dim=-1) before indexing
Unsupervised loss not convergingToo few negative samplesIncrease Q from 5 to 20; use harder negatives
Accuracy drops for high-degree nodesSampling underrepresents neighborhoodUse importance sampling (PinSage) weighted by visit count
def profile_neighborloader(data, num_neighbors, batch_size=512):
"""
Profile NeighborLoader to identify bottlenecks before production deployment.
Reports: batch time, avg nodes per batch, avg edges per batch.
"""
import time
from torch_geometric.loader import NeighborLoader

loader = NeighborLoader(
data,
num_neighbors=num_neighbors,
batch_size=batch_size,
input_nodes=data.train_mask,
shuffle=True,
num_workers=4,
)

times, node_counts, edge_counts = [], [], []
for i, batch in enumerate(loader):
if i >= 10:
break
t0 = time.time()
# Simulate forward pass cost
_ = batch.x.mean(dim=0)
times.append(time.time() - t0)
node_counts.append(batch.num_nodes)
edge_counts.append(batch.num_edges)

print(f"Avg batch time: {sum(times)/len(times)*1000:.1f}ms")
print(f"Avg nodes per batch: {sum(node_counts)/len(node_counts):.0f}")
print(f"Avg edges per batch: {sum(edge_counts)/len(edge_counts):.0f}")
print(f"Nodes/target ratio: {sum(node_counts)/len(node_counts)/batch_size:.1f}x overhead")

# The nodes/target ratio shows how many context nodes are loaded per target node.
# For [25,10] with 2 layers: expect ~80-120x overhead.
# If overhead is >200x, reduce num_neighbors or batch_size.

GraphSAGE for Fraud Detection: A Production Walkthrough

Fraud detection graphs are the most commercially valuable application of GraphSAGE outside of recommendation. The graph: accounts are nodes, transactions are edges. Features: account age, historical transaction volume, IP geolocation entropy, device fingerprint. Labels: fraud / not-fraud (highly imbalanced: ~0.1% fraud).

The inductive capability is critical: fraudsters create new accounts continuously. GCN cannot embed new accounts; GraphSAGE embeds them from their initial transactions within seconds of their first activity.

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader
from sklearn.metrics import roc_auc_score
import numpy as np

class FraudSAGE(torch.nn.Module):
"""
GraphSAGE for transaction fraud detection.
Uses 3 layers to capture 3-hop behavioral patterns (account → shared device →
other accounts on same device → those accounts' transaction histories).
"""
def __init__(self, in_channels, hidden=128, dropout=0.3):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden)
self.conv2 = SAGEConv(hidden, hidden)
self.conv3 = SAGEConv(hidden, hidden // 2)
self.head = torch.nn.Linear(hidden // 2, 1)
self.dropout = dropout

def forward(self, x, edge_index):
h = F.relu(self.conv1(x, edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
h = F.relu(self.conv2(h, edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
h = F.relu(self.conv3(h, edge_index))
return self.head(h).squeeze(-1)

# Class imbalance: use weighted loss
# Fraud rate ~0.1% → pos_weight = 999 (ratio of negative to positive)
pos_weight = torch.tensor([999.0])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

model = FraudSAGE(in_channels=data.num_features)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

loader = NeighborLoader(
data,
num_neighbors=[15, 10, 5], # 3 layers with decreasing samples
batch_size=1024,
input_nodes=data.train_mask,
)

for epoch in range(50):
model.train()
for batch in loader:
out = model(batch.x, batch.edge_index)
loss = criterion(
out[:batch.batch_size],
batch.y[:batch.batch_size].float()
)
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Embed new account in real time (inductive inference)
def embed_new_account(model, new_node_features, neighbor_features, neighbor_edge_index):
"""
Compute fraud score for a brand-new account with no history.
neighbor_features: features of the new account's known neighbors (shared devices, etc.)
neighbor_edge_index: edges connecting new node (index 0) to its neighbors
"""
model.eval()
x = torch.cat([new_node_features.unsqueeze(0), neighbor_features], dim=0)
with torch.no_grad():
logit = model(x, neighbor_edge_index)
fraud_prob = torch.sigmoid(logit[0]).item()
return fraud_prob # risk score for real-time decision

Real-world results from GraphSAGE fraud systems: 15–25% lift in AUC over feature-only logistic regression, 30–40% reduction in false negatives on new account fraud (the hardest case), and real-time embedding latency under 50ms for accounts with up to 500 neighbors. The 3-hop receptive field captures "money mule" patterns - chains of accounts that each transfer funds one hop at a time to obscure the trail.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the GNN Message Passing demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.