scientific-skills/torch-geometric/references/scaling.md
Techniques for training GNNs on large graphs that don't fit in GPU memory, multi-GPU training, and performance optimization.
The primary approach for large single-graph training. Recursively samples a fixed number of neighbors per hop, bounding the computation graph.
from torch_geometric.loader import NeighborLoader
loader = NeighborLoader(
data,
num_neighbors=[15, 10], # Max neighbors per hop (hop 1: 15, hop 2: 10)
batch_size=1024, # Number of seed nodes per batch
input_nodes=data.train_mask, # Which nodes to sample from
shuffle=True,
num_workers=4, # Parallel data loading
replace=False, # Sample without replacement
)
Key parameters:
num_neighbors: List of max neighbors per hop. Length should match GNN depth. Use -1 to sample all neighbors for a hop.input_nodes: Seed nodes — can be a mask, tensor of indices, or ('node_type', mask) tuple for hetero graphs.subgraph_type: "directional" (default), "bidirectional" (add reverse edges), or "induced" (full induced subgraph).disjoint: If True, don't fuse neighborhoods across seed nodes (uses more memory but can be needed).Training pattern:
model = GraphSAGE(in_channels, hidden_channels, out_channels, num_layers=2)
for batch in loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index)
# CRITICAL: only first batch_size nodes are seed nodes
loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
Important details:
batch.batch_size nodes are the seed nodesbatch.n_id maps local indices back to original node IDslen(num_neighbors) == num_gnn_layers for efficiencySamples subgraphs around supervision edges:
from torch_geometric.loader import LinkNeighborLoader
loader = LinkNeighborLoader(
data,
num_neighbors=[20, 10],
edge_label_index=train_data.edge_label_index,
edge_label=train_data.edge_label,
batch_size=256,
neg_sampling_ratio=1.0,
shuffle=True,
)
Samples a fixed number of nodes per type per hop, following HGT paper:
from torch_geometric.loader import HGTLoader
loader = HGTLoader(
data,
num_samples=[512] * 2, # Nodes per type per hop
batch_size=128,
input_nodes=('paper', data['paper'].train_mask),
)
Partitions the graph into clusters, trains on full subgraphs. Better for deeper GNNs since messages flow freely within clusters:
from torch_geometric.loader import ClusterData, ClusterLoader
cluster_data = ClusterData(data, num_parts=1500)
loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True)
for batch in loader:
# batch is a full subgraph — no slicing needed
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(out[batch.train_mask], batch.y[batch.train_mask])
Samples subgraphs via random walks, nodes, or edges with importance-based normalization:
from torch_geometric.loader import GraphSAINTRandomWalkSampler
loader = GraphSAINTRandomWalkSampler(
data, batch_size=6000, walk_length=2, num_steps=5,
)
Extracts K-hop induced subgraphs around seed nodes — decouples depth from scope:
from torch_geometric.loader import ShaDowKHopSampler
loader = ShaDowKHopSampler(
data, depth=2, num_neighbors=5, batch_size=64,
input_nodes=data.train_mask,
)
Standard PyTorch DDP works with PyG. Each GPU gets a partition of the seed nodes:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torch.multiprocessing as mp
def run(rank, world_size, dataset):
# Initialize process group
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
data = dataset[0]
# Split training nodes across GPUs
train_idx = data.train_mask.nonzero().view(-1)
train_idx = train_idx.split(train_idx.size(0) // world_size)[rank]
loader = NeighborLoader(
data,
input_nodes=train_idx,
num_neighbors=[25, 10],
batch_size=1024,
num_workers=4,
shuffle=True,
)
# Wrap model in DDP
model = GraphSAGE(...).to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
model.train()
for batch in loader:
batch = batch.to(rank)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)[:batch.batch_size]
loss = F.cross_entropy(out, batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
# Synchronize before evaluation
dist.barrier()
if rank == 0:
# Evaluate on rank 0 only
...
dist.destroy_process_group()
# Launch
if __name__ == '__main__':
dataset = Reddit('./data/Reddit')
world_size = torch.cuda.device_count()
mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True)
Key points:
mp.spawn() — data auto-moves to shared memorydist.barrier() to synchronize before evaluationdist.destroy_process_group()PyG provides Lightning wrappers for minimal boilerplate:
from torch_geometric.data import LightningNodeData
datamodule = LightningNodeData(
data,
input_train_nodes=data.train_mask,
input_val_nodes=data.val_mask,
input_test_nodes=data.test_mask,
loader='neighbor',
num_neighbors=[25, 10],
batch_size=1024,
)
# Use with any Lightning Trainer
trainer = L.Trainer(devices=4, accelerator='gpu', strategy='ddp')
trainer.fit(model, datamodule)
Also available: LightningLinkData for link prediction, LightningDataset for graph-level tasks.
PyG supports torch.compile for faster execution:
model = GCN(...)
model = torch.compile(model)
# Works with standard training loops
out = model(data.x, data.edge_index)
What works:
Limitations:
torch.compile(model, dynamic=True) if batch graph sizes vary significantlynum_workers=4 (or more) in data loaders for CPU-side parallelismpin_memory=True in loaders for faster CPU-to-GPU transferSparseTensor from torch_sparse instead of edge_index for faster message passing on some layerstorch_geometric.profile to measure time and memory of individual layersfrom torch.amp import autocast, GradScaler
scaler = GradScaler()
with autocast('cuda'):
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size])
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
[15, 10] for 2-layer GNNs.batch_size outputs matter — don't compute metrics on sampled-only nodes.