Back to Pytorch Geometric

Explaining Graph Neural Networks

docs/source/tutorial/explain.rst

2.7.09.2 KB
Original Source

Explaining Graph Neural Networks

Interpreting GNN models is crucial for many use cases. :pyg:PyG (2.3 and beyond) provides the :class:torch_geometric.explain package for first-class GNN explainability support that currently includes

#. a flexible interface to generate a variety of explanations via the :class:~torch_geometric.explain.Explainer class,

#. several underlying explanation algorithms including, e.g., :class:~torch_geometric.explain.algorithm.GNNExplainer, :class:~torch_geometric.explain.algorithm.PGExplainer and :class:~torch_geometric.explain.algorithm.CaptumExplainer,

#. support to visualize explanations via the :class:~torch_geometric.explain.Explanation or the :class:~torch_geometric.explain.HeteroExplanation class,

#. and metrics to evaluate explanations via the :class:~torch_geometric.explain.metric package.

.. warning::

The explanation APIs discussed here may change in the future as we continuously work to improve their ease-of-use and generalizability.

Explainer Interface

The :class:torch_geometric.explain.Explainer class is designed to handle all explainability parameters (see the :class:~torch_geometric.explain.config.ExplainerConfig class for more details):

#. which algorithm from the :class:torch_geometric.explain.algorithm module to use (e.g., :class:~torch_geometric.explain.algorithm.GNNExplainer)

#. the type of explanation to compute, i.e. :obj:explanation_type="phenomenon" to explain the underlying phenomenon of a dataset, and :obj:explanation_type="model" to explain the prediction of a GNN model (see the "GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" <https://arxiv.org/abs/2206.09677>_ paper for more details).

#. the different type of masks for node and edges (e.g., :obj:mask="object" or :obj:mask="attributes")

#. any postprocessing of the masks (e.g., :obj:threshold_type="topk" or :obj:threshold_type="hard")

This class allows the user to easily compare different explainability methods and to easily switch between different types of masks, while making sure the high-level framework stays the same. The :class:~torch_geometric.explain.Explainer generates an :class:~torch_geometric.explain.Explanation or :class:~torch_geometric.explain.HeteroExplanation object which contains the final information about which nodes, edges and features are crucial to explain a GNN model.

.. note::

You can read more about the :class:torch_geometric.explain package in this blog post <https://medium.com/@pytorch_geometric/graph-machine-learning-explainability-with-pyg-ff13cffc23c2>__.

Examples

In what follows, we discuss a few use-cases with corresponding code examples.

Explaining node classification on a homogeneous graph


Assume we have a GNN :obj:`model` that does node classification on a homogeneous graph.
We can use the :class:`torch_geometric.explain.algorithm.GNNExplainer` algorithm to generate an :class:`~torch_geometric.explain.Explanation`.
We configure the :class:`~torch_geometric.explain.Explainer` to use both a :obj:`node_mask_type` and an :obj:`edge_mask_type` such that the final :class:`~torch_geometric.explain.Explanation` object contains (1) a :obj:`node_mask` (indicating which nodes and features are crucial for prediction), and (2) an :obj:`edge_mask` (indicating which edges are crucial for prediction).

.. code-block:: python

    from torch_geometric.data import Data
    from torch_geometric.explain import Explainer, GNNExplainer

    data = Data(...)  # A homogeneous graph data object.

    explainer = Explainer(
        model=model,
        algorithm=GNNExplainer(epochs=200),
        explanation_type='model',
        node_mask_type='attributes',
        edge_mask_type='object',
        model_config=dict(
            mode='multiclass_classification',
            task_level='node',
            return_type='log_probs',  # Model returns log probabilities.
        ),
    )

    # Generate explanation for the node at index `10`:
    explanation = explainer(data.x, data.edge_index, index=10)
    print(explanation.edge_mask)
    print(explanation.node_mask)

Finally, we can visualize both feature importance and the crucial subgraph of the explanation:

.. code-block:: python

    explanation.visualize_feature_importance(top_k=10)

    explanation.visualize_graph()

To evaluate the explanation from the :class:`~torch_geometric.explain.algorithm.GNNExplainer`, we can utilize the :class:`torch_geometric.explain.metric` module.
For example, to compute the :meth:`~torch_geometric.explain.metric.unfaithfulness` of an explanation, run:

