Skip to main content

Graph Representation for ML

The Real Interview Moment

"Walk me through how you would represent a molecular graph in memory for a GNN," the interviewer asks. "Consider 50,000 molecules in a training set, each with up to 100 atoms."

The wrong answer: "I would use an adjacency matrix." A 100×100 dense matrix per molecule, 50,000 molecules - 500 million floats before you even add features. For a sparse molecular graph where each atom has 2–4 bonds, you are wasting 99% of your memory on zeros. At float32, that is 2 GB of pure waste.

But before we even get to GNNs and efficient sparse storage, there is a deeper question: how do you represent graph nodes as feature vectors at all? The answer has evolved dramatically over the past decade - from hand-crafted structural features to random walk embeddings to end-to-end learned representations.

This lesson covers the full evolution. Understanding where DeepWalk, Node2Vec, and spectral embeddings fail explains exactly why GNNs were invented. Every architectural decision in modern GNNs is a direct response to a specific failure mode of these earlier methods.

Why This Exists - The Feature Engineering Problem on Graphs

Machine learning requires fixed-dimensional vector representations of inputs. For tabular data, this is trivial - each row is already a fixed-dimensional vector. For images, a pixel grid gives a natural vectorization. For text, tokenization and bag-of-words or embedding lookups work.

For graphs, the input is a set of entities connected by relationships. A node does not have a natural coordinate system. Two nodes in different parts of the same graph, or in completely different graphs, might have the same structural role but no obvious way to represent this equivalence as a vector.

The history of graph representation learning is the history of increasingly powerful answers to this question:

  1. Manual features (pre-2010): hand-engineer structural statistics (degree, clustering coefficient, motifs)
  2. Matrix factorization (2010–2014): factorize the adjacency matrix or its derivatives
  3. Random walk embeddings (2014–2017): DeepWalk, Node2Vec, LINE - skip-gram on graph walks
  4. Spectral embeddings (parallel track): eigenvectors of the graph Laplacian
  5. GNNs (2017–present): learn representations end-to-end with task supervision

Each generation fixed specific failures of the previous one, until GNNs unified the field.

Manual Node Features

Before any learning, compute structural statistics of each node and use them as features for downstream ML models.

Degree

The most basic feature: how many neighbors does a node have?

dv=N(v)=uAvud_v = |\mathcal{N}(v)| = \sum_{u} A_{vu}

Degree encodes local importance. High-degree nodes are hubs. In fraud detection, unusually high degree (one account connected to thousands of others) is itself a signal. In citation networks, degree correlates with paper influence.

For directed graphs, distinguish in-degree (how many point to vv) and out-degree (how many vv points to).

Clustering Coefficient

What fraction of a node's neighbors are also connected to each other?

Cv={(u,w)E:u,wN(v)}dv(dv1)/2C_v = \frac{|\{(u,w) \in E : u,w \in \mathcal{N}(v)\}|}{d_v(d_v - 1)/2}

Cv=1C_v = 1: all neighbors form a clique. Cv=0C_v = 0: no two neighbors are connected.

In social networks, high clustering coefficient indicates tight friend groups. In molecular graphs, high clustering indicates ring structures.

PageRank

PageRank (Page et al., 1999) assigns importance to nodes based on the importance of nodes that point to them - a recursive definition:

PR(v)=1dn+duin-neighbors(v)PR(u)out-neighbors(u)\text{PR}(v) = \frac{1-d}{n} + d \sum_{u \in \text{in-neighbors}(v)} \frac{\text{PR}(u)}{|\text{out-neighbors}(u)|}

where d0.85d \approx 0.85 is the damping factor. Solved iteratively until convergence.

PageRank as a node feature captures global graph-based importance beyond local degree. Papers with high PageRank are not just well-cited - they are cited by well-cited papers.

Motif Counts

Count how many times a node participates in specific subgraph patterns:

  • Triangle count: how many triangles include this node?
  • 4-cycle count: how many 4-cycles include this node?
  • Graphlet counts: participation in each graphlet (small induced subgraph patterns)

These capture the local structural "role" of a node. Nodes at the center of star-shaped patterns play a different structural role than nodes in dense cliques.

Complete Manual Feature Pipeline

import networkx as nx
import numpy as np

def compute_node_features(G):
"""
Compute manual structural features for all nodes.
Returns feature matrix of shape [n_nodes, n_features].
"""
nodes = list(G.nodes())
n = len(nodes)
node_to_idx = {v: i for i, v in enumerate(nodes)}
features = np.zeros((n, 6))

# 1. Degree (normalized by max degree)
max_deg = max(dict(G.degree()).values())
for v in nodes:
features[node_to_idx[v], 0] = G.degree(v) / max_deg

