Back to Pytorch Geometric

Memory-Efficient Aggregations

docs/source/advanced/sparse_tensor.rst

2.7.08.8 KB
Original Source

Memory-Efficient Aggregations

The :class:~torch_geometric.nn.conv.message_passing.MessagePassing interface of :pyg:PyG relies on a gather-scatter scheme to aggregate messages from neighboring nodes. For example, consider the message passing layer

.. math::

\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \textrm{MLP}(\mathbf{x}_j - \mathbf{x}_i),

that can be implemented as:

.. code-block:: python

from torch_geometric.nn import MessagePassing

x = ...           # Node features of shape [num_nodes, num_features]
edge_index = ...  # Edge indices of shape [2, num_edges]

class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr="add")

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        return MLP(x_j - x_i)

Under the hood, the :class:~torch_geometric.nn.conv.message_passing.MessagePassing implementation produces a code that looks as follows:

.. code-block:: python

from torch_geometric.utils import scatter

x = ...           # Node features of shape [num_nodes, num_features]
edge_index = ...  # Edge indices of shape [2, num_edges]

x_j = x[edge_index[0]]  # Source node features [num_edges, num_features]
x_i = x[edge_index[1]]  # Target node features [num_edges, num_features]

msg = MLP(x_j - x_i)  # Compute message for each edge

# Aggregate messages based on target node indices
out = scatter(msg, edge_index[1], dim=0, dim_size=x.size(0), reduce='sum')

While the gather-scatter formulation generalizes to a lot of useful GNN implementations, it has the disadvantage of explicitly materalizing :obj:x_j and :obj:x_i, resulting in a high memory footprint on large and dense graphs.

Luckily, not all GNNs need to be implemented by explicitly materalizing :obj:x_j and/or :obj:x_i. In some cases, GNNs can also be implemented as a simple-sparse matrix multiplication. As a general rule of thumb, this holds true for GNNs that do not make use of the central node features :obj:x_i or multi-dimensional edge features when computing messages. For example, the :class:~torch_geometric.nn.conv.GINConv layer

.. math::

\mathbf{x}^{\prime}_i = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right),

is equivalent to computing

.. math::

\mathbf{X}^{\prime} = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{X} + \mathbf{A}\mathbf{X} \right),

where :math:\mathbf{A} denotes a sparse adjacency matrix of shape :obj:[num_nodes, num_nodes]. This formulation allows to leverage dedicated and fast sparse-matrix multiplication implementations.

In :pyg:null PyG >= 1.6.0, we officially introduce better support for sparse-matrix multiplication GNNs, resulting in a lower memory footprint and a faster execution time. As a result, we introduce the :class:SparseTensor class (from the :obj:torch_sparse package), which implements fast forward and backward passes for sparse-matrix multiplication based on the "Design Principles for Sparse Matrix Multiplication on the GPU" <https://arxiv.org/abs/1803.08601>_ paper.

Using the :class:SparseTensor class is straightforward and similar to the way :obj:scipy treats sparse matrices:

.. code-block:: python

from torch_sparse import SparseTensor

adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=...,
                   sparse_sizes=(num_nodes, num_nodes))
# value is optional and can be None

# Obtain different representations (COO, CSR, CSC):
row,    col, value = adj.coo()
rowptr, col, value = adj.csr()
colptr, row, value = adj.csc()

adj = adj[:100, :100]  # Slicing, indexing and masking support
adj = adj.set_diag()   # Add diagonal entries
adj_t = adj.t()        # Transpose
out = adj.matmul(x)    # Sparse-dense matrix multiplication
adj = adj.matmul(adj)  # Sparse-sparse matrix multiplication

# Creating SparseTensor instances:
adj = SparseTensor.from_dense(mat)
adj = SparseTensor.eye(100, 100)
adj = SparseTensor.from_scipy(mat)