.. code-block:: python

    from torch_geometric.explain import unfaithfulness

    metric = unfaithfulness(explainer, explanation)
    print(metric)

Explaining node classification on a heterogeneous graph

Assume we have a heterogeneous GNN :obj:model that does node classification on a heterogeneous graph. We can use the :class:IntegratedGradient attribution method from :captum:null Captum <https://captum.ai/docs/extension/integrated_gradients>__ via the :class:torch_geometric.explain.algorithm.CaptumExplainer algorithm to generate a :class:~torch_geometric.explain.HeteroExplanation.

.. note:: :class:~torch_geometric.explain.algorithm.CaptumExplainer is a wrapper around the :captum:null Captum <https://captum.ai>__ library with support for most of attribution methods to explain any homogeneous or heterogeneous :pyg:PyG model.

We configure the :class:~torch_geometric.explain.Explainer to use both a :obj:node_mask_type and an :obj:edge_mask_type such that the final :class:~torch_geometric.explain.HeteroExplanation object contains (1) a :obj:node_mask for each node type (indicating which nodes and features for each node type are crucial for prediction), and (2) an :obj:edge_mask for each edge type (indicating which edges for each edge type are crucial for prediction).

.. code-block:: python

from torch_geometric.data import HeteroData
from torch_geometric.explain import Explainer, CaptumExplainer

hetero_data = HeteroData(...)  # A heterogeneous graph data object.

explainer = Explainer(
    model,  # It is assumed that model outputs a single tensor.
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config = dict(
        mode='multiclass_classification',
        task_level=task_level,
        return_type='probs',  # Model returns probabilities.
    ),
)

# Generate batch-wise heterogeneous explanations for
# the nodes at index `1` and `3`:
hetero_explanation = explainer(
    hetero_data.x_dict,
    hetero_data.edge_index_dict,
    index=torch.tensor([1, 3]),
)
print(hetero_explanation.edge_mask_dict)
print(hetero_explanation.node_mask_dict)

Explaining graph regression on a homogeneous graph


Assume we have a GNN :obj:`model` that does graph regression on a homogeneous graph.
We can use the :class:`torch_geometric.explain.algorithm.PGExplainer` algorithm to generate an :class:`~torch_geometric.explain.Explanation`.
We configure the :class:`~torch_geometric.explain.Explainer` to use an :obj:`edge_mask_type` such that the final :class:`~torch_geometric.explain.Explanation` object contains an :obj:`edge_mask` (indicating which edges are crucial for prediction).
Importantly, passing a :obj:`node_mask_type` to the :class:`~torch_geometric.explain.Explainer` will throw an error since :class:`~torch_geometric.explain.algorithm.PGExplainer` cannot explain the importance of nodes:

.. code-block:: python

    from torch_geometric.data import Data
    from torch_geometric.explain import Explainer, PGExplainer

    dataset = ...
    loader = DataLoader(dataset, batch_size=1, shuffle=True)

    explainer = Explainer(
        model=model,
        algorithm=PGExplainer(epochs=30, lr=0.003),
        explanation_type='phenomenon',
        edge_mask_type='object',
        model_config=dict(
            mode='regression',
            task_level='graph',
            return_type='raw',
        ),
        # Include only the top 10 most important edges:
        threshold_config=dict(threshold_type='topk', value=10),
    )

    # PGExplainer needs to be trained separately since it is a parametric
    # explainer i.e it uses a neural network to generate explanations:
    for epoch in range(30):
        for batch in loader:
            loss = explainer.algorithm.train(
                epoch, model, batch.x, batch.edge_index, target=batch.target)

    # Generate the explanation for a particular graph:
    explanation = explainer(dataset[0].x, dataset[0].edge_index)
    print(explanation.edge_mask)

Since this feature is still undergoing heavy development, please feel free to reach out to the :pyg:`PyG` core team either on :github:`null` `GitHub <https://github.com/pyg-team/pytorch_geometric/discussions>`_ or :slack:`null` `Slack <https://data.pyg.org/slack.html>`_ if you have any questions, comments or concerns.