Skip to main content

Graph Theory for GNNs - Message Passing, Expressiveness, and Over-Smoothing

Reading time: ~28 minutes | Level: Graph Theory → GNN Engineering

In an interview at a top AI lab, the interviewer asks: "Why can't standard GNNs distinguish between a pair of triangles sharing one edge versus a 4-node path graph with the same degree sequence?"

If you don't know the Weisfeiler-Leman test, you cannot answer this. If you cannot answer it, you don't know the fundamental limits of the GNNs you're deploying.

This lesson connects graph theory to GNN design - not just how to call GCNConv, but why it works, where it fails, and what you can do about it.

What You Will Learn

  • The message passing neural network (MPNN) framework
  • GCN, GraphSAGE, GAT, and GIN: derivations and key differences
  • Weisfeiler-Leman (WL) test: the expressiveness bound on MPNNs
  • Over-smoothing: spectral analysis and mitigation strategies
  • Heterogeneous GNNs for knowledge graphs
  • PyTorch Geometric implementation patterns

Part 1 - The Message Passing Framework (MPNN)

Unifying formulation

All major GNN variants can be written as a message passing neural network (MPNN):

For each node vv at layer ll:

mv(l)=AGGREGATE({hu(l1):uN(v)})m_v^{(l)} = \text{AGGREGATE}\left(\{h_u^{(l-1)} : u \in \mathcal{N}(v)\}\right)

hv(l)=UPDATE(hv(l1),mv(l))h_v^{(l)} = \text{UPDATE}\left(h_v^{(l-1)}, m_v^{(l)}\right)

After LL layers, for graph-level tasks:

hG=READOUT({hv(L):vV})h_G = \text{READOUT}\left(\{h_v^{(L)} : v \in V\}\right)

The specific choices of AGGREGATE, UPDATE, and READOUT define different GNN architectures.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter

class MPNNLayer(nn.Module):
"""
Generic message passing layer.
Subclass and implement aggregate() and update() for specific GNNs.
"""

def __init__(self, node_dim: int, out_dim: int, edge_dim: int = 0):
super().__init__()
self.node_dim = node_dim
self.out_dim = out_dim

# Message function: processes (source, dest, edge) features
self.message_fn = nn.Sequential(
nn.Linear(2 * node_dim + edge_dim, out_dim),
nn.ReLU()
)

# Update function: combines old node state with aggregated messages
self.update_fn = nn.Sequential(
nn.Linear(node_dim + out_dim, out_dim),
nn.ReLU()
)

def forward(self, x, edge_index, edge_attr=None):
"""
x: (n_nodes, node_dim)
edge_index: (2, n_edges) - [source_indices, target_indices]
edge_attr: (n_edges, edge_dim) or None
"""
src, dst = edge_index[0], edge_index[1]

# Compute messages for each edge
src_features = x[src] # (n_edges, node_dim)
dst_features = x[dst] # (n_edges, node_dim)

if edge_attr is not None:
message_input = torch.cat([src_features, dst_features, edge_attr], dim=-1)
else:
message_input = torch.cat([src_features, dst_features], dim=-1)

messages = self.message_fn(message_input) # (n_edges, out_dim)

# Aggregate messages for each destination node
agg_messages = scatter(messages, dst, dim=0, dim_size=x.size(0),
reduce='sum') # (n_nodes, out_dim)

# Update node features
update_input = torch.cat([x, agg_messages], dim=-1)
return self.update_fn(update_input)

Part 2 - GCN, GraphSAGE, GAT, and GIN

GCN (Graph Convolutional Network - Kipf & Welling, 2017)

hv(l)=ReLU(W(l)uN(v){v}hu(l1)dudv)h_v^{(l)} = \text{ReLU}\left(W^{(l)} \sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{h_u^{(l-1)}}{\sqrt{d_u d_v}}\right)

Aggregation: weighted sum with symmetric normalization.

from torch_geometric.nn import GCNConv

class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)

def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=-1)

GraphSAGE (Hamilton et al., 2017)