Our :class:~torch_geometric.nn.conv.message_passing.MessagePassing interface can handle both :obj:torch.Tensor and :class:SparseTensor as input for propagating messages. However, when holding a directed graph in :class:SparseTensor, you need to make sure to input the transposed sparse matrix to :func:~torch_geometric.nn.conv.message_passing.MessagePassing.propagate:

.. code-block:: python

conv = GCNConv(16, 32)
out1 = conv(x, edge_index)
out2 = conv(x, adj.t())
assert torch.allclose(out1, out2)

conv = GINConv(nn=Sequential(Linear(16, 32), ReLU(), Linear(32, 32)))
out1 = conv(x, edge_index)
out2 = conv(x, adj.t())
assert torch.allclose(out1, out2)

To leverage sparse-matrix multiplications, the :class:~torch_geometric.nn.conv.message_passing.MessagePassing interface introduces the :func:~torch_geometric.nn.conv.message_passing.message_and_aggregate function (which fuses the :func:~torch_geometric.nn.conv.message_passing.message and :func:~torch_geometric.nn.conv.message_passing.aggregate functions into a single computation step), which gets called whenever it is implemented and receives a :class:SparseTensor as input for :obj:edge_index. With it, the :class:~torch_geometric.nn.conv.GINConv layer can now be implemented as follows:

.. code-block:: python

import torch_sparse

class GINConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr="add")

    def forward(self, x, edge_index):
        out = self.propagate(edge_index, x=x)
        return MLP((1 + eps) x + out)

    def message(self, x_j):
        return x_j

    def message_and_aggregate(self, adj_t, x):
        return torch_sparse.matmul(adj_t, x, reduce=self.aggr)

Playing around with the new :class:SparseTensor format is straightforward since all of our GNNs work with it out-of-the-box. To convert the :obj:edge_index format to the newly introduced :class:SparseTensor format, you can make use of the :class:torch_geometric.transforms.ToSparseTensor transform:

.. code-block:: python

import torch
import torch.nn.functional as F

from torch_geometric.nn import GCNConv
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid

dataset = Planetoid("Planetoid", name="Cora", transform=T.ToSparseTensor())
data = dataset[0]
>>> Data(adj_t=[2708, 2708, nnz=10556], x=[2708, 1433], y=[2708], ...)


class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, 16, cached=True)
        self.conv2 = GCNConv(16, dataset.num_classes, cached=True)

    def forward(self, x, adj_t):
        x = self.conv1(x, adj_t)
        x = F.relu(x)
        x = self.conv2(x, adj_t)
        return F.log_softmax(x, dim=1)

model = GNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.adj_t)
    loss = F.nll_loss(out, data.y)
    loss.backward()
    optimizer.step()
    return float(loss)

for epoch in range(1, 201):
    loss = train(data)

All code remains the same as before, except for the :obj:data transform via :obj:T.ToSparseTensor(). As an additional advantage, :class:~torch_geometric.nn.conv.message_passing.MessagePassing implementations that utilize the :class:SparseTensor class are deterministic on the GPU since aggregations no longer rely on atomic operations.

Notably, the GNN layer execution slightly changes in case GNNs incorporate single or multi-dimensional edge information :obj:edge_weight or :obj:edge_attr into their message passing formulation, respectively. In particular, it is now expected that these attributes are directly added as values to the :class:SparseTensor object. Instead of calling the GNN as

.. code-block:: python

conv = GMMConv(16, 32, dim=3)
out = conv(x, edge_index, edge_attr)

we now execute our GNN operator as

.. code-block:: python

conv = GMMConv(16, 32, dim=3)
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_attr)
out = conv(x, adj.t())

.. note::

Since this feature is still experimental, some operations, *e.g.*, graph pooling methods, may still require you to input the :obj:`edge_index` format.
You can convert :obj:`adj_t` back to :obj:`(edge_index, edge_attr)` via:

.. code-block:: python

    row, col, edge_attr = adj_t.t().coo()
    edge_index = torch.stack([row, col], dim=0)

Please let us know what you think of :class:SparseTensor, how we can improve it, and whenever you encounter any unexpected behavior.