# 2. Clustering coefficient
clustering = nx.clustering(G)
for v in nodes:
features[node_to_idx[v], 1] = clustering[v]

# 3. PageRank (normalized - sums to 1)
pr = nx.pagerank(G, alpha=0.85, max_iter=200)
pr_vals = np.array([pr[v] for v in nodes])
features[:, 2] = (pr_vals - pr_vals.min()) / (pr_vals.ptp() + 1e-8)

# 4. Betweenness centrality (expensive: O(n*m))
bc = nx.betweenness_centrality(G, normalized=True)
for v in nodes:
features[node_to_idx[v], 3] = bc[v]

# 5. Closeness centrality
cc = nx.closeness_centrality(G)
for v in nodes:
features[node_to_idx[v], 4] = cc[v]

# 6. Triangle count (normalized)
triangles = nx.triangles(G)
max_tri = max(triangles.values()) if max(triangles.values()) > 0 else 1
for v in nodes:
features[node_to_idx[v], 5] = triangles[v] / max_tri

return features, nodes

# Example on Cora (converted to networkx)
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx

dataset = Planetoid(root='/tmp/cora', name='Cora')
data = dataset[0]
G = to_networkx(data, to_undirected=True)

X_manual, nodes = compute_node_features(G)
print(f"Feature matrix shape: {X_manual.shape}") # [2708, 6]
print("Feature columns: degree, clustering, pagerank, betweenness, closeness, triangles")

Limitations of Manual Features

Manual features have fundamental problems that motivated the next generation of methods:

  1. Feature selection is brittle: You must decide which structural features to compute before seeing task performance. Degree and clustering coefficient capture different things - which ones matter for fraud detection vs citation classification?

  2. Combinatorial explosion: The number of possible structural features is enormous (graphlets of size up to 5 already has 29 types). Computing all of them is expensive.

  3. No joint optimization: Manual features are computed independently from the task. A GNN jointly optimizes representations with the loss function, learning which structural patterns actually matter for the specific prediction task.

  4. No node content: Structural features ignore node attributes entirely (the paper's text in citation networks, the atom's charge in molecular graphs).

Random Walk Embeddings: The Skip-Gram Era

The breakthrough insight of DeepWalk (2014): treat random walks on a graph like sentences, and apply Word2Vec.

Word2Vec Skip-Gram Review

Word2Vec learns word embeddings by predicting context words given a center word. For a sentence like "the quick brown fox jumps," the skip-gram model predicts "the", "quick", "brown", "jumps" given center word "fox."

The objective maximizes the log probability of context words:

L=tcjc,j0logP(wt+jwt)\mathcal{L} = \sum_{t} \sum_{-c \leq j \leq c, j \neq 0} \log P(w_{t+j} \mid w_t)

where P(wOwI)=exp(vwOvwI)wexp(vwvwI)P(w_O \mid w_I) = \frac{\exp(\mathbf{v}_{w_O}^\top \mathbf{v}_{w_I})}{\sum_{w} \exp(\mathbf{v}_w^\top \mathbf{v}_{w_I})} (softmax, approximated with negative sampling in practice).

Words that appear in similar contexts get similar embeddings. This captures semantic similarity from co-occurrence patterns.

DeepWalk

Perozzi et al. (2014) made the analogy precise:

  1. Generate random walks of fixed length TT starting from each node
  2. Treat each walk as a "sentence" of node IDs
  3. Apply skip-gram (Word2Vec) to learn node embeddings

Nodes that co-occur frequently in random walks get similar embeddings. Since random walks tend to stay in local neighborhoods, nearby nodes in the graph get similar embeddings.

The connection to spectral methods: DeepWalk implicitly factorizes a matrix related to the Pointwise Mutual Information (PMI) of node co-occurrences in random walks:

DeepWalklog(vol(G)2EPMI(u,v))\text{DeepWalk} \approx \log\left(\frac{\text{vol}(G)}{2|E|} \cdot \text{PMI}(u, v)\right)

where vol(G)=vdv=2E\text{vol}(G) = \sum_v d_v = 2|E| for undirected graphs, and the PMI captures how often uu and vv co-occur relative to random chance.

import numpy as np
import random
from collections import defaultdict
from gensim.models import Word2Vec

def random_walk(G_adj, start_node, walk_length):
"""Generate a single random walk starting from start_node."""
walk = [start_node]
current = start_node
for _ in range(walk_length - 1):
neighbors = G_adj[current]
if not neighbors:
break # isolated node
current = random.choice(neighbors)
walk.append(current)
return [str(n) for n in walk] # Word2Vec expects strings

