Skip to main content

Message Passing Neural Networks

The Real Interview Moment

"We have GCN, GraphSAGE, GAT, Interaction Networks, and five other GNN variants on our shortlist. Our team cannot agree which to use for molecular property prediction. How do you even compare them systematically?"

This exact tension was felt at Google Brain and DeepMind in 2017. The field had produced a proliferating zoo of graph neural network architectures, each with its own paper, its own notation, and its own claimed advantage. Some operated on node features, others on edge features. Some used attention, others simple summation. Some claimed better expressivity, others better scalability. There was no principled basis for comparison - just empirical ablations on specific datasets that rarely generalized.

Justin Gilmer and colleagues sat down to read every major GNN paper carefully. A pattern emerged. Strip away the notation differences, and almost every architecture was doing the same three operations: compute messages from neighbors, aggregate those messages, and update the node's hidden state. The differences were only in how each step was implemented. The team wrote a general template - the Message Passing Neural Network (MPNN) framework - and showed GCN, GraphSAGE, GAT, Interaction Networks, DTNN, and several others were all special cases.

This unification was more than taxonomic convenience. It allowed the team to ask the right question: given this framework, what is the fundamental limit of what any MPNN can express? The answer came from a surprising direction - a 1968 graph theory algorithm called the Weisfeiler-Lehman (WL) isomorphism test. Any MPNN is at most as powerful as 1-WL. This ceiling explained why so many GNNs failed on certain graph classification tasks and launched a new research agenda: how do you build GNNs provably more expressive than 1-WL?

Understanding MPNN is not optional for any engineer working on graph-structured data. It is the lingua franca of the field. Every time you read a GNN paper and see "message function," "aggregation," and "update function" in the methods section, you are seeing the MPNN vocabulary. When you debug a GNN that cannot distinguish two structurally different molecules, you are hitting the 1-WL ceiling. When you read about SE(3)-equivariant networks for protein structure prediction, you are seeing the response to that ceiling.


Why This Exists - The GNN Zoo Problem

By 2017, the GNN literature had become fragmented. Consider just a few architectures that existed:

  • Spectral GCN (Bruna et al., 2014) - graph convolution in the Fourier domain
  • ChebNet (Defferrard et al., 2016) - polynomial filters on the graph Laplacian
  • GCN (Kipf and Welling, 2017) - simplified spectral convolution, layer-wise propagation
  • GraphSAGE (Hamilton et al., 2017) - sample-and-aggregate, inductive learning
  • GAT (Velickovic et al., 2018) - attention-weighted aggregation
  • Interaction Networks (Battaglia et al., 2016) - physics simulation on graphs
  • DTNN (Schütt et al., 2017) - deep tensor neural networks for quantum chemistry
  • Molecular Graph Convolutions (Duvenaud et al., 2015) - neural circular fingerprints

Each was described in its own notation. Comparing them required careful reading of multiple papers. Practitioners could not easily decide which to use for a new task because the assumptions and tradeoffs were buried in architectural choices that looked superficially different.

The MPNN framework provides a common language. By formalizing the three-step structure - message, aggregate, update - it becomes straightforward to compare architectures, combine ideas across papers, and reason precisely about what cannot be achieved within the framework.


Historical Context

The paper Neural Message Passing for Quantum Chemistry (Gilmer et al., 2017, NeurIPS) was motivated by a concrete application: predicting quantum chemical properties of small organic molecules from the QM9 dataset. QM9 contains 134K molecules with 13 quantum properties computed by density functional theory (DFT) - atomization energy, dipole moment, HOMO-LUMO gap, polarizability, and more. DFT calculations take hours per molecule. If a GNN could predict these properties accurately from molecular graph structure alone, it would accelerate drug discovery and materials science enormously.

To build the best model, the team needed to understand the design space of all possible GNNs for this task. That forced them to think systematically. The result was the MPNN framework, and the MPNN model they built set a new state-of-the-art on QM9, predicting 11 of 13 properties within chemical accuracy.

The deeper contribution - recognizing the connection to the Weisfeiler-Lehman test - came in How Powerful are Graph Neural Networks? (Xu et al., 2019, ICLR), which proved formally that any MPNN is at most as powerful as 1-WL and introduced Graph Isomorphism Networks (GIN) as the maximally expressive MPNN.


The MPNN Framework - Formal Definition

An MPNN operates on a graph G=(V,E)G = (V, E) where each node vv has a feature vector hv0=xvh_v^0 = x_v and each edge (v,w)(v, w) has a feature vector evwe_{vw}.

The framework has two phases.

Phase 1: Message Passing

Runs for TT steps. At each step tt, every node sends and receives messages via three functions:

Message function MtM_t - computes a message from neighbor ww to node vv:

mvwt+1=Mt(hvt,hwt,evw)m_{vw}^{t+1} = M_t(h_v^t, h_w^t, e_{vw})

Aggregation function AGG\text{AGG} - combines all incoming messages for node vv:

mvt+1=AGG ⁣({mvwt+1wN(v)})m_v^{t+1} = \text{AGG}\!\left(\{m_{vw}^{t+1} \mid w \in \mathcal{N}(v)\}\right)

Update function UtU_t - updates the node's hidden state:

hvt+1=Ut(hvt,mvt+1)h_v^{t+1} = U_t(h_v^t,\, m_v^{t+1})

Phase 2: Readout

After TT steps, a readout function RR produces a graph-level representation for graph-level tasks:

y^=R ⁣({hvTvG})\hat{y} = R\!\left(\{h_v^T \mid v \in G\}\right)

For node-level tasks the final hvTh_v^T are used directly. For edge-level tasks, edge representations are computed from endpoint representations.

