Message Passing Neural Networks
The Real Interview Moment
"We have GCN, GraphSAGE, GAT, Interaction Networks, and five other GNN variants on our shortlist. Our team cannot agree which to use for molecular property prediction. How do you even compare them systematically?"
This exact tension was felt at Google Brain and DeepMind in 2017. The field had produced a proliferating zoo of graph neural network architectures, each with its own paper, its own notation, and its own claimed advantage. Some operated on node features, others on edge features. Some used attention, others simple summation. Some claimed better expressivity, others better scalability. There was no principled basis for comparison - just empirical ablations on specific datasets that rarely generalized.
Justin Gilmer and colleagues sat down to read every major GNN paper carefully. A pattern emerged. Strip away the notation differences, and almost every architecture was doing the same three operations: compute messages from neighbors, aggregate those messages, and update the node's hidden state. The differences were only in how each step was implemented. The team wrote a general template - the Message Passing Neural Network (MPNN) framework - and showed GCN, GraphSAGE, GAT, Interaction Networks, DTNN, and several others were all special cases.
This unification was more than taxonomic convenience. It allowed the team to ask the right question: given this framework, what is the fundamental limit of what any MPNN can express? The answer came from a surprising direction - a 1968 graph theory algorithm called the Weisfeiler-Lehman (WL) isomorphism test. Any MPNN is at most as powerful as 1-WL. This ceiling explained why so many GNNs failed on certain graph classification tasks and launched a new research agenda: how do you build GNNs provably more expressive than 1-WL?
Understanding MPNN is not optional for any engineer working on graph-structured data. It is the lingua franca of the field. Every time you read a GNN paper and see "message function," "aggregation," and "update function" in the methods section, you are seeing the MPNN vocabulary. When you debug a GNN that cannot distinguish two structurally different molecules, you are hitting the 1-WL ceiling. When you read about SE(3)-equivariant networks for protein structure prediction, you are seeing the response to that ceiling.
Why This Exists - The GNN Zoo Problem
By 2017, the GNN literature had become fragmented. Consider just a few architectures that existed:
- Spectral GCN (Bruna et al., 2014) - graph convolution in the Fourier domain
- ChebNet (Defferrard et al., 2016) - polynomial filters on the graph Laplacian
- GCN (Kipf and Welling, 2017) - simplified spectral convolution, layer-wise propagation
- GraphSAGE (Hamilton et al., 2017) - sample-and-aggregate, inductive learning
- GAT (Velickovic et al., 2018) - attention-weighted aggregation
- Interaction Networks (Battaglia et al., 2016) - physics simulation on graphs
- DTNN (Schütt et al., 2017) - deep tensor neural networks for quantum chemistry
- Molecular Graph Convolutions (Duvenaud et al., 2015) - neural circular fingerprints
Each was described in its own notation. Comparing them required careful reading of multiple papers. Practitioners could not easily decide which to use for a new task because the assumptions and tradeoffs were buried in architectural choices that looked superficially different.
The MPNN framework provides a common language. By formalizing the three-step structure - message, aggregate, update - it becomes straightforward to compare architectures, combine ideas across papers, and reason precisely about what cannot be achieved within the framework.
Historical Context
The paper Neural Message Passing for Quantum Chemistry (Gilmer et al., 2017, NeurIPS) was motivated by a concrete application: predicting quantum chemical properties of small organic molecules from the QM9 dataset. QM9 contains 134K molecules with 13 quantum properties computed by density functional theory (DFT) - atomization energy, dipole moment, HOMO-LUMO gap, polarizability, and more. DFT calculations take hours per molecule. If a GNN could predict these properties accurately from molecular graph structure alone, it would accelerate drug discovery and materials science enormously.
To build the best model, the team needed to understand the design space of all possible GNNs for this task. That forced them to think systematically. The result was the MPNN framework, and the MPNN model they built set a new state-of-the-art on QM9, predicting 11 of 13 properties within chemical accuracy.
The deeper contribution - recognizing the connection to the Weisfeiler-Lehman test - came in How Powerful are Graph Neural Networks? (Xu et al., 2019, ICLR), which proved formally that any MPNN is at most as powerful as 1-WL and introduced Graph Isomorphism Networks (GIN) as the maximally expressive MPNN.
The MPNN Framework - Formal Definition
An MPNN operates on a graph where each node has a feature vector and each edge has a feature vector .
The framework has two phases.
Phase 1: Message Passing
Runs for steps. At each step , every node sends and receives messages via three functions:
Message function - computes a message from neighbor to node :
Aggregation function - combines all incoming messages for node :
Update function - updates the node's hidden state:
Phase 2: Readout
After steps, a readout function produces a graph-level representation for graph-level tasks:
For node-level tasks the final are used directly. For edge-level tasks, edge representations are computed from endpoint representations.
:::note What "message passing" physically means Information flows along edges. In step 1, node generates messages to send to neighbors. In step 2, node collects all incoming messages and combines them. In step 3, node updates its own state. After steps, each node has aggregated information from its full -hop neighborhood. :::
GCN as a Special Case of MPNN
Recall the GCN update rule:
Mapping to MPNN:
Message function:
The message from to is the neighbor's feature multiplied by weight matrix and scaled by the symmetric normalization factor. GCN does not use edge features.
Aggregation:
Simple summation. The normalization is baked into the message function.
Update function:
The self-loop () is included in the aggregation, so the update applies only the activation.
Readout: mean or sum over all final node embeddings.
GraphSAGE as a Special Case of MPNN
GraphSAGE uses a concatenation-based update:
Message function:
The message is simply the neighbor's current hidden state - no edge features, no weighting.
Aggregation:
GraphSAGE proposes three aggregators. MEAN = element-wise mean of neighbor features. MAX = element-wise max-pooling. LSTM = process a random permutation of neighbors through an LSTM (order-dependent in theory, works in practice).
Update function:
Concatenate the node's own features with the aggregated message, then apply a linear transform and activation. This explicit concatenation of self and neighborhood is GraphSAGE's defining characteristic.
GAT as a Special Case of MPNN
Graph Attention Networks use learned attention weights:
where .
Message function:
The coefficient depends on both endpoint states, making this a data-dependent rather than topology-fixed scaling.
Aggregation: - weighted sum with attention weights already incorporated in messages.
Update:
Key insight in MPNN terms: GCN vs GAT differ only in the message function. GCN uses fixed normalization from graph structure; GAT uses learned data-dependent weights from node features.
The Unification Diagram
Edge Features in MPNN
One area where MPNN shines over standard GCN and GAT is edge feature support. Molecular graphs have rich edge features - bond type (single/double/triple/aromatic), bond length, bond angle. Ignoring these discards crucial chemical information.
In the MPNN framework, edge features are first-class citizens of the message function:
Gilmer et al. used a neural network that produces a matrix from edge features:
where is produced by a small MLP applied to the edge feature vector. The edge feature gates information flow along that edge.
For directional molecular graphs this becomes directional message passing - the message from to differs from the message from to , parameterized by shared edge features but different endpoint states.
MPNN for Molecular Property Prediction - QM9
The QM9 dataset is the canonical benchmark. Each molecule is a graph:
- Nodes = atoms (features: atom type, formal charge, number of hydrogens, aromaticity, hybridization)
- Edges = bonds (features: bond type - single/double/triple/aromatic)
- Labels = 13 quantum chemical properties (DFT-computed)
The original MPNN paper used a GRU-based update (treating steps as a recurrent computation over the graph) and a "set2vec" attention readout. This achieved chemical accuracy on 11 of 13 QM9 properties.
The Expressivity Ceiling - MPNN and 1-WL
The Weisfeiler-Lehman Test
The 1-dimensional Weisfeiler-Lehman (1-WL) graph isomorphism algorithm:
- Initialize each node with a label (e.g., its degree or atom type)
- At each iteration: update each node's label by hashing its current label together with the multiset of its neighbors' labels
- Repeat until labels stabilize
- If the multisets of final labels for two graphs differ, they are non-isomorphic
The 1-WL test is powerful but incomplete - there exist non-isomorphic graph pairs that 1-WL assigns identical label distributions to.
The MPNN-WL Connection
Xu et al. (2019) proved:
Theorem: Any MPNN with a sum aggregation and an injective update function is exactly as expressive as 1-WL. No MPNN can exceed 1-WL expressiveness.
Why? Because MPNN update rules are structurally identical to 1-WL label refinement. Each node's new state is a function of its current state and the multiset of its neighbors' states - precisely what 1-WL does. If 1-WL cannot distinguish two graphs (assigns them identical label multisets at every round), then no MPNN can, because the computation is isomorphic.
A Concrete Failure Case
The canonical 1-WL failure involves strongly regular graphs. A simpler example for intuition:
Consider the rook's graph (nodes arranged in a 3x3 grid where each node connects to all others in the same row and column) versus the Petersen graph. Both have 10 nodes, all with degree 4, and 1-WL assigns identical labels to all nodes in both after several rounds.
For molecular chemistry: two non-isomorphic molecules with the same local atom environment at every neighborhood radius cannot be distinguished. An MPNN trained on atom type + bond type features would produce identical fingerprints for both and assign the same predicted properties - even if those properties differ experimentally.
Beyond 1-WL - Breaking the Ceiling
k-GNNs (Morris et al., 2019)
k-GNNs operate on -tuples of nodes instead of individual nodes. A -GNN is exactly as powerful as the -WL test:
| Level | Power | Cost |
|---|---|---|
| 1-GNN (standard MPNN) | 1-WL | memory, compute |
| 2-GNN | 2-WL - counts triangles, 2-paths | memory |
| 3-GNN | 3-WL - distinguishes all strongly regular graphs | memory |
For nodes: 3-GNN needs tuple representations. Impractical for any real molecular dataset.
Structural Encodings (Practical Alternative)
Augment node features with precomputed structural information:
- Random walk encodings: append diagonal of (return probability in steps) - encodes triangle counts and cycle structure
- Laplacian eigenvectors: use spectral information as positional encodings (invariant to graph isomorphism)
- Ring size features: smallest ring containing each atom - crucial for aromatic chemistry
- Distance encodings: shortest path distances between all node pairs
These break the 1-WL ceiling without exponential cost. The tradeoff: expensive preprocessing and problem-specific design.
Equivariant Message Passing for 3D Geometry
The 3D Molecule Problem
Molecular graphs in 2D (connectivity only) miss 3D geometry - bond angles, dihedral angles, and atomic coordinates determine molecular properties like binding affinity and reactivity. Incorporating 3D geometry requires SE(3)-equivariance: if the molecule is rotated or translated in space, predicted scalar properties must be unchanged and predicted vector properties must rotate/translate identically.
EGNN - E(n)-Equivariant Graph Neural Network
Satorras et al. (2021) extended MPNN to maintain 3D coordinates alongside node embeddings:
The key: use distances (rotation-invariant scalars) in the message function, and update coordinates by weighted displacement vectors (rotation-equivariant). SE(3)-equivariance without the computational cost of spherical harmonics.
DimeNet - Directional Message Passing Using Bond Angles
DimeNet (Klicpera et al., 2020) defines messages on directed bonds and incorporates the angle between consecutive bonds:
Messages flow along directed edges () and the angle between consecutive bonds is used in the message function. This allows distinction of configurations with the same connectivity and distances but different bond angles - a significant improvement for quantum chemistry.
PyTorch Geometric Implementation
The MessagePassing Base Class
PyG provides MessagePassing as a base class. Override message(), optionally aggregate(), and update(), then call self.propagate() to execute the loop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax as pyg_softmax
class GenericMPNNLayer(MessagePassing):
"""
Generic MPNN layer with edge feature support.
Message: MLP(h_v, h_u, e_vu)
Aggregation: sum
Update: MLP(h_v, aggregated_msg) + LayerNorm
"""
def __init__(self, node_dim: int, edge_dim: int, out_dim: int):
super().__init__(aggr='add') # 'add' = sum aggregation
self.node_dim = node_dim
self.edge_dim = edge_dim
self.out_dim = out_dim
# M_t: message network
self.message_net = nn.Sequential(
nn.Linear(2 * node_dim + edge_dim, out_dim),
nn.ReLU(),
nn.Linear(out_dim, out_dim),
)
# U_t: update network
self.update_net = nn.Sequential(
nn.Linear(node_dim + out_dim, out_dim),
nn.ReLU(),
nn.Linear(out_dim, out_dim),
)
self.norm = nn.LayerNorm(out_dim)
def forward(self, x, edge_index, edge_attr):
"""
x: [N, node_dim]
edge_index: [2, E] COO format
edge_attr: [E, edge_dim]
"""
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
def message(self, x_i, x_j, edge_attr):
"""
x_i: target node features [E, node_dim]
x_j: source node features [E, node_dim]
edge_attr: edge features [E, edge_dim]
Returns message m_{j→i} [E, out_dim]
"""
msg_input = torch.cat([x_i, x_j, edge_attr], dim=-1)
return self.message_net(msg_input)
def update(self, aggr_out, x):
"""
aggr_out: aggregated messages [N, out_dim]
x: original node features [N, node_dim]
"""
update_input = torch.cat([x, aggr_out], dim=-1)
h_new = self.update_net(update_input)
return self.norm(h_new)
Edge-Conditioned Message Passing (Gilmer et al. style)
class EdgeConditionedMP(MessagePassing):
"""
Message: m_vw = A(e_vw) @ h_w
where A(e_vw) is an (out_dim x node_dim) matrix from an edge MLP.
Update: GRU-cell - treats message passing as a recurrent process.
"""
def __init__(self, node_dim: int, edge_dim: int, out_dim: int):
super().__init__(aggr='add')
self.node_dim = node_dim
self.out_dim = out_dim
# Edge network: edge_dim → out_dim × node_dim (flattened matrix)
self.edge_net = nn.Sequential(
nn.Linear(edge_dim, 128),
nn.ReLU(),
nn.Linear(128, out_dim * node_dim),
)
# GRU update cell
self.gru = nn.GRUCell(out_dim, node_dim)
self.norm = nn.LayerNorm(node_dim)
def forward(self, x, edge_index, edge_attr):
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
def message(self, x_j, edge_attr):
"""
x_j: [E, node_dim] source node features
edge_attr: [E, edge_dim] bond features
Returns: [E, out_dim] m_{j→v} = A(e_vj) @ h_j
"""
A = self.edge_net(edge_attr) # [E, out_dim * node_dim]
A = A.view(-1, self.out_dim, self.node_dim) # [E, out_dim, node_dim]
# batched matrix-vector product
msg = torch.bmm(A, x_j.unsqueeze(-1)).squeeze(-1) # [E, out_dim]
return msg
def update(self, aggr_out, x):
"""GRU treats current node state as hidden state."""
h_new = self.gru(aggr_out, x) # [N, node_dim]
return self.norm(h_new)
Full MPNN for QM9
import numpy as np
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_add_pool
class MPNNModel(nn.Module):
"""
Full MPNN following Gilmer et al. (2017) for QM9.
T message-passing steps with weight-tied layers.
Attention-weighted readout (simplified set2vec).
"""
def __init__(
self,
node_in_dim: int = 11, # QM9 atom feature dimension
edge_in_dim: int = 4, # QM9 bond feature dimension
hidden_dim: int = 64,
n_mp_steps: int = 6, # T - depth of message passing
out_dim: int = 1, # number of target properties
):
super().__init__()
self.n_mp_steps = n_mp_steps
# Initial node embedding
self.node_encoder = nn.Linear(node_in_dim, hidden_dim)
# Weight-tied message passing layer (same weights at every step)
self.mp_layer = EdgeConditionedMP(hidden_dim, edge_in_dim, hidden_dim)
# Attention-weighted readout
self.readout_attn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)
self.readout_proj = nn.Linear(hidden_dim, hidden_dim)
# Output MLP
self.output_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, out_dim),
)
def forward(self, data):
x, edge_index, edge_attr, batch = (
data.x.float(),
data.edge_index,
data.edge_attr.float(),
data.batch,
)
# Initial embedding
h = F.relu(self.node_encoder(x)) # [N, hidden_dim]
# T rounds of message passing (weight-tied)
for _ in range(self.n_mp_steps):
h = self.mp_layer(h, edge_index, edge_attr)
# Attention-weighted readout per graph
attn = self.readout_attn(h) # [N, 1]
attn_weights = pyg_softmax(attn, batch) # [N, 1] - softmax within each graph
h_proj = self.readout_proj(h) # [N, hidden_dim]
graph_repr = global_add_pool(
attn_weights * h_proj, batch
) # [B, hidden_dim]
return self.output_mlp(graph_repr) # [B, out_dim]
def train_mpnn_qm9(target_property: int = 7, epochs: int = 100):
"""
Train MPNN on QM9.
target_property indices: 0=mu, 1=alpha, 2=HOMO, 3=LUMO,
4=gap, 5=R2, 6=ZPVE, 7=U0, 8=U, 9=H, 10=G, 11=Cv, 12=omega1
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = QM9(root='data/QM9')
# Normalize target
y_all = dataset.data.y[:, target_property]
y_mean, y_std = y_all.mean().item(), y_all.std().item()
train_data = dataset[:110000]
val_data = dataset[110000:120000]
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=4)
model = MPNNModel(hidden_dim=64, n_mp_steps=6).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=10, factor=0.5, min_lr=1e-5
)
best_val_mae = float('inf')
for epoch in range(epochs):
# ── train ──
model.train()
train_loss = 0.0
for batch in train_loader:
batch = batch.to(device)
y_norm = (batch.y[:, target_property] - y_mean) / y_std
pred = model(batch).squeeze(-1)
loss = F.mse_loss(pred, y_norm)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()
# ── validate ──
model.eval()
val_maes = []
with torch.no_grad():
for batch in val_loader:
batch = batch.to(device)
y_true = batch.y[:, target_property]
pred = model(batch).squeeze(-1) * y_std + y_mean
val_maes.append(F.l1_loss(pred, y_true).item())
val_mae = float(np.mean(val_maes))
scheduler.step(val_mae)
if val_mae < best_val_mae:
best_val_mae = val_mae
torch.save(model.state_dict(), 'best_mpnn_qm9.pt')
if epoch % 10 == 0:
avg_loss = train_loss / len(train_loader)
print(f"Epoch {epoch:3d} loss={avg_loss:.4f} "
f"val_MAE={val_mae:.4f} best={best_val_mae:.4f}")
return model, best_val_mae
Over-Squashing - The Second Fundamental Bottleneck
Beyond the 1-WL expressivity ceiling, MPNNs face over-squashing: when information must travel hops to reach a target node, it passes through aggregations. Each aggregation compresses potentially many neighbors into a fixed-dimensional vector. After steps, distant nodes' information is exponentially attenuated.
Formally, the Jacobian sensitivity of node 's embedding to node 's initial features decays with distance:
for GCN-style normalization. High-degree bottleneck nodes crush information flow.
Solutions:
- Graph rewiring: add virtual long-range edges (Gasteiger et al., 2021)
- Higher-dimensional messages: use vector or tensor node features that carry more information per hop
- Graph Transformers: global attention that bypasses the structural bottleneck entirely
- Expander graph rewiring: add edges to give the graph better spectral gap (faster mixing)
:::warning Over-squashing is frequently misdiagnosed Practitioners see poor GNN performance on tasks requiring long-range dependencies and blame the expressivity ceiling. Often the actual cause is over-squashing. More message passing steps does not fix over-squashing - it can worsen it while simultaneously causing over-smoothing. :::
Production: Drug Discovery at Schrödinger
Schrödinger Inc. uses MPNN-based models as a core component of their computational drug discovery pipeline. Their FEP+ (Free Energy Perturbation) workflow traditionally required expensive molecular dynamics simulations (hours per compound). MPNN-based property predictors can screen millions of compounds in minutes.
The workflow:
- Virtual screening: MPNN predicts binding affinity for a library of 10M+ compounds in minutes
- Hit expansion: top-scoring compounds clustered, analogs generated by generative model
- FEP+ validation: promising candidates validated with physics-based simulation
- Experimental testing: highest-confidence compounds synthesized and tested in assay
Key engineering decisions:
- Ensemble of 5 MPNNs (different random seeds) - variance = uncertainty estimate
- Active learning loop: experimental results feed back to retrain on high-information compounds
- Scaffold-aware splitting: test set held out by chemical scaffold to measure true generalization, not interpolation
:::tip Chemical accuracy benchmark Target for quantum chemical property prediction: "chemical accuracy" = 1 kcal/mol = 0.043 eV. State-of-the-art DimeNet++ achieves ~0.025 eV MAE on QM9 HOMO energy. A basic MPNN achieves ~0.045 eV. If your model exceeds 0.1 eV, something is wrong with your training setup. :::
YouTube Resources
| Title | Channel | Focus |
|---|---|---|
| Message Passing Neural Networks (CS224W Lecture 8) | Stanford Online | MPNN framework · WL test · expressivity theory |
| How Powerful are Graph Neural Networks? | Yannic Kilcher | GIN paper · 1-WL equivalence proof walkthrough |
| Graph Neural Networks for Molecules | Valence Labs | MPNN for drug discovery · practical considerations |
| Equivariant Graph Neural Networks | Prof. Erik Bekkers | SE(3)-equivariance · EGNN · geometric deep learning |
| DimeNet: Directional Message Passing | Johannes Klicpera | Bond angle encoding · directional MPNN for QM9 |
Common Pitfalls
:::danger Using MEAN aggregation when structural expressiveness matters MEAN and MAX aggregations are not injective over multisets - two different neighborhood configurations can produce the same aggregated value. For tasks requiring structural discrimination (graph classification, isomorphism detection), use SUM aggregation with an MLP update (GIN-style). MEAN discards neighborhood size: a node with 2 identical neighbors looks identical to a node with 20 identical neighbors. :::
:::danger Ignoring edge features for molecular tasks GCN and basic GAT have no edge feature support. Using them for molecular property prediction discards bond type, bond order, and bond length - critical chemical information. The difference between a C=C double bond and a C-C single bond changes reactivity completely. Always use a proper MPNN with edge features for chemistry. :::
:::warning Too many message passing steps creates over-smoothing After many rounds of aggregation, all node embeddings converge toward a similar vector (proportional to the graph's dominant eigenvector). This is over-smoothing. For most molecular tasks (where relevant patterns are local), 3-6 steps is optimal. Adding more steps past this threshold typically hurts performance. :::
:::warning Weight-tied vs per-layer weights Gilmer et al. used the same message/update weights at every step (weight-tied). This reduces parameter count and generalizes better with small training sets. Per-layer weights are better when the information content at different hop distances is qualitatively different. Start with weight-tied layers for molecular tasks. :::
Interview Q&A
Q: Prove that GCN and GAT are both special cases of the MPNN framework. What is the key difference between them in MPNN terms?
A: Both are MPNNs that differ only in the message function. GCN's message from node to node is - a linear transform of the neighbor's features scaled by a fixed topology-derived normalization constant. GAT's message is where - the coefficient is data-dependent, learned from node features. Both use sum aggregation and a nonlinear activation in the update. The difference: GCN uses fixed graph-structural weights; GAT uses learned content-based attention weights. Neither natively supports edge features. For molecular graphs with rich bond features, you need an MPNN like Gilmer et al.'s model with an edge-conditioned message function.
Q: What is the Weisfeiler-Lehman expressivity limit? Give a concrete example of two graphs that no MPNN can distinguish.
A: 1-WL iteratively refines node labels by hashing each node's label with the multiset of its neighbors' labels. Any MPNN is structurally equivalent to 1-WL refinement - it maps (node state, multiset of neighbor states) to a new node state. Xu et al. (2019) proved no MPNN can exceed 1-WL expressiveness. A concrete failure: the 6-cycle (hexagon, every node degree 2) vs two disjoint triangles (every node degree 2). 1-WL assigns identical labels to all nodes in both graphs at every round because every node's 1-hop neighborhood looks the same in both graphs. A GNN initialized on degree features will produce identical embeddings for both graphs and cannot classify them differently. For chemistry: non-isomorphic molecules whose atoms all have identical local atom environments at every radius are indistinguishable by any standard MPNN.
Q: What is the difference between over-smoothing and over-squashing, and how do you fix each?
A: They are opposite failure modes of message passing. Over-smoothing: with too many layers, node representations converge toward a common vector through repeated averaging. All nodes look identical; the model loses discriminative power. Fix: fewer layers, residual/skip connections, or PairNorm. Over-squashing: when information must travel far across the graph, it is compressed through a narrow sequence of aggregations. The Jacobian decays exponentially with graph distance, so long-range information is lost. Fix: graph rewiring (add virtual edges), graph transformers with global attention, or expander graph construction. They often co-occur - adding more layers to solve over-squashing worsens over-smoothing. The right tool is graph rewiring, not more layers.
Q: How does equivariant message passing (EGNN) differ from standard MPNN, and why does it matter for protein structure?
A: Standard MPNN operates on node feature scalars and uses topology. Equivariant GNNs extend this to 3D coordinates as geometric objects. EGNN updates node positions alongside scalar features . Messages use (a rotation-invariant scalar) as input. Coordinate updates use displacement vectors weighted by a scalar from the message network - making the update rotation-equivariant. SE(3)-equivariance means: rotate/translate the input → the output coordinates rotate/translate identically, and scalar predictions are unchanged. For proteins, this is essential: energy, force, and binding affinity are geometric properties of 3D structure. A non-equivariant model must learn rotation invariance from data, requiring vastly more training examples and still generalizing poorly to unseen orientations.
Q: How would you implement an MPNN for edge-level tasks like predicting whether a chemical bond will break?
A: After rounds of node-level message passing, compute edge representations from endpoint node states: where is an MLP. The final node states encode each atom's full -hop neighborhood context, so the edge representation captures both endpoints' chemical environments. For directed tasks (predicting which direction a reaction proceeds), use an asymmetric function: differs from by argument order. Apply binary cross-entropy on top of for bond-breaking classification. Ensure the training data preserves edge directionality and that the loss is computed per-edge rather than per-graph.
Q: Why is sum aggregation strictly more expressive than mean or max aggregation for MPNN?
A: Sum aggregation is injective over multisets of bounded values (proven by Zaheer et al., 2017): different multisets of neighbor features always produce different sums (when using a universal function approximator applied after the sum). Mean aggregation discards neighborhood size - a node with neighbors produces the same mean as a node with neighbors . Max aggregation discards multiplicity - produces the same max as . For structural discrimination tasks, this means GCN (sum after normalization, which is approximately mean) and some GraphSAGE variants are strictly less expressive than GIN with sum. In practice, for tasks where only the presence/absence of neighbor types matters (not counts), mean works well and is more robust to high-degree nodes. Match the aggregation function to the task's symmetry requirements.
GIN - The Maximally Expressive MPNN
If any MPNN can be at most as powerful as 1-WL, what is the most powerful MPNN? Xu et al. (2019) answered this with the Graph Isomorphism Network (GIN):
Two design choices make GIN maximally expressive:
- Sum aggregation: injective over multisets - different neighbor multisets always produce different sums
- MLP update: a universal function approximator - can represent any continuous function
The term ensures the central node's features are distinguished from neighbors - equivalent to adding a self-loop with a learned weight. The scalar is either learned or fixed to 0 (both work in practice).
Why other aggregations are less expressive:
| Aggregator | Injective over multisets? | Example failure |
|---|---|---|
| SUM + MLP | Yes - maximally expressive | - |
| MEAN | No - loses neighborhood size | and → same mean |
| MAX | No - loses multiplicity | and → same max |
| GCN (normalized sum) | No - approximately mean | same as MEAN failures |
GIN with SUM + MLP is the theoretically optimal choice for structural discrimination. For tasks where you care about neighbor types but not counts (common in social networks), MEAN works fine. For tasks requiring structural isomorphism detection, use GIN.
MPNN Variants Comparison Table
| Architecture | Message Function | Aggregation | Update | Edge Features | Expressiveness |
|---|---|---|---|---|---|
| GCN | SUM | No | Strictly less than 1-WL | ||
| GraphSAGE | MEAN/MAX/LSTM | No | Less than 1-WL (MEAN) | ||
| GAT | SUM | No | Strictly less than 1-WL | ||
| MPNN (Gilmer) | SUM | GRU cell | Yes | Up to 1-WL | |
| GIN | SUM | MLP((1+ε) + ) | No | Exactly 1-WL | |
| EGNN | SUM | + coord update | Dist | Up to 1-WL + SE(3) equivariant | |
| DimeNet | Directed bond + angle | SUM | MLP | Angle + dist | Up to 1-WL + angular info |
Choosing Your MPNN Architecture
The right MPNN depends on your task. Use this decision guide:
Practical rules of thumb:
- For molecular property prediction with bond features: use edge-conditioned MPNN (Gilmer et al.) with – steps
- For citation networks, social graphs: GCN or GAT with 2 layers is usually optimal
- For structural graph classification: GIN with SUM aggregation
- For 3D molecular geometry (protein structure, force fields): EGNN or DimeNet
- When in doubt: start with a 2-layer GCN baseline, then systematically upgrade
Production Engineering Notes
Batching Graphs in PyTorch Geometric
Unlike images (fixed size) or text (padded to same length), graphs come in different sizes. PyG handles this with a disjoint union batching approach:
from torch_geometric.loader import DataLoader
# PyG automatically creates a single large disconnected graph from the batch
# Node indices are offset: graph 0 uses nodes 0..n0-1, graph 1 uses n0..n0+n1-1, etc.
# The `batch` tensor maps each node to its graph index: batch[i] = graph index of node i
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for data in loader:
# data.x: [total_nodes_in_batch, node_features]
# data.edge_index: [2, total_edges_in_batch] - offsets applied automatically
# data.batch: [total_nodes_in_batch] - graph membership
out = model(data)
This is why global_add_pool(x, batch) works - it sums over all nodes belonging to the same graph.
Gradient Clipping for GNN Training
GNNs can have unstable gradients, especially with many layers or GRU-based updates. Always clip gradients:
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
Clip norm of 1.0 is the standard for GNNs on molecular tasks. If training diverges, check gradient norms with torch.nn.utils.clip_grad_norm_ (without clipping first) to diagnose.
Monitoring GNN Training Health
Three signals to watch during training:
- Gradient norm: should stabilize below 1.0; if exploding, reduce learning rate or add clipping
- Node embedding variance: sample 100 node embeddings per epoch and compute pairwise cosine similarity - if they converge toward 1.0, you have over-smoothing
- Validation loss curve: GNNs on molecular tasks often show a "double descent" - initial improvement, then plateau, then improvement again as the model learns long-range patterns
Summary
The MPNN framework shows that GCN, GraphSAGE, GAT, and nearly every graph neural network published before 2019 are all special cases of a three-step computation: message, aggregate, update. The differences between architectures reduce to implementation choices within these three functions.
The framework has a hard theoretical ceiling: no MPNN can be more expressive than the 1-WL Weisfeiler-Lehman test. GIN achieves this ceiling exactly - it is the maximally expressive MPNN. Breaking beyond 1-WL requires either higher-order GNNs (exponential cost) or structural encodings (practical, problem-specific). The 1-WL limitation is not a bug - it is an inherent property of the message-passing paradigm.
For 3D molecular tasks, equivariant extensions (EGNN, DimeNet) incorporate geometric information while respecting the symmetries of physical space, achieving chemical accuracy on quantum property prediction benchmarks. For production drug discovery, MPNN-based screening pipelines have reduced candidate selection from weeks of simulation to minutes of inference.
Understanding MPNN at this depth - the unified framework, the expressivity analysis, the equivariant extensions, and the practical engineering tradeoffs - is what separates practitioners who can use GNNs from practitioners who can design them.
Graph Transformers - Going Beyond MPNN
The 1-WL expressivity ceiling and over-squashing have motivated a new class of models: Graph Transformers that use global self-attention over all node pairs, bypassing the MPNN locality constraint entirely.
How Graph Transformers Work
A Graph Transformer treats the graph as a set of nodes and applies Transformer-style attention between all pairs:
This is identical to standard Transformer attention. The graph structure is incorporated through structural positional encodings added to node features - otherwise the model is permutation equivariant but graph-unaware.
Key Graph Transformer Models
Graphormer (Ying et al., Microsoft, 2021) won the OGB-LSC quantum chemistry challenge. It adds:
- Centrality encoding: degree of each node as a learnable bias to the attention logit
- Spatial encoding: shortest path distance between node pairs as an attention bias
- Edge encoding: average of edge features along the shortest path as additional attention input
where is a learnable bias depending only on the shortest path distance between and .
SAN (Spectral Attention Network, 2021): uses Laplacian eigenvectors as positional encodings, giving each node a unique structural fingerprint that encodes its global position in the graph.
GPS (General Powerful Scalable, 2022): combines local MPNN layers with global self-attention. Each GPS layer has two sub-layers: an MPNN layer for local structure and a Transformer layer for global context. This hybridization gives local + global information flow without the full cost of pure attention (since local MPNN is ).
When to Use Graph Transformers vs MPNN
| Criterion | MPNN | Graph Transformer |
|---|---|---|
| Graph size | Any (with mini-batching) | Typically less than 1K nodes (due to attention) |
| Long-range dependencies | Poor (over-squashing) | Excellent (global attention) |
| Expressiveness | Up to 1-WL | Beyond 1-WL (with structural encodings) |
| Inference speed | Fast | Slower ( per layer) |
| Molecular tasks (small graphs) | Excellent | Excellent - often better |
| Social networks (large graphs) | Scalable | Needs approximation (e.g., Performer, Linformer) |
For molecular property prediction where graphs are small (10-50 atoms), Graph Transformers with structural encodings currently set the state-of-the-art on QM9 and similar benchmarks. For large-graph tasks, MPNNs with structural encodings are still the practical choice.
Scalable MPNN Training
Mini-Batch Training on Large Graphs
The challenge: backpropagation through MPNN layers requires materializing the -hop computation subgraph for each training node - which can be enormous for high-degree nodes in social graphs.
Neighbor sampling (GraphSAGE): sample a fixed number of neighbors at each hop. layers with fan-out gives at most nodes per target, making batch size predictable.
Layer-wise sampling (LADIES, FastGCN): instead of sampling per-node, sample a fixed set of nodes per layer. Faster but introduces approximation error.
Subgraph sampling (GraphSAINT): sample a connected subgraph (random walk, random edge, random node) and train on it as if it were the full graph. Subgraph normalization corrects for sampling bias. Scales to graphs with billions of edges.
Cluster-GCN: partition the graph into clusters (using METIS or KMeans on node features), assign each cluster to a mini-batch. Efficient because edges within a cluster are dense; cross-cluster edges are discarded during that mini-batch (a principled approximation).
# GraphSAINT random walk sampler - scales to billion-edge graphs
from torch_geometric.loader import GraphSAINTRandomWalkSampler
loader = GraphSAINTRandomWalkSampler(
data,
batch_size=6000, # nodes per subgraph
walk_length=2, # random walk steps
num_steps=5, # mini-batches per epoch
sample_coverage=100,# samples to estimate normalization
)
for subgraph in loader:
out = model(subgraph.x, subgraph.edge_index)
loss = criterion(out[subgraph.train_mask], subgraph.y[subgraph.train_mask])
:::tip Which sampler to use
For citation graphs (Cora, CiteSeer, ogbn-arxiv): neighbor sampling with GraphSAGE. For social graphs (ogbn-products, Reddit): GraphSAINT or Cluster-GCN. For molecular graphs: full batch - molecules are small enough to fit in memory. For heterogeneous graphs: use PyG's NeighborLoader with node type-specific fan-outs.
:::
:::tip 🎮 Interactive Playground
Visualize this concept: Try the GNN Message Passing demo on the EngineersOfAI Playground - no code required.
:::