def generate_walks(adj_list, num_walks, walk_length):
"""Generate all random walks."""
walks = []
nodes = list(adj_list.keys())
for _ in range(num_walks):
random.shuffle(nodes)
for node in nodes:
walk = random_walk(adj_list, node, walk_length)
walks.append(walk)
return walks

def deepwalk(adj_list, embedding_dim=128, num_walks=10, walk_length=80,
window_size=10, workers=4):
"""Full DeepWalk pipeline."""
print(f"Generating {num_walks} walks of length {walk_length}...")
walks = generate_walks(adj_list, num_walks, walk_length)

print(f"Training Word2Vec on {len(walks)} walks...")
model = Word2Vec(
sentences=walks,
vector_size=embedding_dim,
window=window_size,
min_count=0, # include all nodes even if rare
sg=1, # skip-gram (vs CBOW)
workers=workers,
epochs=5,
negative=5, # negative sampling
)

# Extract embedding matrix: node_id -> embedding
nodes = list(adj_list.keys())
embeddings = np.zeros((len(nodes), embedding_dim))
for i, node in enumerate(nodes):
embeddings[i] = model.wv[str(node)]

return embeddings, nodes

# Build adjacency list from PyG Cora
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/cora', name='Cora')
data = dataset[0]

adj_list = defaultdict(list)
edge_index = data.edge_index.numpy()
for src, dst in zip(edge_index[0], edge_index[1]):
adj_list[src].append(dst)

embeddings, nodes = deepwalk(adj_list, embedding_dim=64, num_walks=10, walk_length=40)
print(f"Embeddings shape: {embeddings.shape}") # [2708, 64]

Node2Vec - Biased Random Walks

Grover & Leskovec (2016) recognized that random walks make a specific choice: uniform sampling among neighbors. This biases the walk toward certain exploration strategies.

BFS-like walks explore the neighborhood of the current position - they capture local structural equivalence (nodes with similar local structure get similar embeddings).

DFS-like walks explore further away from the start - they capture community membership (nodes in the same community get similar embeddings).

Node2Vec introduces two parameters pp and qq that control this trade-off via a biased transition probability. When the walk is at node vv and came from node uu, the probability of transitioning to neighbor xx is:

πvx=αpq(u,x)wvx\pi_{vx} = \alpha_{pq}(u, x) \cdot w_{vx}

where:

