docs/source/advanced/sparse_tensor.rst
The :class:~torch_geometric.nn.conv.message_passing.MessagePassing interface of :pyg:PyG relies on a gather-scatter scheme to aggregate messages from neighboring nodes.
For example, consider the message passing layer
.. math::
\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \textrm{MLP}(\mathbf{x}_j - \mathbf{x}_i),
that can be implemented as:
.. code-block:: python
from torch_geometric.nn import MessagePassing
x = ... # Node features of shape [num_nodes, num_features]
edge_index = ... # Edge indices of shape [2, num_edges]
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr="add")
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j):
return MLP(x_j - x_i)
Under the hood, the :class:~torch_geometric.nn.conv.message_passing.MessagePassing implementation produces a code that looks as follows:
.. code-block:: python
from torch_geometric.utils import scatter
x = ... # Node features of shape [num_nodes, num_features]
edge_index = ... # Edge indices of shape [2, num_edges]
x_j = x[edge_index[0]] # Source node features [num_edges, num_features]
x_i = x[edge_index[1]] # Target node features [num_edges, num_features]
msg = MLP(x_j - x_i) # Compute message for each edge
# Aggregate messages based on target node indices
out = scatter(msg, edge_index[1], dim=0, dim_size=x.size(0), reduce='sum')
While the gather-scatter formulation generalizes to a lot of useful GNN implementations, it has the disadvantage of explicitly materalizing :obj:x_j and :obj:x_i, resulting in a high memory footprint on large and dense graphs.
Luckily, not all GNNs need to be implemented by explicitly materalizing :obj:x_j and/or :obj:x_i.
In some cases, GNNs can also be implemented as a simple-sparse matrix multiplication.
As a general rule of thumb, this holds true for GNNs that do not make use of the central node features :obj:x_i or multi-dimensional edge features when computing messages.
For example, the :class:~torch_geometric.nn.conv.GINConv layer
.. math::
\mathbf{x}^{\prime}_i = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right),
is equivalent to computing
.. math::
\mathbf{X}^{\prime} = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{X} + \mathbf{A}\mathbf{X} \right),
where :math:\mathbf{A} denotes a sparse adjacency matrix of shape :obj:[num_nodes, num_nodes].
This formulation allows to leverage dedicated and fast sparse-matrix multiplication implementations.
In :pyg:null PyG >= 1.6.0, we officially introduce better support for sparse-matrix multiplication GNNs, resulting in a lower memory footprint and a faster execution time.
As a result, we introduce the :class:SparseTensor class (from the :obj:torch_sparse package), which implements fast forward and backward passes for sparse-matrix multiplication based on the "Design Principles for Sparse Matrix Multiplication on the GPU" <https://arxiv.org/abs/1803.08601>_ paper.
Using the :class:SparseTensor class is straightforward and similar to the way :obj:scipy treats sparse matrices:
.. code-block:: python
from torch_sparse import SparseTensor
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=...,
sparse_sizes=(num_nodes, num_nodes))
# value is optional and can be None
# Obtain different representations (COO, CSR, CSC):
row, col, value = adj.coo()
rowptr, col, value = adj.csr()
colptr, row, value = adj.csc()
adj = adj[:100, :100] # Slicing, indexing and masking support
adj = adj.set_diag() # Add diagonal entries
adj_t = adj.t() # Transpose
out = adj.matmul(x) # Sparse-dense matrix multiplication
adj = adj.matmul(adj) # Sparse-sparse matrix multiplication
# Creating SparseTensor instances:
adj = SparseTensor.from_dense(mat)
adj = SparseTensor.eye(100, 100)
adj = SparseTensor.from_scipy(mat)
Our :class:~torch_geometric.nn.conv.message_passing.MessagePassing interface can handle both :obj:torch.Tensor and :class:SparseTensor as input for propagating messages.
However, when holding a directed graph in :class:SparseTensor, you need to make sure to input the transposed sparse matrix to :func:~torch_geometric.nn.conv.message_passing.MessagePassing.propagate:
.. code-block:: python
conv = GCNConv(16, 32)
out1 = conv(x, edge_index)
out2 = conv(x, adj.t())
assert torch.allclose(out1, out2)
conv = GINConv(nn=Sequential(Linear(16, 32), ReLU(), Linear(32, 32)))
out1 = conv(x, edge_index)
out2 = conv(x, adj.t())
assert torch.allclose(out1, out2)
To leverage sparse-matrix multiplications, the :class:~torch_geometric.nn.conv.message_passing.MessagePassing interface introduces the :func:~torch_geometric.nn.conv.message_passing.message_and_aggregate function (which fuses the :func:~torch_geometric.nn.conv.message_passing.message and :func:~torch_geometric.nn.conv.message_passing.aggregate functions into a single computation step), which gets called whenever it is implemented and receives a :class:SparseTensor as input for :obj:edge_index.
With it, the :class:~torch_geometric.nn.conv.GINConv layer can now be implemented as follows:
.. code-block:: python
import torch_sparse
class GINConv(MessagePassing):
def __init__(self):
super().__init__(aggr="add")
def forward(self, x, edge_index):
out = self.propagate(edge_index, x=x)
return MLP((1 + eps) x + out)
def message(self, x_j):
return x_j
def message_and_aggregate(self, adj_t, x):
return torch_sparse.matmul(adj_t, x, reduce=self.aggr)
Playing around with the new :class:SparseTensor format is straightforward since all of our GNNs work with it out-of-the-box.
To convert the :obj:edge_index format to the newly introduced :class:SparseTensor format, you can make use of the :class:torch_geometric.transforms.ToSparseTensor transform:
.. code-block:: python
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
dataset = Planetoid("Planetoid", name="Cora", transform=T.ToSparseTensor())
data = dataset[0]
>>> Data(adj_t=[2708, 2708, nnz=10556], x=[2708, 1433], y=[2708], ...)
class GNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_features, 16, cached=True)
self.conv2 = GCNConv(16, dataset.num_classes, cached=True)
def forward(self, x, adj_t):
x = self.conv1(x, adj_t)
x = F.relu(x)
x = self.conv2(x, adj_t)
return F.log_softmax(x, dim=1)
model = GNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(data):
model.train()
optimizer.zero_grad()
out = model(data.x, data.adj_t)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
return float(loss)
for epoch in range(1, 201):
loss = train(data)
All code remains the same as before, except for the :obj:data transform via :obj:T.ToSparseTensor().
As an additional advantage, :class:~torch_geometric.nn.conv.message_passing.MessagePassing implementations that utilize the :class:SparseTensor class are deterministic on the GPU since aggregations no longer rely on atomic operations.
Notably, the GNN layer execution slightly changes in case GNNs incorporate single or multi-dimensional edge information :obj:edge_weight or :obj:edge_attr into their message passing formulation, respectively.
In particular, it is now expected that these attributes are directly added as values to the :class:SparseTensor object.
Instead of calling the GNN as
.. code-block:: python
conv = GMMConv(16, 32, dim=3)
out = conv(x, edge_index, edge_attr)
we now execute our GNN operator as
.. code-block:: python
conv = GMMConv(16, 32, dim=3)
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_attr)
out = conv(x, adj.t())
.. note::
Since this feature is still experimental, some operations, *e.g.*, graph pooling methods, may still require you to input the :obj:`edge_index` format.
You can convert :obj:`adj_t` back to :obj:`(edge_index, edge_attr)` via:
.. code-block:: python
row, col, edge_attr = adj_t.t().coo()
edge_index = torch.stack([row, col], dim=0)
Please let us know what you think of :class:SparseTensor, how we can improve it, and whenever you encounter any unexpected behavior.