Explainable GNNs (XGNN)

Interpretability for Graph Neural Networks

SuNaAI Lab

Technical Guide Series

ResourcesTechnical GuidesExplainable Graph Neural Networks

Chapter 1: Explaining Graph Predictions

Making Graph Neural Networks interpretable and trustworthy

Graph Neural Networks excel at learning from relational data, but understanding why they make specific predictions on graphs is challenging. Explainability for GNNs needs to handle both node-level and graph-level explanations.

The XGNN Challenge

Graph structures add complexity to interpretability. We need to identify which subgraphs, nodes, or edges are most important for predictions, while respecting the relational structure.

Use Cases

  • Molecular property prediction in drug discovery
  • Social network analysis and recommendation
  • Traffic prediction and urban planning
  • Knowledge graph reasoning
  • Fraud detection in financial networks
  • Brain connectivity analysis

Chapter 2: GNN Architecture Overview

Understanding the building blocks of Graph Neural Networks

Graph Neural Networks aggregate information from a node's neighborhood to learn node representations. The key challenge is designing permutation-invariant aggregation functions.

Message Passing in GNNs

General Message Passing
# At each layer:
# 1. Create messages from neighbors
m_v = f(h_v, h_u, e_uv) for u in neighbors(v)

# 2. Aggregate messages
h_v^l = AGG({m_v : u in N(v)})

# 3. Update node representation
h_v^(l+1) = UPDATE(h_v^l, h_v^(l-1))

Where:
- h_v: node v features
- e_uv: edge features
- N(v): neighbors of node v
- AGG: aggregation function (mean, sum, max, attention)

Popular GNN Architectures

GCN

Graph Convolutional Network with spectral convolutions

Formula: H^(l+1) = σ(D^(-1/2)AD^(-1/2)H^(l)W^(l))

GAT

Graph Attention Network with learnable attention weights

Formula: h_i^l = σ(Σ α_ij * Wh_j)

GraphSAGE

Sample and aggregate approach with subsampling

Formula: AGG(h_v, {h_u})

GIN

Graph Isomorphism Network - most expressive GNN

Formula: h_v = MLP((1+ε)h_v + Σ h_u)

Chapter 3: XGNN Methods

Techniques for explaining GNN predictions

Method Categories

1. Gradient-Based Methods

Use gradients to identify important nodes/edges for predictions.

Examples: Gradient-based attention, Integrated Gradients
Pros: Fast, model-agnostic
Cons: Only provides local importance

2. Perturbation-Based Methods

Remove or mask parts of the graph and observe prediction changes.

Examples: GNNExplainer, PGExplainer
Pros: Model-agnostic, interpretable
Cons: Computationally expensive

3. Attention-Based Methods

Leverage built-in attention mechanisms in attention-based GNNs.

Examples: GAT attention weights, Multi-head attention
Pros: No additional computation needed
Cons: Only for attention-based models

4. Surrogate Methods

Train simpler interpretable models to approximate GNN behavior.

Examples: Graph Attention Patterns, Tree Ensembles
Pros: Highly interpretable
Cons: Approximation may be inaccurate

Chapter 4: Attention-Based Explanations

Using attention to understand GNN decisions

For GAT and other attention-based GNNs, the attention weights themselves can provide explanations. These weights indicate how much each neighbor influences a node's representation.

Attention-Based Explanation
from torch_geometric.nn import GATConv
import torch
import torch.nn.functional as F

class GATWithAttention(torch.nn.Module):
    def __init__(self, num_features, hidden_dim, num_classes):
        super().__init__()
        self.gat1 = GATConv(num_features, hidden_dim, heads=4)
        self.gat2 = GATConv(hidden_dim * 4, num_classes, heads=1)
        
    def forward(self, x, edge_index):
        # First layer
        x, attention_weights = self.gat1(x, edge_index, return_attention_weights=True)
        x = F.dropout(x, training=self.training)
        
        # Second layer
        x = self.gat2(x, edge_index)
        
        return F.log_softmax(x, dim=1), attention_weights

# Extract attention explanations
model.eval()
logits, (edge_index, attention) = model(x, edge_index)

# Visualize attention for node i
node_i_attention = attention[i, :, :]  # Multi-head attention

Visualizing Attention

Attention Heatmaps

Attention weights can be visualized as heatmaps showing which edges are most important. Higher weights indicate stronger influence on the prediction.

Chapter 5: Important Subgraph Extraction

Identifying causal subgraphs for predictions

Many XGNN methods identify important subgraphs—connected components of the input graph that are crucial for the prediction. These provide interpretable, local explanations.

GNNExplainer

GNNExplainer Implementation
import torch
import torch.nn.functional as F
from torch_geometric.nn import GNNExplainer

# Trained GNN model
model = trained_gnn_model()

# Create explainer
explainer = GNNExplainer(model, epochs=100, return_type='log_prob')

# Explain a specific node
node_idx = 100  # Node to explain
subgraph = explainer.explain(x, edge_index, node_idx)

# subgraph contains:
# - node_mask: Which nodes are important
# - edge_mask: Which edges are important
# - subgraph_edge_index: Edges of important subgraph

# Visualize
explainer.visualize_subgraph(node_idx, subgraph.edge_index)

Optimization Objective

GNNExplainer finds subgraph G_S that maximizes mutual information:

max_G_S MI(Y, (G_S, X_S)) = H(Y) - H(Y | (G_S, X_S))

This maximizes the information that subgraph G_S provides about prediction Y, while keeping the subgraph small and connected.

Chapter 6: Counterfactual Explanations

"What would happen if we changed the graph?"

Counterfactual explanations show minimal changes needed to flip a prediction. For graphs, this means adding/removing edges or changing node features.

Generating Counterfactuals

Counterfactual Generation
# Find minimal changes to flip prediction
from explainability import CounterfactualGNN

explainer = CounterfactualGNN(model, 
                              perturbation=('edge', 0.1))

# Generate counterfactual for node i
original_pred = model(graph, node_idx=i)
original_class = original_pred.argmax()

# Generate counterfactual graph
cf_graph, changes = explainer.explain(graph, node_idx=i)

cf_pred = model(cf_graph, node_idx=i)
cf_class = cf_pred.argmax()

print(f"Original: Class {original_class}")
print(f"Counterfactual: Class {cf_class}")
print(f"Changes: {changes}")

Properties of Good Counterfactuals

🎯 Validity

The counterfactual should actually change the prediction to the desired class.

📏 Sparsity

Minimal changes are more interpretable and actionable.

✅ Feasibility

Changes should be realistic and implementable in the real world.

Chapter 7: Evaluating XGNN Methods

Metrics for explanation quality

Evaluation Metrics

1. Fidelity

How well the explanation predicts the model's behavior.

Fidelity = 1 - |P_m(G_s) - P_m(G)|
Where G_s is the important subgraph

2. Sparsity

Explanations should identify small, focused subgraphs.

Sparsity = 1 - |G_s| / |G|

3. Stability

Explanations should be consistent across similar inputs.

Stability = 1 - D(explain(G_1), explain(G_2))
Where D is distance metric

Chapter 8: Best Practices

Guidelines for XGNN

✅ Do's:

  • • Use multiple explanation methods
  • • Validate with domain experts
  • • Consider computational efficiency
  • • Evaluate explanation quality
  • • Provide visualizations
  • • Document limitations

❌ Don'ts:

  • • Trust explanations blindly
  • • Ignore graph structure
  • • Focus only on node features
  • • Overinterpret small changes