αpq(u,x)={1/pif dux=0(return to u)1if dux=1(neighbor of both u and v)1/qif dux=2(further from u)\alpha_{pq}(u, x) = \begin{cases} 1/p & \text{if } d_{ux} = 0 \quad (\text{return to } u) \\ 1 & \text{if } d_{ux} = 1 \quad (\text{neighbor of both } u \text{ and } v) \\ 1/q & \text{if } d_{ux} = 2 \quad (\text{further from } u) \end{cases}

duxd_{ux} is the shortest path distance between uu and xx.

  • pp (return parameter): p<1p < 1 encourages return to uu (local exploration, BFS-like)
  • qq (in-out parameter): q<1q < 1 encourages exploration away from uu (DFS-like); q>1q > 1 stays close
def node2vec_walk(G_adj, start_node, walk_length, p=1.0, q=1.0):
"""
Biased random walk for Node2Vec.
p: return parameter. q: in-out parameter.
"""
walk = [start_node]
if walk_length == 1:
return [str(start_node)]

# Second node: uniform among neighbors
neighbors = G_adj[start_node]
if not neighbors:
return [str(start_node)]

prev = start_node
current = random.choice(neighbors)
walk.append(current)

for _ in range(walk_length - 2):
neighbors_curr = G_adj[current]
if not neighbors_curr:
break

# Biased transition: depends on distance to prev
probs = []
for x in neighbors_curr:
if x == prev:
probs.append(1.0 / p) # return to prev
elif x in G_adj[prev]:
probs.append(1.0) # common neighbor of prev and current
else:
probs.append(1.0 / q) # explore outward

# Normalize
total = sum(probs)
probs = [pr / total for pr in probs]

# Sample
next_node = random.choices(neighbors_curr, weights=probs)[0]
walk.append(next_node)
prev = current
current = next_node

return [str(n) for n in walk]

def node2vec(adj_list, embedding_dim=128, num_walks=10, walk_length=80,
p=1.0, q=1.0, window_size=10):
"""Full Node2Vec pipeline."""
walks = []
nodes = list(adj_list.keys())
for _ in range(num_walks):
random.shuffle(nodes)
for node in nodes:
walk = node2vec_walk(adj_list, node, walk_length, p=p, q=q)
walks.append(walk)

model = Word2Vec(
sentences=walks,
vector_size=embedding_dim,
window=window_size,
min_count=0,
sg=1,
workers=4,
epochs=5,
negative=5,
)

embeddings = np.zeros((len(nodes), embedding_dim))
for i, node in enumerate(nodes):
embeddings[i] = model.wv[str(node)]
return embeddings

# BFS-like (structural equivalence): p=1, q=0.5
emb_bfs = node2vec(adj_list, p=1, q=0.5)

# DFS-like (community detection): p=1, q=2
emb_dfs = node2vec(adj_list, p=1, q=2)

When to use each:

  • q<1q < 1 (DFS-like): community detection, when you want nodes in the same cluster to be close
  • q>1q > 1 (BFS-like): structural role detection, when you want nodes with the same structural role (e.g., all bridges) to be close regardless of which part of the graph they are in

LINE - Large-scale Information Network Embedding

Tang et al. (2015) argued that co-occurrence in random walks is an indirect proxy for two distinct notions of similarity. LINE explicitly optimizes both:

First-order proximity: directly connected nodes should have similar embeddings.

L1=(u,v)EwuvlogP1(u,v)\mathcal{L}_1 = -\sum_{(u,v) \in E} w_{uv} \log P_1(u, v)

P1(u,v)=11+exp(uuuv)(sigmoid)P_1(u, v) = \frac{1}{1 + \exp(-\mathbf{u}_u^\top \mathbf{u}_v)} \quad \text{(sigmoid)}

Second-order proximity: nodes with similar neighborhoods should have similar embeddings. Node uu defines a conditional distribution over its neighbors:

P2(vu)=exp(uvuu)kexp(ukuu)P_2(v \mid u) = \frac{\exp(\mathbf{u}_v^\top \mathbf{u}_u)}{\sum_{k} \exp(\mathbf{u}_k^\top \mathbf{u}_u)}

The second-order loss minimizes KL divergence between the empirical neighborhood distribution and the predicted one.

LINE is trained with negative sampling and asynchronous SGD - important for web-scale graphs.

Spectral Embeddings

Another approach: represent nodes using eigenvectors of the graph Laplacian. These capture global graph structure without any random walk.

Laplacian Eigenmaps

The kk smallest non-zero eigenvectors of the graph Laplacian L=DAL = D - A give a smooth embedding of nodes into Rk\mathbb{R}^k:

Lui=λiui,0=λ1λ2λnL \mathbf{u}_i = \lambda_i \mathbf{u}_i, \quad 0 = \lambda_1 \leq \lambda_2 \leq \ldots \leq \lambda_n

The embedding of node vv is: [u2(v),u3(v),,uk+1(v)][\mathbf{u}_2(v), \mathbf{u}_3(v), \ldots, \mathbf{u}_{k+1}(v)]

The key property: the Laplacian eigenvector corresponding to the smallest non-zero eigenvalue (u2\mathbf{u}_2, called the Fiedler vector) gives the optimal graph bisection - nodes assigned to opposite sides of the cut are maximally separated. Using multiple eigenvectors extends this to multiple clusters.

import numpy as np
from scipy.linalg import eigh
import matplotlib.pyplot as plt

def spectral_embedding(A, k=2):
"""
Compute k-dimensional spectral embedding from adjacency matrix A.
Returns embedding matrix of shape [n_nodes, k].
"""
n = A.shape[0]
# Degree matrix
degrees = A.sum(axis=1)
D = np.diag(degrees)

# Normalized Laplacian: L_sym = D^{-1/2} L D^{-1/2}
d_inv_sqrt = np.where(degrees > 0, 1.0 / np.sqrt(degrees), 0.0)
D_inv_sqrt = np.diag(d_inv_sqrt)
L = D - A
L_sym = D_inv_sqrt @ L @ D_inv_sqrt

# Compute k+1 smallest eigenvalues/vectors
# eigh returns them sorted ascending
eigenvalues, eigenvectors = eigh(L_sym)

# Skip eigenvector 0 (constant vector, eigenvalue 0)
# Use eigenvectors 1 through k
embedding = eigenvectors[:, 1:k+1]
return embedding, eigenvalues[1:k+1]

# Small example: two clusters connected by one bridge
A_clustered = np.zeros((8, 8))
# Cluster 1: nodes 0-3 (dense clique)
for i in range(4):
for j in range(4):
if i != j:
A_clustered[i, j] = 1
# Cluster 2: nodes 4-7 (dense clique)
for i in range(4, 8):
for j in range(4, 8):
if i != j:
A_clustered[i, j] = 1
# Bridge edge
A_clustered[3, 4] = A_clustered[4, 3] = 1

emb, evals = spectral_embedding(A_clustered, k=2)
print("Eigenvalues (Fiedler):", evals)
print("Embedding shape:", emb.shape)

# In the 2D embedding, the two clusters will be clearly separated
# along the Fiedler vector (first eigenvector)
colors = ['blue']*4 + ['red']*4
plt.figure(figsize=(6,6))
plt.scatter(emb[:, 0], emb[:, 1], c=colors, s=100)
for i, (x, y) in enumerate(emb):
plt.annotate(str(i), (x, y), textcoords="offset points", xytext=(5,5))
plt.title("Spectral Embedding - 2 Clusters")
plt.xlabel("Fiedler vector")
plt.ylabel("2nd eigenvector")
plt.savefig("spectral_embedding.png", dpi=150)

Advantages of spectral embeddings:

  • No hyperparameters (unlike Node2Vec's p,qp, q, walk length)
  • Global structure - eigenvectors capture graph-wide patterns, not just local neighborhoods
  • Deterministic and reproducible

Disadvantages:

  • Eigendecomposition is O(n3)O(n^3) - completely infeasible for large graphs
  • Cannot generalize to new nodes (adding a node requires recomputing all eigenvectors)
  • Does not use node features

Comparison: Shallow Embedding Methods

MethodCaptures structureUses featuresInductiveScalableParameters
Manual featuresPartially (structural stats)Partially (separate)YesYes0
DeepWalkYes (random walk proximity)NoNoYes (Word2Vec SGD)n×dn \times d
Node2VecYes (controllable BFS/DFS)NoNoYesn×dn \times d
LINEYes (1st + 2nd order)NoNoYesn×dn \times d
Spectral (Laplacian eigenmaps)Yes (global)NoNoNo (O(n3)O(n^3))0
GCN/GNNYes (learned)YesPartiallyModerateL×d2L \times d^2

The key observation: all shallow methods (rows 2–5) are transductive - they learn an embedding table indexed by node ID. If a new node is added after training, there is no embedding for it. This is the fundamental limitation that GNNs were designed to address.

Limitations of Shallow Methods

No Generalization to Unseen Nodes

A Node2Vec model trained on Monday's graph has no embeddings for users who joined on Tuesday. To get Tuesday's embeddings, you must re-run the entire pipeline: new random walks + retrain Word2Vec. For a system that adds millions of users daily (Facebook, TikTok), this is completely infeasible.

GNNs fix this by learning aggregation functions rather than embedding lookup tables. Given any new node's features and neighbors' features, the GNN computes its embedding on the fly.

No Feature Utilization

DeepWalk and Node2Vec only look at graph topology - they completely ignore node content. For citation networks, the text of the paper is a primary signal. For molecular graphs, the atom type and charge are essential. Shallow methods that ignore features leave massive information on the table.

A GNN starts with node features XX and learns how to aggregate them - topology and content are fused from the first layer.

Depth of Interactions

Random walk methods look at multi-hop co-occurrence, but only through the lens of path frequency. They cannot learn nonlinear combinations of multi-hop features. A GNN with 3 layers learns that "nodes connected to a node that is connected to a high-PageRank node" should have a specific representation - in a way that is end-to-end optimized for the task.

How to Featurize Edges and Graphs

Beyond node features, edges and entire graphs can also be featurized.

Edge Features

Edges can carry rich information that node features miss:

DomainEdge typeFeatures
MolecularChemical bondBond type (1/2/3/aromatic), bond length, stereochemistry
TransactionPaymentAmount, currency, timestamp, merchant category
CitationReferenceSection location, sentiment (positive/negative cite), self-citation flag
SocialFriendshipDuration, interaction frequency, mutual friends count

In PyG, edge features are stored in data.edge_attr of shape [num_edges, edge_feature_dim]. GNN layers that support edge features (e.g., NNConv, EdgeConv, ChebConv variants) pass edge attributes into the message function.

from torch_geometric.data import Data
import torch

# Molecule with edge features
# Nodes: C, C, O (3 atoms)
# Edges: C-C (single bond), C-O (single bond), both directions
x = torch.tensor([
[6, 0, 2, 0, 1], # Carbon: atomic_num=6, charge=0, degree=2, H_count=0, aromatic=1
[6, 0, 2, 0, 1], # Carbon
[8, 0, 1, 1, 0], # Oxygen: atomic_num=8, charge=0, degree=1, H_count=1, aromatic=0
], dtype=torch.float)

edge_index = torch.tensor([
[0, 1, 1, 0, 0, 2, 2, 0], # source
[1, 0, 0, 1, 2, 0, 0, 2], # destination
], dtype=torch.long)

# Bond features: [bond_type, is_aromatic, is_in_ring]
edge_attr = torch.tensor([
[1.5, 1, 1], # C-C aromatic bond in ring
[1.5, 1, 1],
[1.5, 1, 1],
[1.5, 1, 1],
[1.0, 0, 0], # C-O single bond, not in ring
[1.0, 0, 0],
[1.0, 0, 0],
[1.0, 0, 0],
], dtype=torch.float)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=torch.tensor([1]))
print(data)

Graph-Level Features

For graph classification, you also want to represent global properties:

  • Simple pooling (baseline): mean/sum/max of all node embeddings after GNN layers
  • Graph statistics: number of nodes, edges, diameter, density, number of connected components
  • Global node (virtual): add a "master node" connected to all other nodes; its final embedding represents the whole graph
  • Hierarchical pooling: DiffPool and similar methods learn to cluster nodes and progressively coarsen the graph
from torch_geometric.nn import global_mean_pool, global_add_pool, global_max_pool
import torch

# After GNN forward pass:
# x: [total_nodes, embedding_dim] - all nodes from a batch
# batch: [total_nodes] - graph ID for each node

def graph_readout(x, batch, method='mean'):
"""Aggregate node embeddings to graph level."""
if method == 'mean':
return global_mean_pool(x, batch) # [batch_size, embedding_dim]
elif method == 'sum':
return global_add_pool(x, batch)
elif method == 'max':
return global_max_pool(x, batch)
elif method == 'combined':
# Stack multiple readouts for richer representation
mean = global_mean_pool(x, batch)
add = global_add_pool(x, batch)
mx = global_max_pool(x, batch)
return torch.cat([mean, add, mx], dim=1) # [batch_size, 3*embedding_dim]

Bag-of-Nodes Baseline

Before running a GNN, always establish a strong baseline: ignore graph structure entirely and just aggregate node features:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

# MUTAG molecular dataset
dataset = TUDataset(root='/tmp/mutag', name='MUTAG')

class BagOfNodesBaseline(nn.Module):
"""
Graph classification without any message passing.
Just mean-pool node features and classify.
This is the lower bound - a GNN should beat this.
"""
def __init__(self, in_channels, hidden, out_channels):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_channels, hidden),
nn.ReLU(),
nn.Linear(hidden, out_channels),
)