:::note What "message passing" physically means Information flows along edges. In step 1, node vv generates messages to send to neighbors. In step 2, node vv collects all incoming messages and combines them. In step 3, node vv updates its own state. After TT steps, each node has aggregated information from its full TT-hop neighborhood. :::


GCN as a Special Case of MPNN

Recall the GCN update rule:

hv(k+1)=σ ⁣(uN(v){v}1dvduW(k)hu(k))h_v^{(k+1)} = \sigma\!\left(\sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{1}{\sqrt{d_v\, d_u}}\, W^{(k)} h_u^{(k)}\right)

Mapping to MPNN:

Message function: M(hu,hv,e)=1dudvWhuM(h_u, h_v, e) = \frac{1}{\sqrt{d_u\, d_v}}\, W h_u

The message from uu to vv is the neighbor's feature multiplied by weight matrix WW and scaled by the symmetric normalization factor. GCN does not use edge features.

Aggregation: AGG=\text{AGG} = \sum

Simple summation. The normalization is baked into the message function.

Update function: U(hv,mv)=σ(mv)U(h_v, m_v) = \sigma(m_v)

The self-loop (A~=A+I\tilde{A} = A + I) is included in the aggregation, so the update applies only the activation.

Readout: mean or sum over all final node embeddings.


GraphSAGE as a Special Case of MPNN

GraphSAGE uses a concatenation-based update:

hv(k)=σ ⁣(W(k)[hv(k1)AGG ⁣({hu(k1)uN(v)})])h_v^{(k)} = \sigma\!\left(W^{(k)} \cdot \left[h_v^{(k-1)} \,\big\|\, \text{AGG}\!\left(\{h_u^{(k-1)} \mid u \in \mathcal{N}(v)\}\right)\right]\right)

Message function: M(hu,hv,e)=huM(h_u, h_v, e) = h_u

The message is simply the neighbor's current hidden state - no edge features, no weighting.

Aggregation: AGG=MEAN()orMAX()orLSTM()\text{AGG} = \text{MEAN}(\cdot) \quad \text{or} \quad \text{MAX}(\cdot) \quad \text{or} \quad \text{LSTM}(\cdot)

GraphSAGE proposes three aggregators. MEAN = element-wise mean of neighbor features. MAX = element-wise max-pooling. LSTM = process a random permutation of neighbors through an LSTM (order-dependent in theory, works in practice).

Update function: U(hv,mv)=σ ⁣(W[hvmv])U(h_v, m_v) = \sigma\!\left(W \cdot [h_v \,\|\, m_v]\right)

Concatenate the node's own features with the aggregated message, then apply a linear transform and activation. This explicit concatenation of self and neighborhood is GraphSAGE's defining characteristic.


GAT as a Special Case of MPNN

Graph Attention Networks use learned attention weights:

hv(k)=σ ⁣(uN(v){v}αvu(k)W(k)hu(k1))h_v^{(k)} = \sigma\!\left(\sum_{u \in \mathcal{N}(v) \cup \{v\}} \alpha_{vu}^{(k)}\, W^{(k)} h_u^{(k-1)}\right)

where αvu=softmaxu ⁣(LeakyReLU ⁣(aT[WhvWhu]))\alpha_{vu} = \text{softmax}_u\!\left(\text{LeakyReLU}\!\left(a^T[Wh_v \,\|\, Wh_u]\right)\right).

Message function: M(hu,hv,e)=αvuWhuM(h_u, h_v, e) = \alpha_{vu} \cdot W h_u

The coefficient αvu\alpha_{vu} depends on both endpoint states, making this a data-dependent rather than topology-fixed scaling.

Aggregation: \sum - weighted sum with attention weights already incorporated in messages.

Update: U(hv,mv)=σ(mv)U(h_v, m_v) = \sigma(m_v)

Key insight in MPNN terms: GCN vs GAT differ only in the message function. GCN uses fixed normalization from graph structure; GAT uses learned data-dependent weights from node features.


The Unification Diagram


Edge Features in MPNN

One area where MPNN shines over standard GCN and GAT is edge feature support. Molecular graphs have rich edge features - bond type (single/double/triple/aromatic), bond length, bond angle. Ignoring these discards crucial chemical information.

In the MPNN framework, edge features are first-class citizens of the message function:

mvw=M(hv,hw,evw)m_{vw} = M(h_v,\, h_w,\, e_{vw})

Gilmer et al. used a neural network that produces a matrix from edge features:

mvw=A(evw)hwm_{vw} = A(e_{vw}) \cdot h_w

where A(evw)Rd×dA(e_{vw}) \in \mathbb{R}^{d \times d} is produced by a small MLP applied to the edge feature vector. The edge feature gates information flow along that edge.

For directional molecular graphs this becomes directional message passing - the message from vv to ww differs from the message from ww to vv, parameterized by shared edge features but different endpoint states.


MPNN for Molecular Property Prediction - QM9

The QM9 dataset is the canonical benchmark. Each molecule is a graph:

  • Nodes = atoms (features: atom type, formal charge, number of hydrogens, aromaticity, hybridization)
  • Edges = bonds (features: bond type - single/double/triple/aromatic)
  • Labels = 13 quantum chemical properties (DFT-computed)

The original MPNN paper used a GRU-based update (treating TT steps as a recurrent computation over the graph) and a "set2vec" attention readout. This achieved chemical accuracy on 11 of 13 QM9 properties.


The Expressivity Ceiling - MPNN and 1-WL

The Weisfeiler-Lehman Test

