scientific-skills/torch-geometric/references/link_prediction.md
Link prediction is the task of predicting missing or future edges in a graph. Common applications: social network friend suggestion, knowledge graph completion, drug-target interaction.
Use RandomLinkSplit to split edges into train/val/test while maintaining graph structure:
import torch_geometric.transforms as T
transform = T.RandomLinkSplit(
num_val=0.1, # 10% of edges for validation
num_test=0.1, # 10% of edges for test
is_undirected=True, # Set True for undirected graphs
add_negative_train_samples=False, # Generate negatives on-the-fly during training
neg_sampling_ratio=1.0, # 1 negative per positive edge
)
train_data, val_data, test_data = transform(data)
After splitting, each split contains:
edge_index: message-passing edges (train edges only — no data leakage)edge_label_index: supervision edges [2, num_supervision_edges] — the edges to predictedge_label: binary labels — 1 for positive (real) edges, 0 for negative (fake) edgesFor the training split with add_negative_train_samples=False, only positive edges are in edge_label_index and negatives are sampled during training. Val/test splits always include both positive and negative edges.
The standard approach:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class LinkEncoder(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
def decode(z, edge_label_index):
"""Dot-product decoder: score = z_src . z_dst for each edge."""
src, dst = edge_label_index
return (z[src] * z[dst]).sum(dim=1)
from torch_geometric.utils import negative_sampling
model = LinkEncoder(data.num_features, 128, 64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(train_data):
model.train()
optimizer.zero_grad()
# Encode using message-passing edges only
z = model(train_data.x, train_data.edge_index)
# Sample negative edges for this batch
neg_edge_index = negative_sampling(
edge_index=train_data.edge_index,
num_nodes=train_data.num_nodes,
num_neg_samples=train_data.edge_label_index.size(1),
)
# Combine positive and negative supervision edges
edge_label_index = torch.cat([train_data.edge_label_index, neg_edge_index], dim=1)
edge_label = torch.cat([
torch.ones(train_data.edge_label_index.size(1)),
torch.zeros(neg_edge_index.size(1)),
])
# Decode and compute loss
pred = decode(z, edge_label_index)
loss = F.binary_cross_entropy_with_logits(pred, edge_label)
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def test(data_split):
model.eval()
z = model(data_split.x, data_split.edge_index)
pred = decode(z, data_split.edge_label_index).sigmoid()
# AUC is the standard metric for link prediction
from sklearn.metrics import roc_auc_score
return roc_auc_score(data_split.edge_label.cpu(), pred.cpu())
PyG provides GAE and VGAE for unsupervised link prediction:
from torch_geometric.nn import GAE, VGAE, GCNConv
class Encoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, 2 * out_channels)
self.conv2 = GCNConv(2 * out_channels, out_channels)
# For VGAE, also define conv_mu and conv_logstd
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
# GAE wraps your encoder and provides train/test methods
model = GAE(Encoder(data.num_features, 64))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
optimizer.zero_grad()
z = model.encode(train_data.x, train_data.edge_index)
loss = model.recon_loss(z, train_data.edge_label_index)
# For VGAE, add KL divergence:
# loss = loss + (1 / data.num_nodes) * model.kl_loss()
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def test(data_split):
model.eval()
z = model.encode(data_split.x, data_split.edge_index)
return model.test(z, data_split.edge_label_index[0], # positive edges
data_split.edge_label_index[1]) # negative edges
For VGAE, the encoder must return mu and logstd instead of a single embedding. Use the VGAE-specific encoder pattern:
class VariationalEncoder(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, 2 * out_channels)
self.conv_mu = GCNConv(2 * out_channels, out_channels)
self.conv_logstd = GCNConv(2 * out_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)
model = VGAE(VariationalEncoder(data.num_features, 64))
For large graphs, use LinkNeighborLoader — it samples subgraphs around supervision edges:
from torch_geometric.loader import LinkNeighborLoader
train_loader = LinkNeighborLoader(
data=train_data,
num_neighbors=[20, 10], # Sample neighbors per hop
edge_label_index=train_data.edge_label_index,
edge_label=train_data.edge_label,
batch_size=128, # Number of supervision edges per batch
neg_sampling_ratio=1.0, # 1 negative per positive
shuffle=True,
)
for batch in train_loader:
# batch.edge_label_index: supervision edges (pos + neg)
# batch.edge_label: 1 for positive, 0 for negative
# batch.edge_index: message-passing edges (from neighbor sampling)
z = model(batch.x, batch.edge_index)
pred = decode(z, batch.edge_label_index)
loss = F.binary_cross_entropy_with_logits(pred, batch.edge_label)
For heterogeneous graphs (e.g., user-item recommendation):
transform = T.RandomLinkSplit(
num_val=0.1,
num_test=0.1,
neg_sampling_ratio=1.0,
add_negative_train_samples=False,
edge_types=('user', 'rates', 'movie'), # Which edge type to predict
rev_edge_types=('movie', 'rev_rates', 'user'), # Its reverse
)
train_data, val_data, test_data = transform(data)
# Supervision edges are in:
# train_data['user', 'rates', 'movie'].edge_label_index
# train_data['user', 'rates', 'movie'].edge_label
from sklearn.metrics import roc_auc_score, average_precision_score
auc = roc_auc_score(edge_label.cpu(), pred.cpu())
ap = average_precision_score(edge_label.cpu(), pred.cpu())
RandomLinkSplit handles this correctly — edge_index in train_data only contains training edges.is_undirected=True in RandomLinkSplit — otherwise it will treat each direction independently and leak information.