def forward(self, x, edge_index, batch):
# Completely ignore edge_index - no graph structure used
# Just pool node features and classify
pooled = global_mean_pool(x, batch) # [batch_size, in_channels]
return self.mlp(pooled)

model = BagOfNodesBaseline(dataset.num_features, 64, dataset.num_classes)
print("Bag-of-nodes baseline established.")
print("A GNN that doesn't beat this is not learning from graph structure.")

This baseline is important: if your GNN does not significantly outperform it, your task might not actually require graph structure, or your GNN is not learning from the edges.

The Message Passing Framework - Bridge to GNNs

All modern GNNs can be understood as instances of a general message passing framework (Gilmer et al., 2017 - MPNN):

muv(k)=Mk ⁣(hv(k1),hu(k1),euv)\mathbf{m}_{uv}^{(k)} = M_k\!\left(\mathbf{h}_v^{(k-1)}, \mathbf{h}_u^{(k-1)}, \mathbf{e}_{uv}\right)

hv(k)=Uk ⁣(hv(k1),uN(v)muv(k))\mathbf{h}_v^{(k)} = U_k\!\left(\mathbf{h}_v^{(k-1)},\, \sum_{u \in \mathcal{N}(v)} \mathbf{m}_{uv}^{(k)}\right)

where:

  • MkM_k is the message function - what node uu sends to node vv via edge (u,v)(u,v)
  • UkU_k is the update function - how vv integrates received messages with its own state
  • euv\mathbf{e}_{uv} are edge features (optional)
  • The aggregation (sum here, but can be mean, max) is permutation-invariant