GraphSAGE samples a fixed number of neighbors and uses various aggregators:

hv(l)=σ(W(l)CONCAT(hv(l1),AGG({hu(l1):uN(v)})))h_v^{(l)} = \sigma\left(W^{(l)} \cdot \text{CONCAT}\left(h_v^{(l-1)}, \text{AGG}(\{h_u^{(l-1)} : u \in \mathcal{N}(v)\})\right)\right)

from torch_geometric.nn import SAGEConv

class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels,
aggr: str = 'mean'):
"""
aggr: 'mean', 'max', 'lstm' - different aggregation strategies
mean: robust, most commonly used
max: captures the most extreme feature value from neighborhood
lstm: sequential aggregation (NOT permutation-invariant without resampling)
"""
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels, aggr=aggr)
self.conv2 = SAGEConv(hidden_channels, out_channels, aggr=aggr)

def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = self.conv2(x, edge_index)
return x

Graph Attention Network (GAT - Veličković et al., 2018)

GAT replaces the fixed normalization of GCN with learned attention weights:

αvu=exp(LeakyReLU(aT[WhvWhu]))kN(v)exp(LeakyReLU(aT[WhvWhk]))\alpha_{vu} = \frac{\exp(\text{LeakyReLU}(a^T [Wh_v \| Wh_u]))}{\sum_{k \in \mathcal{N}(v)} \exp(\text{LeakyReLU}(a^T [Wh_v \| Wh_k]))}

hv(l)=σ(uN(v)αvuWhu)h_v^{(l)} = \sigma\left(\sum_{u \in \mathcal{N}(v)} \alpha_{vu} W h_u\right)

from torch_geometric.nn import GATConv

class GAT(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels,
heads: int = 8, dropout: float = 0.6):
super().__init__()
# Multi-head attention: heads independent attention mechanisms
self.conv1 = GATConv(in_channels, hidden_channels // heads,
heads=heads, dropout=dropout)
# Single head for final layer (or concat all heads)
self.conv2 = GATConv(hidden_channels, out_channels,
heads=1, dropout=dropout)

def forward(self, x, edge_index):
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=-1)

# GAT attention weights can be visualized
def get_attention_weights(model, x, edge_index):
"""Extract and visualize attention weights from GAT."""
model.eval()
with torch.no_grad():
_, (edge_index_out, alpha) = model.conv1(
x, edge_index, return_attention_weights=True
)
return alpha # (n_edges, n_heads) attention weights

GIN (Graph Isomorphism Network - Xu et al., 2019)

GIN is designed to be maximally expressive (up to the WL test):

hv(l)=MLP(l)((1+ϵ(l))hv(l1)+uN(v)hu(l1))h_v^{(l)} = \text{MLP}^{(l)}\left((1 + \epsilon^{(l)}) h_v^{(l-1)} + \sum_{u \in \mathcal{N}(v)} h_u^{(l-1)}\right)

from torch_geometric.nn import GINConv

class GIN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, n_layers: int = 3):
super().__init__()
self.convs = nn.ModuleList()

for i in range(n_layers):
in_ch = in_channels if i == 0 else hidden_channels
mlp = nn.Sequential(
nn.Linear(in_ch, hidden_channels),
nn.BatchNorm1d(hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, hidden_channels)
)
# eps=0: learnable epsilon; train_eps=True: make it a parameter
self.convs.append(GINConv(mlp, train_eps=True))

self.classifier = nn.Linear(hidden_channels, out_channels)

def forward(self, x, edge_index, batch=None):
for conv in self.convs:
x = F.relu(conv(x, edge_index))

# Global sum pooling for graph-level classification
if batch is not None:
from torch_geometric.nn import global_add_pool
x = global_add_pool(x, batch)

return self.classifier(x)

Part 3 - GNN Expressiveness: The Weisfeiler-Leman Test

The 1-WL (Weisfeiler-Leman) graph isomorphism test

The WL test iteratively refines node colors to test graph isomorphism:

1. Initialize: each node gets color = hash(initial_feature)
2. Repeat until stable:
a. For each node v: new_color[v] = hash(color[v], SORT({color[u] : u ∈ N(v)}))
3. Two graphs are potentially isomorphic iff they have the same multiset of final colors
from collections import defaultdict
import hashlib

def wl_graph_hash(G, node_features: dict = None, n_iterations: int = 3) -> str:
"""
Compute WL graph hash - same hash implies potentially isomorphic graphs.
"""
nodes = list(G.nodes())

# Initialize colors
if node_features:
colors = {v: str(node_features[v]) for v in nodes}
else:
colors = {v: str(G.degree(v)) for v in nodes} # Use degree as initial label

for _ in range(n_iterations):
new_colors = {}
for v in nodes:
# Aggregate: sort neighbor colors for permutation invariance
neighbor_colors = sorted([colors[u] for u in G.neighbors(v)])
combined = colors[v] + '|' + ','.join(neighbor_colors)
new_colors[v] = hashlib.md5(combined.encode()).hexdigest()[:8]
colors = new_colors

# Graph-level representation: sorted multiset of node colors
color_multiset = sorted(colors.values())
graph_hash = hashlib.md5(str(color_multiset).encode()).hexdigest()
return graph_hash

# Example: distinguishable vs indistinguishable graphs
import networkx as nx

# Triangle
G_triangle = nx.cycle_graph(3)

# Path of length 3 (same number of nodes and edges as triangle)
G_path = nx.path_graph(3)

print(f"Triangle WL hash: {wl_graph_hash(G_triangle)}")
print(f"Path hash: {wl_graph_hash(G_path)}")
print(f"Same hash: {wl_graph_hash(G_triangle) == wl_graph_hash(G_path)}")
# Different! WL test correctly distinguishes them (triangle has cycles)

The WL-GNN equivalence theorem (Xu et al., 2019)

Theorem: Any MPNN is at most as powerful as the 1-WL test. That is:

  • If the WL test says two graphs are non-isomorphic, MPNNs with sufficient expressiveness can distinguish them
  • If the WL test says two graphs are isomorphic, no MPNN can distinguish them

GIN achieves exactly the power of 1-WL - the theoretical maximum for MPNNs with:

  • Sum aggregation (not mean or max)
  • Injective MLP as the update function
# Pair of graphs that the WL test (and therefore all MPNNs) CANNOT distinguish
import networkx as nx

# Two non-isomorphic graphs with identical 1-WL signatures
# (require higher-order WL to distinguish)

# Regular graph 1: two triangles sharing one edge (bowtie)
G1 = nx.Graph()
G1.add_edges_from([(0,1),(1,2),(2,0),(2,3),(3,4),(4,2)])

# Regular graph 2: 6-cycle
G2 = nx.cycle_graph(6)

# Both have 6 nodes, 6 edges, all nodes degree 2
# WL test: both start with all nodes having the same color (degree 2)
# WL refinement: all neighbors also have degree 2 → same color profile
# WL cannot distinguish them!

hash_g1 = wl_graph_hash(G1)
hash_g2 = wl_graph_hash(G2)
print(f"Bowtie WL hash: {hash_g1}")
print(f"6-cycle WL hash: {hash_g2}")
print(f"Indistinguishable by WL (and MPNNs): {hash_g1 == hash_g2}")

# Actual isomorphism check
print(f"Graphs are isomorphic: {nx.is_isomorphic(G1, G2)}") # False, but WL can't tell

Beyond WL: higher-order GNNs

To distinguish graphs that 1-WL cannot, use:

  • k-WL: Test equivalence by coloring kk-tuples of nodes instead of individual nodes
  • k-GNN: Operate on subgraphs of size kk instead of individual nodes - more expensive but more powerful
  • NGNN (Nested GNN): Run a GNN inside each node's neighborhood subgraph

Part 4 - Over-Smoothing: Analysis and Solutions

Why over-smoothing occurs

From the spectral perspective (Lesson 04): GCN propagation applies a low-pass filter - high-frequency (discriminative) components decay exponentially with depth.