The 1-dimensional Weisfeiler-Lehman (1-WL) graph isomorphism algorithm:

  1. Initialize each node with a label (e.g., its degree or atom type)
  2. At each iteration: update each node's label by hashing its current label together with the multiset of its neighbors' labels
  3. Repeat until labels stabilize
  4. If the multisets of final labels for two graphs differ, they are non-isomorphic

The 1-WL test is powerful but incomplete - there exist non-isomorphic graph pairs that 1-WL assigns identical label distributions to.

The MPNN-WL Connection

Xu et al. (2019) proved:

Theorem: Any MPNN with a sum aggregation and an injective update function is exactly as expressive as 1-WL. No MPNN can exceed 1-WL expressiveness.

Expressivity(any MPNN)1-WL\text{Expressivity}(\text{any MPNN}) \leq \text{1-WL}

Why? Because MPNN update rules are structurally identical to 1-WL label refinement. Each node's new state is a function of its current state and the multiset of its neighbors' states - precisely what 1-WL does. If 1-WL cannot distinguish two graphs (assigns them identical label multisets at every round), then no MPNN can, because the computation is isomorphic.

A Concrete Failure Case

The canonical 1-WL failure involves strongly regular graphs. A simpler example for intuition:

Consider the rook's graph K3K3K_3 \square K_3 (nodes arranged in a 3x3 grid where each node connects to all others in the same row and column) versus the Petersen graph. Both have 10 nodes, all with degree 4, and 1-WL assigns identical labels to all nodes in both after several rounds.

For molecular chemistry: two non-isomorphic molecules with the same local atom environment at every neighborhood radius cannot be distinguished. An MPNN trained on atom type + bond type features would produce identical fingerprints for both and assign the same predicted properties - even if those properties differ experimentally.


Beyond 1-WL - Breaking the Ceiling

k-GNNs (Morris et al., 2019)

k-GNNs operate on kk-tuples of nodes instead of individual nodes. A kk-GNN is exactly as powerful as the kk-WL test:

LevelPowerCost
1-GNN (standard MPNN)1-WLO(n)O(n) memory, O(nd)O(n \cdot d) compute
2-GNN2-WL - counts triangles, 2-pathsO(n2)O(n^2) memory
3-GNN3-WL - distinguishes all strongly regular graphsO(n3)O(n^3) memory

For n=100n = 100 nodes: 3-GNN needs 10610^6 tuple representations. Impractical for any real molecular dataset.

Structural Encodings (Practical Alternative)

Augment node features with precomputed structural information:

  • Random walk encodings: append diagonal of AkA^k (return probability in kk steps) - encodes triangle counts and cycle structure
  • Laplacian eigenvectors: use spectral information as positional encodings (invariant to graph isomorphism)
  • Ring size features: smallest ring containing each atom - crucial for aromatic chemistry
  • Distance encodings: shortest path distances between all node pairs

These break the 1-WL ceiling without exponential cost. The tradeoff: expensive preprocessing and problem-specific design.


Equivariant Message Passing for 3D Geometry

The 3D Molecule Problem

Molecular graphs in 2D (connectivity only) miss 3D geometry - bond angles, dihedral angles, and atomic coordinates determine molecular properties like binding affinity and reactivity. Incorporating 3D geometry requires SE(3)-equivariance: if the molecule is rotated or translated in space, predicted scalar properties must be unchanged and predicted vector properties must rotate/translate identically.

EGNN - E(n)-Equivariant Graph Neural Network

Satorras et al. (2021) extended MPNN to maintain 3D coordinates alongside node embeddings:

mij=ϕe ⁣(hi,hj,xixj2,aij)m_{ij} = \phi_e\!\left(h_i,\, h_j,\, \|x_i - x_j\|^2,\, a_{ij}\right) xi=xi+1N(i)jN(i)(xixj)ϕx(mij)x_i' = x_i + \frac{1}{|\mathcal{N}(i)|} \sum_{j \in \mathcal{N}(i)} (x_i - x_j)\, \phi_x(m_{ij}) hi=ϕh ⁣(hi,jmij)h_i' = \phi_h\!\left(h_i,\, \textstyle\sum_j m_{ij}\right)

The key: use distances xixj2\|x_i - x_j\|^2 (rotation-invariant scalars) in the message function, and update coordinates by weighted displacement vectors (rotation-equivariant). SE(3)-equivariance without the computational cost of spherical harmonics.

DimeNet - Directional Message Passing Using Bond Angles

DimeNet (Klicpera et al., 2020) defines messages on directed bonds and incorporates the angle between consecutive bonds:

mji(t+1)=U ⁣(mji(t),kN(j)iM ⁣(mkj(t),eRBF(xjxi),akji))m_{ji}^{(t+1)} = U\!\left(m_{ji}^{(t)},\, \sum_{k \in \mathcal{N}(j) \setminus i} M\!\left(m_{kj}^{(t)},\, e_{\text{RBF}}(\|x_j - x_i\|),\, a_{\angle kji}\right)\right)

Messages flow along directed edges (kjik \to j \to i) and the angle kji\angle kji between consecutive bonds is used in the message function. This allows distinction of configurations with the same connectivity and distances but different bond angles - a significant improvement for quantum chemistry.


PyTorch Geometric Implementation

The MessagePassing Base Class

PyG provides MessagePassing as a base class. Override message(), optionally aggregate(), and update(), then call self.propagate() to execute the loop.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax as pyg_softmax


class GenericMPNNLayer(MessagePassing):
"""
Generic MPNN layer with edge feature support.
Message: MLP(h_v, h_u, e_vu)
Aggregation: sum
Update: MLP(h_v, aggregated_msg) + LayerNorm
"""