This unifies GCN, GAT, GraphSAGE, GIN, and many more:

  • GCN: Mk(hv,hu,e)=1dvduhuM_k(h_v, h_u, e) = \frac{1}{\sqrt{d_v d_u}} h_u; Uk(hv,m)=σ(Wm)U_k(h_v, m) = \sigma(Wm)
  • GAT: Mk(hv,hu,e)=αvuWhuM_k(h_v, h_u, e) = \alpha_{vu} W h_u; Uk(hv,m)=σ(m)U_k(h_v, m) = \sigma(m)
  • GraphSAGE: Mk(hv,hu,e)=huM_k(h_v, h_u, e) = h_u; Uk(hv,m)=σ(W[hvAGG(m)])U_k(h_v, m) = \sigma(W[h_v \| \text{AGG}(m)])

Evaluating Node Embeddings

After training any embedding method (Node2Vec, GCN, etc.), evaluate quality via downstream tasks:

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, f1_score
import numpy as np

def evaluate_embeddings(embeddings, labels, train_mask, val_mask, test_mask):
"""
Evaluate node embeddings via linear probing:
fit logistic regression on train, evaluate on val/test.
This isolates embedding quality from classifier complexity.
"""
X_train = embeddings[train_mask]
y_train = labels[train_mask]
X_val = embeddings[val_mask]
y_val = labels[val_mask]
X_test = embeddings[test_mask]
y_test = labels[test_mask]