From the message passing perspective: after LL layers, each node aggregates information from all nodes within LL hops. For small-world graphs, L=5L = 5 may already cover nearly the entire graph. Once every node sees the same global information, features converge.

import torch
import torch.nn.functional as F
import numpy as np

def measure_oversmoothing(model, data, n_layers_range: range) -> dict:
"""
Measure feature diversity as a function of GNN depth.
Higher diversity = less over-smoothing.
"""
from scipy.spatial.distance import pdist

results = {}
model.eval()

with torch.no_grad():
# Get features after each layer
x = data.x
for layer_idx, conv in enumerate(model.convs):
if layer_idx in n_layers_range:
# Pairwise cosine similarity
x_np = F.normalize(x, dim=-1).numpy()
similarities = pdist(x_np, metric='cosine')
diversity = 1 - np.mean(similarities) # Higher = more similar = more smoothed
results[layer_idx] = diversity

x = F.relu(conv(x, data.edge_index))

return results

Solutions to over-smoothing

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, JumpingKnowledge

# Solution 1: Jumping Knowledge Networks (JK-Net)
# Aggregate features from ALL layers, not just the last
class JKNet(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels,
n_layers: int = 6, mode: str = 'cat'):
super().__init__()
self.convs = nn.ModuleList([
GCNConv(in_channels if i == 0 else hidden_channels, hidden_channels)
for i in range(n_layers)
])
# Aggregate representations from all layers
self.jk = JumpingKnowledge(mode=mode, channels=hidden_channels,
num_layers=n_layers)
# JK 'cat' mode: concatenates all layer outputs
in_jk = hidden_channels * n_layers if mode == 'cat' else hidden_channels
self.classifier = nn.Linear(in_jk, out_channels)

def forward(self, x, edge_index):
layer_outputs = []
for conv in self.convs:
x = F.relu(conv(x, edge_index))
layer_outputs.append(x)

# JK aggregation: cat, max, or lstm
x = self.jk(layer_outputs)
return self.classifier(x)

# Solution 2: APPNP (Approximate Personalized Propagation of Neural Predictions)
from torch_geometric.nn import APPNP as PyGAPPNP

class APPNP_Model(nn.Module):
"""
Separate feature transformation from propagation.
Propagation uses Personalized PageRank (avoids over-smoothing via α damping).
"""
def __init__(self, in_channels, hidden_channels, out_channels,
K: int = 10, alpha: float = 0.1):
super().__init__()
# MLP for feature transformation (no propagation)
self.mlp = nn.Sequential(
nn.Linear(in_channels, hidden_channels),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(hidden_channels, out_channels)
)
# PPR propagation (K iterations with teleportation probability α)
self.prop = PyGAPPNP(K=K, alpha=alpha)

def forward(self, x, edge_index):
# 1. Transform features (no graph structure)
h = self.mlp(x)

# 2. Propagate using PPR (not simple repeated multiplication)
# α controls how much weight to give the local features
return F.log_softmax(self.prop(h, edge_index), dim=-1)

# Solution 3: DropEdge - randomly remove edges during training
class DropEdgeGCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, p: float = 0.5):
super().__init__()
self.convs = nn.ModuleList([
GCNConv(in_channels, hidden_channels),
GCNConv(hidden_channels, hidden_channels),
GCNConv(hidden_channels, out_channels)
])
self.drop_edge_p = p

def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
if self.training:
# Randomly drop edges (reduces effective propagation depth)
mask = torch.rand(edge_index.size(1)) > self.drop_edge_p
edge_index_drop = edge_index[:, mask]
else:
edge_index_drop = edge_index

x = conv(x, edge_index_drop)
if i < len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)

return F.log_softmax(x, dim=-1)

Part 5 - GNNs for Knowledge Graphs (Heterogeneous)

Relational GCN (RGCN)

For knowledge graphs with multiple relation types:

hv(l+1)=σ(W0(l)hv(l)+rRuNr(v)1Nr(v)Wr(l)hu(l))h_v^{(l+1)} = \sigma\left(W_0^{(l)} h_v^{(l)} + \sum_{r \in \mathcal{R}} \sum_{u \in \mathcal{N}_r(v)} \frac{1}{|\mathcal{N}_r(v)|} W_r^{(l)} h_u^{(l)}\right)

from torch_geometric.nn import RGCNConv
import torch.nn as nn

class RGCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels,
n_relations: int, num_bases: int = 30):
super().__init__()
# num_bases: basis decomposition reduces #parameters from
# n_relations × in × out to num_bases × in × out + n_relations × num_bases
self.conv1 = RGCNConv(in_channels, hidden_channels,
num_relations=n_relations, num_bases=num_bases)
self.conv2 = RGCNConv(hidden_channels, out_channels,
num_relations=n_relations, num_bases=num_bases)

def forward(self, x, edge_index, edge_type):
"""
edge_type: (n_edges,) integer tensor - relation type per edge
"""
x = torch.relu(self.conv1(x, edge_index, edge_type))
x = self.conv2(x, edge_index, edge_type)
return x

Part 6 - Complete GNN Training Pipeline

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

def train_gcn_node_classification(
dataset_name: str = 'Cora',
hidden_channels: int = 64,
n_epochs: int = 200,
lr: float = 0.01,
weight_decay: float = 5e-4
):
"""
Full GCN training pipeline for node classification.
Standard benchmark on Cora/CiteSeer/PubMed citation networks.
"""
# Load dataset
dataset = Planetoid(root='/tmp/Planetoid', name=dataset_name,
transform=NormalizeFeatures())
data = dataset[0]

print(f"Dataset: {dataset_name}")
print(f" Nodes: {data.num_nodes}")
print(f" Edges: {data.num_edges}")
print(f" Node features: {data.num_node_features}")
print(f" Classes: {dataset.num_classes}")
print(f" Train nodes: {data.train_mask.sum()}")
print(f" Val nodes: {data.val_mask.sum()}")
print(f" Test nodes: {data.test_mask.sum()}")

# Build model
class GCNNodeClassifier(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, dataset.num_classes)

def forward(self, x, edge_index):
x = F.dropout(x, p=0.5, training=self.training)
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCNNodeClassifier().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr,
weight_decay=weight_decay)

def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()

def evaluate(mask):
model.eval()
with torch.no_grad():
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
correct = (pred[mask] == data.y[mask]).sum()
acc = correct / mask.sum()
return acc.item()

# Training loop
best_val_acc = 0
best_test_acc = 0

for epoch in range(1, n_epochs + 1):
loss = train()
val_acc = evaluate(data.val_mask)
test_acc = evaluate(data.test_mask)

if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc

if epoch % 50 == 0:
print(f"Epoch {epoch:3d}: loss={loss:.4f}, "
f"val_acc={val_acc:.4f}, test_acc={test_acc:.4f}")

print(f"\nBest test accuracy: {best_test_acc:.4f}")
return model, best_test_acc

Interview Questions

Q1: Explain the message passing framework and how GCN, GAT, and GIN differ within it.

The MPNN (Message Passing Neural Network) framework has three components:

  1. AGGREGATE: Combine messages from neighbors
  2. UPDATE: Update node state using own state + aggregated messages
  3. READOUT: Produce graph-level representation (for graph tasks)

GCN (Kipf & Welling):

  • AGGREGATE: Weighted sum uN(v){v}1dudvhu\sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{1}{\sqrt{d_u d_v}} h_u
  • Fixed weights based on degree - no learning in the aggregation
  • Spectral motivation: 1st-order Chebyshev approximation

GAT (Veličković et al.):

  • AGGREGATE: Attention-weighted sum uαvuhu\sum_{u} \alpha_{vu} h_u
  • Attention weights αvu\alpha_{vu} are learned based on feature similarity
  • Advantage: different neighbors contribute different amounts based on content