def __init__(self, node_dim: int, edge_dim: int, out_dim: int):
super().__init__(aggr='add') # 'add' = sum aggregation
self.node_dim = node_dim
self.edge_dim = edge_dim
self.out_dim = out_dim

# M_t: message network
self.message_net = nn.Sequential(
nn.Linear(2 * node_dim + edge_dim, out_dim),
nn.ReLU(),
nn.Linear(out_dim, out_dim),
)

# U_t: update network
self.update_net = nn.Sequential(
nn.Linear(node_dim + out_dim, out_dim),
nn.ReLU(),
nn.Linear(out_dim, out_dim),
)
self.norm = nn.LayerNorm(out_dim)

def forward(self, x, edge_index, edge_attr):
"""
x: [N, node_dim]
edge_index: [2, E] COO format
edge_attr: [E, edge_dim]
"""
return self.propagate(edge_index, x=x, edge_attr=edge_attr)

def message(self, x_i, x_j, edge_attr):
"""
x_i: target node features [E, node_dim]
x_j: source node features [E, node_dim]
edge_attr: edge features [E, edge_dim]
Returns message m_{j→i} [E, out_dim]
"""
msg_input = torch.cat([x_i, x_j, edge_attr], dim=-1)
return self.message_net(msg_input)

def update(self, aggr_out, x):
"""
aggr_out: aggregated messages [N, out_dim]
x: original node features [N, node_dim]
"""
update_input = torch.cat([x, aggr_out], dim=-1)
h_new = self.update_net(update_input)
return self.norm(h_new)

Edge-Conditioned Message Passing (Gilmer et al. style)

class EdgeConditionedMP(MessagePassing):
"""
Message: m_vw = A(e_vw) @ h_w
where A(e_vw) is an (out_dim x node_dim) matrix from an edge MLP.
Update: GRU-cell - treats message passing as a recurrent process.
"""

def __init__(self, node_dim: int, edge_dim: int, out_dim: int):
super().__init__(aggr='add')
self.node_dim = node_dim
self.out_dim = out_dim

# Edge network: edge_dim → out_dim × node_dim (flattened matrix)
self.edge_net = nn.Sequential(
nn.Linear(edge_dim, 128),
nn.ReLU(),
nn.Linear(128, out_dim * node_dim),
)

# GRU update cell
self.gru = nn.GRUCell(out_dim, node_dim)
self.norm = nn.LayerNorm(node_dim)

def forward(self, x, edge_index, edge_attr):
return self.propagate(edge_index, x=x, edge_attr=edge_attr)

def message(self, x_j, edge_attr):
"""
x_j: [E, node_dim] source node features
edge_attr: [E, edge_dim] bond features
Returns: [E, out_dim] m_{j→v} = A(e_vj) @ h_j
"""
A = self.edge_net(edge_attr) # [E, out_dim * node_dim]
A = A.view(-1, self.out_dim, self.node_dim) # [E, out_dim, node_dim]
# batched matrix-vector product
msg = torch.bmm(A, x_j.unsqueeze(-1)).squeeze(-1) # [E, out_dim]
return msg

def update(self, aggr_out, x):
"""GRU treats current node state as hidden state."""
h_new = self.gru(aggr_out, x) # [N, node_dim]
return self.norm(h_new)

Full MPNN for QM9

import numpy as np
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_add_pool


class MPNNModel(nn.Module):
"""
Full MPNN following Gilmer et al. (2017) for QM9.
T message-passing steps with weight-tied layers.
Attention-weighted readout (simplified set2vec).
"""

def __init__(
self,
node_in_dim: int = 11, # QM9 atom feature dimension
edge_in_dim: int = 4, # QM9 bond feature dimension
hidden_dim: int = 64,
n_mp_steps: int = 6, # T - depth of message passing
out_dim: int = 1, # number of target properties
):
super().__init__()
self.n_mp_steps = n_mp_steps

# Initial node embedding
self.node_encoder = nn.Linear(node_in_dim, hidden_dim)

# Weight-tied message passing layer (same weights at every step)
self.mp_layer = EdgeConditionedMP(hidden_dim, edge_in_dim, hidden_dim)

# Attention-weighted readout
self.readout_attn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)
self.readout_proj = nn.Linear(hidden_dim, hidden_dim)

# Output MLP
self.output_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, out_dim),
)

def forward(self, data):
x, edge_index, edge_attr, batch = (
data.x.float(),
data.edge_index,
data.edge_attr.float(),
data.batch,
)

# Initial embedding
h = F.relu(self.node_encoder(x)) # [N, hidden_dim]

# T rounds of message passing (weight-tied)
for _ in range(self.n_mp_steps):
h = self.mp_layer(h, edge_index, edge_attr)

# Attention-weighted readout per graph
attn = self.readout_attn(h) # [N, 1]
attn_weights = pyg_softmax(attn, batch) # [N, 1] - softmax within each graph
h_proj = self.readout_proj(h) # [N, hidden_dim]
graph_repr = global_add_pool(
attn_weights * h_proj, batch
) # [B, hidden_dim]

return self.output_mlp(graph_repr) # [B, out_dim]