# Normalize embeddings
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

# Logistic regression
clf = LogisticRegression(max_iter=1000, C=1.0)
clf.fit(X_train, y_train)

train_acc = accuracy_score(y_train, clf.predict(X_train))
val_acc = accuracy_score(y_val, clf.predict(X_val))
test_acc = accuracy_score(y_test, clf.predict(X_test))
test_f1 = f1_score(y_test, clf.predict(X_test), average='macro')

print(f"Linear probe - Train: {train_acc:.3f} | Val: {val_acc:.3f} | Test: {test_acc:.3f} | F1: {test_f1:.3f}")
return test_acc, test_f1

# Typical results on Cora:
# Manual features (degree, clustering, PageRank): ~55% test accuracy
# DeepWalk (128-dim, 10 walks, len 80): ~70-74% test accuracy
# Node2Vec (optimized p,q): ~73-76% test accuracy
# GCN (2-layer, 64-dim): ~81-83% test accuracy
# GAT (2-layer, 8 heads): ~82-85% test accuracy

YouTube Resources

VideoChannelWhy Watch
Node2Vec: Scalable Feature Learning for NetworksStanford CS224WOriginal lecture by Jure Leskovec
DeepWalk and Node2Vec ExplainedAleksa GordicClear animated walkthrough with code
Graph Representation Learning - HamiltonMontreal AIAuthor of the textbook covers shallow embeddings
Representation Learning on Graphs and Manifolds (ICLR 2019)ICLRWorkshop overview connecting all methods
Knowledge Graph Embeddings - TransE and BeyondStanford CS224WExtends node embeddings to relational graphs

Production Engineering Notes

:::note Always Run DeepWalk/Node2Vec as Baselines Before training a full GNN, establish random walk embedding baselines. If Node2Vec with a logistic regression head achieves 78% accuracy and your GNN achieves 79%, the GNN may not be worth the added complexity. Always know what the simpler method achieves. :::

:::warning Node2Vec Hyperparameters Require Tuning Walk length, number of walks, window size, pp, and qq all interact. A common mistake is using default parameters (usually tuned for social networks) on molecular or knowledge graphs. For molecular graphs, short walks (length 10–20) often outperform long walks (80+) because molecular graphs have small diameter. Grid search p{0.25,0.5,1,2,4}p \in \{0.25, 0.5, 1, 2, 4\} and q{0.25,0.5,1,2,4}q \in \{0.25, 0.5, 1, 2, 4\}. :::

:::danger Shallow Embeddings Cannot Generalize to New Nodes A Node2Vec model trained on yesterday's graph has a fixed embedding table. New nodes added today have no representation. If your system needs real-time embeddings for new entities (new products, new users, new documents), shallow methods are architecturally incompatible - use a GNN with inductive inference (GraphSAGE, PinSage). Failing to account for this leads to systems that need full retraining every day, which quickly becomes infeasible at scale. :::

Interview Q&A

Q1: Compare DeepWalk and Node2Vec. What does the pp/qq parameter control and when would you set q<1q < 1 vs q>1q > 1?

Both methods generate random walks on a graph and apply Word2Vec skip-gram to learn node embeddings from co-occurrence. DeepWalk uses uniform random walks - at each step, sample uniformly among all neighbors. Node2Vec uses biased walks controlled by pp (return parameter) and qq (in-out parameter). When the walk is at node vv coming from node uu: returning to uu has weight 1/p1/p; moving to a common neighbor of uu and vv has weight 1; moving further from uu has weight 1/q1/q. Setting q<1q < 1 encourages DFS-like outward exploration, making nodes in the same community similar. Setting q>1q > 1 encourages BFS-like local exploration, making nodes with similar local structure (same degree, same motifs) similar regardless of global position. Use q<1q < 1 for community detection tasks; use q>1q > 1 for structural role detection (finding all "bridges" or all "hubs" regardless of where they appear).