GIN (Xu et al.):

  • AGGREGATE: Sum (not mean! Mean loses structural information)
  • UPDATE: Injective MLP on (node_features + summed_neighbors)
  • Motivation: Sum aggregation is necessary for WL-equivalent expressiveness

Key insight: GCN's mean normalization loses information about neighborhood size - two nodes with the same average neighbor feature but different numbers of neighbors look identical. GIN's sum aggregation preserves this count, achieving maximum expressiveness under the WL bound.

Q2: What is the Weisfeiler-Leman test and what does it mean for GNN expressiveness?

The 1-WL (Weisfeiler-Leman) test iteratively refines node colorings to test graph isomorphism:

  1. Initialize colors from initial node labels (or all the same if unlabeled)
  2. Each iteration: c(v)hash(c(v),sorted_multiset({c(u):uN(v)}))c(v) \leftarrow \text{hash}(c(v), \text{sorted\_multiset}(\{c(u) : u \in \mathcal{N}(v)\}))
  3. If color distributions differ between graphs → non-isomorphic
  4. If same colors after stabilization → potentially isomorphic (not guaranteed)

The WL-GNN theorem (Xu et al., 2019):

Any MPNN ff satisfies: f(G1)f(G2)WL distinguishes G1 and G2f(G_1) \neq f(G_2) \Rightarrow \text{WL distinguishes } G_1 \text{ and } G_2

and

WL cannot distinguish G1 and G2f(G1)=f(G2)\text{WL cannot distinguish } G_1 \text{ and } G_2 \Rightarrow f(G_1) = f(G_2)

Interpretation: MPNNs are at most as powerful as 1-WL. They cannot distinguish any pair of graphs that 1-WL cannot distinguish.

GIN achieves the maximum: GIN with sum aggregation and injective MLP is as powerful as 1-WL. GCN (with mean aggregation) and GraphSAGE (with mean/max) are strictly less powerful.

Practical meaning: Consider two dd-regular graphs (all nodes have the same degree). WL assigns the same color to all nodes in both graphs - it cannot distinguish them. Therefore, no MPNN can distinguish them either. A GNN trained on one might transfer to the other, even if they are non-isomorphic.

Q3: What is over-smoothing in GNNs and what are three practical solutions?

Over-smoothing: As GNN depth increases, all node representations converge to the same vector, making nodes indistinguishable. Node classification accuracy typically peaks at 2-4 layers and degrades with more layers.

Cause: Repeated application of the aggregation step is equivalent to a low-pass graph filter. After LL steps, information from all nodes within LL hops is mixed - for small-world graphs, this covers nearly the entire graph after just a few layers.

Spectral view: The normalized adjacency A^\hat{A} has eigenvalues in (1,1](-1, 1]. After LL steps, A^L\hat{A}^L amplifies the eigenvalue-1 component (global mean) and suppresses all other components.

Three solutions:

  1. JK-Net (Jumping Knowledge Networks): Concatenate (or take max/LSTM over) representations from every layer, not just the last. Each layer captures a different radius neighborhood - the network learns which radius is useful per node.

  2. APPNP: Separate feature transformation (MLP) from graph propagation (personalized PageRank). PPR propagation uses a damping factor α\alpha that prevents the constant eigenvector from dominating. The node always retains (1α)(1-\alpha) of its own features.

  3. Residual connections (ResGCN): hv(l+1)=GCN(l)(hv(l))+hv(l)h_v^{(l+1)} = \text{GCN}^{(l)}(h_v^{(l)}) + h_v^{(l)}. Skip connections ensure that the identity information (layer-0 features) is always accessible, preventing full convergence to the global mean.

When to use:

  • JK-Net: When optimal radius is unknown per node (heterogeneous graph structure)
  • APPNP: When you want a principled approach with theoretical guarantees
  • Residual: Simple baseline improvement for deeper models
Q4: Why does mean aggregation lose information that sum aggregation preserves?

This is the key insight of GIN (Xu et al., 2019).

Mean aggregation computes: mv=1N(v)uN(v)hum_v = \frac{1}{|\mathcal{N}(v)|} \sum_{u \in \mathcal{N}(v)} h_u