def train_mpnn_qm9(target_property: int = 7, epochs: int = 100):
"""
Train MPNN on QM9.
target_property indices: 0=mu, 1=alpha, 2=HOMO, 3=LUMO,
4=gap, 5=R2, 6=ZPVE, 7=U0, 8=U, 9=H, 10=G, 11=Cv, 12=omega1
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = QM9(root='data/QM9')

# Normalize target
y_all = dataset.data.y[:, target_property]
y_mean, y_std = y_all.mean().item(), y_all.std().item()

train_data = dataset[:110000]
val_data = dataset[110000:120000]

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=4)

model = MPNNModel(hidden_dim=64, n_mp_steps=6).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=10, factor=0.5, min_lr=1e-5
)

best_val_mae = float('inf')

for epoch in range(epochs):
# ── train ──
model.train()
train_loss = 0.0
for batch in train_loader:
batch = batch.to(device)
y_norm = (batch.y[:, target_property] - y_mean) / y_std
pred = model(batch).squeeze(-1)
loss = F.mse_loss(pred, y_norm)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_loss += loss.item()

# ── validate ──
model.eval()
val_maes = []
with torch.no_grad():
for batch in val_loader:
batch = batch.to(device)
y_true = batch.y[:, target_property]
pred = model(batch).squeeze(-1) * y_std + y_mean
val_maes.append(F.l1_loss(pred, y_true).item())

val_mae = float(np.mean(val_maes))
scheduler.step(val_mae)

if val_mae < best_val_mae:
best_val_mae = val_mae
torch.save(model.state_dict(), 'best_mpnn_qm9.pt')

if epoch % 10 == 0:
avg_loss = train_loss / len(train_loader)
print(f"Epoch {epoch:3d} loss={avg_loss:.4f} "
f"val_MAE={val_mae:.4f} best={best_val_mae:.4f}")

return model, best_val_mae

Over-Squashing - The Second Fundamental Bottleneck

Beyond the 1-WL expressivity ceiling, MPNNs face over-squashing: when information must travel kk hops to reach a target node, it passes through kk aggregations. Each aggregation compresses potentially many neighbors into a fixed-dimensional vector. After kk steps, distant nodes' information is exponentially attenuated.

Formally, the Jacobian sensitivity of node vv's embedding to node uu's initial features decays with distance:

hv(k)hu(0)C(1dvdu)k\left\|\frac{\partial h_v^{(k)}}{\partial h_u^{(0)}}\right\| \leq C \cdot \left(\frac{1}{\sqrt{d_v \cdot d_u}}\right)^k

for GCN-style normalization. High-degree bottleneck nodes crush information flow.

Solutions:

  • Graph rewiring: add virtual long-range edges (Gasteiger et al., 2021)
  • Higher-dimensional messages: use vector or tensor node features that carry more information per hop
  • Graph Transformers: global attention that bypasses the structural bottleneck entirely
  • Expander graph rewiring: add edges to give the graph better spectral gap (faster mixing)

:::warning Over-squashing is frequently misdiagnosed Practitioners see poor GNN performance on tasks requiring long-range dependencies and blame the expressivity ceiling. Often the actual cause is over-squashing. More message passing steps does not fix over-squashing - it can worsen it while simultaneously causing over-smoothing. :::


Production: Drug Discovery at Schrödinger

Schrödinger Inc. uses MPNN-based models as a core component of their computational drug discovery pipeline. Their FEP+ (Free Energy Perturbation) workflow traditionally required expensive molecular dynamics simulations (hours per compound). MPNN-based property predictors can screen millions of compounds in minutes.

The workflow:

  1. Virtual screening: MPNN predicts binding affinity for a library of 10M+ compounds in minutes
  2. Hit expansion: top-scoring compounds clustered, analogs generated by generative model
  3. FEP+ validation: promising candidates validated with physics-based simulation
  4. Experimental testing: highest-confidence compounds synthesized and tested in assay

Key engineering decisions:

  • Ensemble of 5 MPNNs (different random seeds) - variance = uncertainty estimate
  • Active learning loop: experimental results feed back to retrain on high-information compounds
  • Scaffold-aware splitting: test set held out by chemical scaffold to measure true generalization, not interpolation

:::tip Chemical accuracy benchmark Target for quantum chemical property prediction: "chemical accuracy" = 1 kcal/mol = 0.043 eV. State-of-the-art DimeNet++ achieves ~0.025 eV MAE on QM9 HOMO energy. A basic MPNN achieves ~0.045 eV. If your model exceeds 0.1 eV, something is wrong with your training setup. :::


YouTube Resources

TitleChannelFocus
Message Passing Neural Networks (CS224W Lecture 8)Stanford OnlineMPNN framework · WL test · expressivity theory
How Powerful are Graph Neural Networks?Yannic KilcherGIN paper · 1-WL equivalence proof walkthrough
Graph Neural Networks for MoleculesValence LabsMPNN for drug discovery · practical considerations
Equivariant Graph Neural NetworksProf. Erik BekkersSE(3)-equivariance · EGNN · geometric deep learning
DimeNet: Directional Message PassingJohannes KlicperaBond angle encoding · directional MPNN for QM9

Common Pitfalls

:::danger Using MEAN aggregation when structural expressiveness matters MEAN and MAX aggregations are not injective over multisets - two different neighborhood configurations can produce the same aggregated value. For tasks requiring structural discrimination (graph classification, isomorphism detection), use SUM aggregation with an MLP update (GIN-style). MEAN discards neighborhood size: a node with 2 identical neighbors looks identical to a node with 20 identical neighbors. :::

:::danger Ignoring edge features for molecular tasks GCN and basic GAT have no edge feature support. Using them for molecular property prediction discards bond type, bond order, and bond length - critical chemical information. The difference between a C=C double bond and a C-C single bond changes reactivity completely. Always use a proper MPNN with edge features for chemistry. :::

:::warning Too many message passing steps creates over-smoothing After many rounds of aggregation, all node embeddings converge toward a similar vector (proportional to the graph's dominant eigenvector). This is over-smoothing. For most molecular tasks (where relevant patterns are local), 3-6 steps is optimal. Adding more steps past this threshold typically hurts performance. :::

:::warning Weight-tied vs per-layer weights Gilmer et al. used the same message/update weights at every step (weight-tied). This reduces parameter count and generalizes better with small training sets. Per-layer weights are better when the information content at different hop distances is qualitatively different. Start with weight-tied layers for molecular tasks. :::


Interview Q&A

Q: Prove that GCN and GAT are both special cases of the MPNN framework. What is the key difference between them in MPNN terms?

A: Both are MPNNs that differ only in the message function. GCN's message from node uu to node vv is M(hu,hv,e)=1dudvWhuM(h_u, h_v, e) = \frac{1}{\sqrt{d_u d_v}} W h_u - a linear transform of the neighbor's features scaled by a fixed topology-derived normalization constant. GAT's message is M(hu,hv,e)=αvuWhuM(h_u, h_v, e) = \alpha_{vu} W h_u where αvu=softmaxu(LeakyReLU(aT[WhvWhu]))\alpha_{vu} = \text{softmax}_u(\text{LeakyReLU}(a^T[Wh_v \| Wh_u])) - the coefficient is data-dependent, learned from node features. Both use sum aggregation and a nonlinear activation in the update. The difference: GCN uses fixed graph-structural weights; GAT uses learned content-based attention weights. Neither natively supports edge features. For molecular graphs with rich bond features, you need an MPNN like Gilmer et al.'s model with an edge-conditioned message function.

Q: What is the Weisfeiler-Lehman expressivity limit? Give a concrete example of two graphs that no MPNN can distinguish.

A: 1-WL iteratively refines node labels by hashing each node's label with the multiset of its neighbors' labels. Any MPNN is structurally equivalent to 1-WL refinement - it maps (node state, multiset of neighbor states) to a new node state. Xu et al. (2019) proved no MPNN can exceed 1-WL expressiveness. A concrete failure: the 6-cycle (hexagon, every node degree 2) vs two disjoint triangles (every node degree 2). 1-WL assigns identical labels to all nodes in both graphs at every round because every node's 1-hop neighborhood looks the same in both graphs. A GNN initialized on degree features will produce identical embeddings for both graphs and cannot classify them differently. For chemistry: non-isomorphic molecules whose atoms all have identical local atom environments at every radius are indistinguishable by any standard MPNN.

Q: What is the difference between over-smoothing and over-squashing, and how do you fix each?

A: They are opposite failure modes of message passing. Over-smoothing: with too many layers, node representations converge toward a common vector through repeated averaging. All nodes look identical; the model loses discriminative power. Fix: fewer layers, residual/skip connections, or PairNorm. Over-squashing: when information must travel far across the graph, it is compressed through a narrow sequence of aggregations. The Jacobian hvT/hu0\partial h_v^T / \partial h_u^0 decays exponentially with graph distance, so long-range information is lost. Fix: graph rewiring (add virtual edges), graph transformers with global attention, or expander graph construction. They often co-occur - adding more layers to solve over-squashing worsens over-smoothing. The right tool is graph rewiring, not more layers.

Q: How does equivariant message passing (EGNN) differ from standard MPNN, and why does it matter for protein structure?

A: Standard MPNN operates on node feature scalars and uses topology. Equivariant GNNs extend this to 3D coordinates as geometric objects. EGNN updates node positions xiR3x_i \in \mathbb{R}^3 alongside scalar features hih_i. Messages use xixj2\|x_i - x_j\|^2 (a rotation-invariant scalar) as input. Coordinate updates use (xixj)(x_i - x_j) displacement vectors weighted by a scalar from the message network - making the update rotation-equivariant. SE(3)-equivariance means: rotate/translate the input → the output coordinates rotate/translate identically, and scalar predictions are unchanged. For proteins, this is essential: energy, force, and binding affinity are geometric properties of 3D structure. A non-equivariant model must learn rotation invariance from data, requiring vastly more training examples and still generalizing poorly to unseen orientations.

Q: How would you implement an MPNN for edge-level tasks like predicting whether a chemical bond will break?

A: After TT rounds of node-level message passing, compute edge representations from endpoint node states: eij=ϕedge(hiT,hjT,eij)e_{ij}' = \phi_{\text{edge}}(h_i^T, h_j^T, e_{ij}) where ϕedge\phi_{\text{edge}} is an MLP. The final node states hiTh_i^T encode each atom's full TT-hop neighborhood context, so the edge representation captures both endpoints' chemical environments. For directed tasks (predicting which direction a reaction proceeds), use an asymmetric function: eij=ϕedge(hiT,hjT,eij)e_{i \to j}' = \phi_{\text{edge}}(h_i^T, h_j^T, e_{ij}) differs from eji=ϕedge(hjT,hiT,eij)e_{j \to i}' = \phi_{\text{edge}}(h_j^T, h_i^T, e_{ij}) by argument order. Apply binary cross-entropy on top of eije_{ij}' for bond-breaking classification. Ensure the training data preserves edge directionality and that the loss is computed per-edge rather than per-graph.

Q: Why is sum aggregation strictly more expressive than mean or max aggregation for MPNN?

A: Sum aggregation is injective over multisets of bounded values (proven by Zaheer et al., 2017): different multisets of neighbor features always produce different sums (when using a universal function approximator applied after the sum). Mean aggregation discards neighborhood size - a node with neighbors {a,a}\{a, a\} produces the same mean as a node with neighbors {a,a,a,a}\{a, a, a, a\}. Max aggregation discards multiplicity - {a,a,b}\{a, a, b\} produces the same max as {a,b}\{a, b\}. For structural discrimination tasks, this means GCN (sum after normalization, which is approximately mean) and some GraphSAGE variants are strictly less expressive than GIN with sum. In practice, for tasks where only the presence/absence of neighbor types matters (not counts), mean works well and is more robust to high-degree nodes. Match the aggregation function to the task's symmetry requirements.


GIN - The Maximally Expressive MPNN

If any MPNN can be at most as powerful as 1-WL, what is the most powerful MPNN? Xu et al. (2019) answered this with the Graph Isomorphism Network (GIN):

hv(k+1)=MLP(k) ⁣((1+ε(k))hv(k)+uN(v)hu(k))h_v^{(k+1)} = \text{MLP}^{(k)}\!\left((1 + \varepsilon^{(k)}) \cdot h_v^{(k)} + \sum_{u \in \mathcal{N}(v)} h_u^{(k)}\right)

Two design choices make GIN maximally expressive:

  1. Sum aggregation: injective over multisets - different neighbor multisets always produce different sums
  2. MLP update: a universal function approximator - can represent any continuous function

The term (1+ε)hv(k)(1 + \varepsilon) h_v^{(k)} ensures the central node's features are distinguished from neighbors - equivalent to adding a self-loop with a learned weight. The scalar ε\varepsilon is either learned or fixed to 0 (both work in practice).

Why other aggregations are less expressive:

AggregatorInjective over multisets?Example failure
SUM + MLPYes - maximally expressive-
MEANNo - loses neighborhood size{a,a}\{a,a\} and {a,a,a,a}\{a,a,a,a\} → same mean
MAXNo - loses multiplicity{a,a,b}\{a,a,b\} and {a,b}\{a,b\} → same max
GCN (normalized sum)No - approximately meansame as MEAN failures

GIN with SUM + MLP is the theoretically optimal choice for structural discrimination. For tasks where you care about neighbor types but not counts (common in social networks), MEAN works fine. For tasks requiring structural isomorphism detection, use GIN.


MPNN Variants Comparison Table

ArchitectureMessage FunctionAggregationUpdateEdge FeaturesExpressiveness
GCN1dudvWhu\frac{1}{\sqrt{d_u d_v}} W h_uSUMσ()\sigma(\cdot)NoStrictly less than 1-WL
GraphSAGEhuh_uMEAN/MAX/LSTMσ(W[hvmv])\sigma(W[h_v \| m_v])NoLess than 1-WL (MEAN)
GATαvuWhu\alpha_{vu} W h_uSUMσ()\sigma(\cdot)NoStrictly less than 1-WL
MPNN (Gilmer)A(evw)huA(e_{vw}) h_uSUMGRU cellYesUp to 1-WL
GINhuh_uSUMMLP((1+ε)hvh_v + mvm_v)NoExactly 1-WL
EGNNϕe(hi,hj,dij2)\phi_e(h_i, h_j, d_{ij}^2)SUMϕh\phi_h + coord updateDistUp to 1-WL + SE(3) equivariant
DimeNetDirected bond + angleSUMMLPAngle + distUp to 1-WL + angular info

Choosing Your MPNN Architecture

The right MPNN depends on your task. Use this decision guide:

Practical rules of thumb:

  • For molecular property prediction with bond features: use edge-conditioned MPNN (Gilmer et al.) with T=3T = 366 steps
  • For citation networks, social graphs: GCN or GAT with 2 layers is usually optimal
  • For structural graph classification: GIN with SUM aggregation
  • For 3D molecular geometry (protein structure, force fields): EGNN or DimeNet
  • When in doubt: start with a 2-layer GCN baseline, then systematically upgrade

Production Engineering Notes

Batching Graphs in PyTorch Geometric

Unlike images (fixed size) or text (padded to same length), graphs come in different sizes. PyG handles this with a disjoint union batching approach:

from torch_geometric.loader import DataLoader

# PyG automatically creates a single large disconnected graph from the batch
# Node indices are offset: graph 0 uses nodes 0..n0-1, graph 1 uses n0..n0+n1-1, etc.
# The `batch` tensor maps each node to its graph index: batch[i] = graph index of node i

loader = DataLoader(dataset, batch_size=32, shuffle=True)
for data in loader:
# data.x: [total_nodes_in_batch, node_features]
# data.edge_index: [2, total_edges_in_batch] - offsets applied automatically
# data.batch: [total_nodes_in_batch] - graph membership
out = model(data)

This is why global_add_pool(x, batch) works - it sums over all nodes belonging to the same graph.

Gradient Clipping for GNN Training

GNNs can have unstable gradients, especially with many layers or GRU-based updates. Always clip gradients:

optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

Clip norm of 1.0 is the standard for GNNs on molecular tasks. If training diverges, check gradient norms with torch.nn.utils.clip_grad_norm_ (without clipping first) to diagnose.

Monitoring GNN Training Health

Three signals to watch during training:

  1. Gradient norm: should stabilize below 1.0; if exploding, reduce learning rate or add clipping
  2. Node embedding variance: sample 100 node embeddings per epoch and compute pairwise cosine similarity - if they converge toward 1.0, you have over-smoothing
  3. Validation loss curve: GNNs on molecular tasks often show a "double descent" - initial improvement, then plateau, then improvement again as the model learns long-range patterns

Summary

The MPNN framework shows that GCN, GraphSAGE, GAT, and nearly every graph neural network published before 2019 are all special cases of a three-step computation: message, aggregate, update. The differences between architectures reduce to implementation choices within these three functions.

The framework has a hard theoretical ceiling: no MPNN can be more expressive than the 1-WL Weisfeiler-Lehman test. GIN achieves this ceiling exactly - it is the maximally expressive MPNN. Breaking beyond 1-WL requires either higher-order GNNs (exponential cost) or structural encodings (practical, problem-specific). The 1-WL limitation is not a bug - it is an inherent property of the message-passing paradigm.

For 3D molecular tasks, equivariant extensions (EGNN, DimeNet) incorporate geometric information while respecting the symmetries of physical space, achieving chemical accuracy on quantum property prediction benchmarks. For production drug discovery, MPNN-based screening pipelines have reduced candidate selection from weeks of simulation to minutes of inference.

Understanding MPNN at this depth - the unified framework, the expressivity analysis, the equivariant extensions, and the practical engineering tradeoffs - is what separates practitioners who can use GNNs from practitioners who can design them.


Graph Transformers - Going Beyond MPNN

The 1-WL expressivity ceiling and over-squashing have motivated a new class of models: Graph Transformers that use global self-attention over all node pairs, bypassing the MPNN locality constraint entirely.

How Graph Transformers Work

A Graph Transformer treats the graph as a set of nodes and applies Transformer-style attention between all pairs:

Attn(v,u)=exp(qvTku/d)wexp(qvTkw/d)\text{Attn}(v, u) = \frac{\exp(q_v^T k_u / \sqrt{d})}{\sum_{w} \exp(q_v^T k_w / \sqrt{d})}

hv=uAttn(v,u)vuh_v' = \sum_u \text{Attn}(v, u) \cdot v_u

This is identical to standard Transformer attention. The graph structure is incorporated through structural positional encodings added to node features - otherwise the model is permutation equivariant but graph-unaware.

Key Graph Transformer Models

Graphormer (Ying et al., Microsoft, 2021) won the OGB-LSC quantum chemistry challenge. It adds:

  • Centrality encoding: degree of each node as a learnable bias to the attention logit
  • Spatial encoding: shortest path distance between node pairs as an attention bias
  • Edge encoding: average of edge features along the shortest path as additional attention input

αij=(hiWQ)(hjWK)Td+bϕ(i,j)+cdeg(i)+cdeg(j)\alpha_{ij} = \frac{(h_i W_Q)(h_j W_K)^T}{\sqrt{d}} + b_{\phi(i,j)} + c_{deg(i)} + c_{deg(j)}

where bϕ(i,j)b_{\phi(i,j)} is a learnable bias depending only on the shortest path distance between ii and jj.

SAN (Spectral Attention Network, 2021): uses Laplacian eigenvectors as positional encodings, giving each node a unique structural fingerprint that encodes its global position in the graph.

GPS (General Powerful Scalable, 2022): combines local MPNN layers with global self-attention. Each GPS layer has two sub-layers: an MPNN layer for local structure and a Transformer layer for global context. This hybridization gives local + global information flow without the full O(n2)O(n^2) cost of pure attention (since local MPNN is O(nd)O(n \cdot d)).

When to Use Graph Transformers vs MPNN

CriterionMPNNGraph Transformer
Graph sizeAny (with mini-batching)Typically less than 1K nodes (due to O(n2)O(n^2) attention)
Long-range dependenciesPoor (over-squashing)Excellent (global attention)
ExpressivenessUp to 1-WLBeyond 1-WL (with structural encodings)
Inference speedFastSlower (O(n2)O(n^2) per layer)
Molecular tasks (small graphs)ExcellentExcellent - often better
Social networks (large graphs)ScalableNeeds approximation (e.g., Performer, Linformer)

For molecular property prediction where graphs are small (10-50 atoms), Graph Transformers with structural encodings currently set the state-of-the-art on QM9 and similar benchmarks. For large-graph tasks, MPNNs with structural encodings are still the practical choice.


Scalable MPNN Training

Mini-Batch Training on Large Graphs

The challenge: backpropagation through KK MPNN layers requires materializing the KK-hop computation subgraph for each training node - which can be enormous for high-degree nodes in social graphs.

Neighbor sampling (GraphSAGE): sample a fixed number of neighbors at each hop. KK layers with fan-out [f1,f2,,fK][f_1, f_2, \ldots, f_K] gives at most fk\prod f_k nodes per target, making batch size predictable.

Layer-wise sampling (LADIES, FastGCN): instead of sampling per-node, sample a fixed set of nodes per layer. Faster but introduces approximation error.

Subgraph sampling (GraphSAINT): sample a connected subgraph (random walk, random edge, random node) and train on it as if it were the full graph. Subgraph normalization corrects for sampling bias. Scales to graphs with billions of edges.

Cluster-GCN: partition the graph into clusters (using METIS or KMeans on node features), assign each cluster to a mini-batch. Efficient because edges within a cluster are dense; cross-cluster edges are discarded during that mini-batch (a principled approximation).

# GraphSAINT random walk sampler - scales to billion-edge graphs
from torch_geometric.loader import GraphSAINTRandomWalkSampler

loader = GraphSAINTRandomWalkSampler(
data,
batch_size=6000, # nodes per subgraph
walk_length=2, # random walk steps
num_steps=5, # mini-batches per epoch
sample_coverage=100,# samples to estimate normalization
)
for subgraph in loader:
out = model(subgraph.x, subgraph.edge_index)
loss = criterion(out[subgraph.train_mask], subgraph.y[subgraph.train_mask])

:::tip Which sampler to use For citation graphs (Cora, CiteSeer, ogbn-arxiv): neighbor sampling with GraphSAGE. For social graphs (ogbn-products, Reddit): GraphSAINT or Cluster-GCN. For molecular graphs: full batch - molecules are small enough to fit in memory. For heterogeneous graphs: use PyG's NeighborLoader with node type-specific fan-outs. :::

:::tip 🎮 Interactive Playground

Visualize this concept: Try the GNN Message Passing demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.