Q2: What does DeepWalk implicitly factorize? What does this connection to matrix factorization reveal?

Yang et al. (2015) showed that skip-gram with negative sampling implicitly factorizes a shifted PMI matrix. For DeepWalk, this is approximately log(vol(G)2EPMI(u,v))logk\log\left(\frac{\text{vol}(G)}{2|E|} \cdot \text{PMI}(u,v)\right) - \log k where kk is the number of negative samples and PMI is the pointwise mutual information of co-occurrence in random walks. This reveals that DeepWalk is a specific weighted matrix factorization - the weighting is implicit in the random walk distribution. The connection shows that all shallow embedding methods (DeepWalk, LINE, Spectral) are solving variations of the same problem: approximate factorization of different matrices derived from the graph. GNNs transcend this by conditioning embeddings on node features and using task supervision rather than graph structure alone.

Q3: What are the fundamental limitations of spectral graph embeddings, and how do GNNs overcome them?

Spectral embeddings use the kk smallest non-trivial eigenvectors of the normalized Laplacian as node coordinates. Three fundamental limitations: (1) Scalability - eigendecomposition is O(n3)O(n^3), infeasible for graphs with more than ~10,000 nodes. (2) Inductivity - adding a new node to the graph requires recomputing all eigenvectors from scratch; there is no way to compute a new node's embedding from its features and neighbors. (3) Feature blindness - eigenvectors only use graph topology, completely ignoring node features (text, atom type, etc.). GNNs overcome all three: (1) Message passing is O(kEd2)O(k \cdot |E| \cdot d^2) - linear in edges; (2) Aggregation functions generalize to new nodes immediately; (3) GNNs take node features as input and fuse them with structural information.

Q4: How do you featurize edges in a GNN, and why does it matter for molecular property prediction?

Edge features are stored as edge_attr in PyG - a tensor of shape [num_edges, edge_feature_dim]. In the message passing framework, edge features enter the message function: muv=M(hu,hv,euv)m_{uv} = M(h_u, h_v, e_{uv}). The simplest approach is concatenation: muv=MLP([huhveuv])m_{uv} = \text{MLP}([h_u \| h_v \| e_{uv}]). For molecular graphs, bond type (single, double, triple, aromatic), bond length, and stereochemistry are critical - they determine the molecule's reactivity and 3D geometry. A model that treats C=C (double bond, planar, shorter) the same as C-C (single bond, rotatable, longer) loses fundamental chemical information. Empirically, adding bond features improves molecular property prediction by 3–8% on benchmarks like ogbg-molhiv.

Q5: Describe the bag-of-nodes baseline and explain when a GNN would fail to beat it.

The bag-of-nodes baseline computes a graph-level representation by simply averaging (or summing) node features, then classifying with an MLP - no message passing, no edge information. A GNN that fails to beat this baseline is not learning from graph structure. This happens when: (1) The task is determined entirely by the distribution of node features and graph structure is irrelevant (e.g., predicting the average atom mass of a molecule from its atom types). (2) The GNN is over-smoothing - after many layers, all embeddings are nearly identical, making the learned embeddings equivalent to a mean of initial features. (3) The training set is too small to learn meaningful structural patterns. Always run this baseline and ensure your GNN beats it by a meaningful margin before concluding that graph structure is being utilized.

Q6: Explain the Pointwise Mutual Information connection in random walk embeddings. Why is sum aggregation more expressive than mean?

The PMI of a node pair (u,v)(u,v) in random walks measures: PMI(u,v)=logP(u,v)P(u)P(v)\text{PMI}(u,v) = \log \frac{P(u,v)}{P(u)P(v)} where P(u,v)P(u,v) is the co-occurrence probability and P(u),P(v)P(u), P(v) are marginals. High PMI means uu and vv co-occur much more than chance - they are structurally close. Skip-gram embedding training pushes co-occurring nodes' embeddings together and separates non-co-occurring nodes. This is mathematically equivalent to factorizing the PMI matrix. For expressivity: mean aggregation cannot distinguish a node with neighbors having features {1,2}\{1, 2\} from a node with neighbors {1.5,1.5}\{1.5, 1.5\} - both give mean 1.5. Sum aggregation can distinguish them: sums are 3 vs 3, but the distributions are different if we consider pairwise interactions. In practice, GIN's sum aggregation allows the model to count the number of neighbors with each feature value, which mean and max cannot do.

:::tip 🎮 Interactive Playground

Visualize this concept: Try the Graph Explorer demo on the EngineersOfAI Playground - no code required.

:::

© 2026 EngineersOfAI. All rights reserved.