This averages over the multiset of neighbor features. Two nodes with the same average neighbor feature but different neighborhood sizes look identical.

Example:

  • Node A with 3 neighbors each having feature 1: mean = 1
  • Node B with 6 neighbors each having feature 1: mean = 1

Mean aggregation: A and B look the same. But they are structurally different - different degrees, different neighborhood sizes.

Sum aggregation computes: mv=uN(v)hum_v = \sum_{u \in \mathcal{N}(v)} h_u

  • Node A: sum = 3
  • Node B: sum = 6

Sum aggregation distinguishes A and B.

Formal statement: Sum is the only aggregation that produces injective functions over multisets (when combined with an injective MLP). Mean and max are not injective over multisets:

  • Mean: {1,1}\{1, 1\} and {1}\{1\} both have mean 1 (but different structures)
  • Max: {1,1}\{1, 1\} and {1}\{1\} both have max 1 (same problem)

For graph-level tasks: Global mean pooling (average of all node features) also loses information - prefer sum pooling for expressive graph-level representations.

Q5: How does the attention mechanism in GAT differ from GCN, and when is attention beneficial?

GCN normalization: Edge weight = 1/dudv1/\sqrt{d_u d_v} - purely structural, independent of node features. Every neighbor of the same degree contributes equally.

GAT attention: Edge weight = αvu=softmax(f(hv,hu))\alpha_{vu} = \text{softmax}(f(h_v, h_u)) - learned, based on feature similarity. The network learns to weight neighbors differently based on their features.

Formally: αvu=exp(LeakyReLU(aT[WhvWhu]))kN(v)exp(LeakyReLU(aT[WhvWhk]))\alpha_{vu} = \frac{\exp(\text{LeakyReLU}(a^T [W h_v \| W h_u]))}{\sum_{k \in \mathcal{N}(v)} \exp(\text{LeakyReLU}(a^T [W h_v \| W h_k]))}

When GAT helps:

  1. Heterogeneous neighborhoods: When neighbors vary in relevance. Example: in a citation network, paper A may have 50 neighbors but only 5 are truly related - GAT can learn to focus on those 5.

  2. Knowledge graphs with multiple relation types: Without relation-type information, GAT can still learn to prioritize informative neighbors.

  3. Noisy graphs: Edges that are "wrong" (noise, sparsification errors) receive low attention weights, effectively filtering them out.

When GAT doesn't help (or hurts):

  1. Regular or homogeneous graphs: If all neighbors are equally informative (like symmetric molecular graphs), attention adds overhead with no benefit.

  2. Sparse labeled training sets: Attention weights require more data to learn reliably. With < 20 labeled nodes per class, GCN often outperforms GAT.

  3. Computational cost: GAT computes pairwise attention scores for each edge - O(m×d)O(m \times d) instead of O(m)O(m) for GCN. For dense graphs, this is significant.

Multi-head attention (standard in GAT): Use KK independent attention heads and concatenate (or average) their outputs. Multiple heads prevent degenerate attention (all weight on one neighbor) and capture different structural aspects.

Quick Reference

ArchitectureAggregationExpressiveness vs WLBest For
GCNNormalized sum (mean)Strictly lessHomophilic graphs, simple baseline
GraphSAGEMean/Max/LSTMStrictly lessInductive learning, large graphs
GATLearned attentionStrictly lessHeterogeneous neighborhoods
GINSum + MLPEquivalent to 1-WLGraph classification, max expressiveness
Over-smoothing SolutionKey IdeaOverhead
JK-NetConcatenate all layersK×K \times memory
APPNPPPR propagation + α dampingMinimal
Residual (ResGCN)Skip connectionsMinimal
DropEdgeRandom edge dropoutNone at test time
PairNormNormalize pairwise distancesMinimal

Module 09 complete - next: Module 10: Time Series Mathematics →

:::tip 🎮 Interactive Playground

Visualize this concept: Try the GNN Message Passing demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.