Interpretability for Graph Neural Networks
SuNaAI Lab
Technical Guide Series
Making Graph Neural Networks interpretable and trustworthy
Graph Neural Networks excel at learning from relational data, but understanding why they make specific predictions on graphs is challenging. Explainability for GNNs needs to handle both node-level and graph-level explanations.
Graph structures add complexity to interpretability. We need to identify which subgraphs, nodes, or edges are most important for predictions, while respecting the relational structure.
Understanding the building blocks of Graph Neural Networks
Graph Neural Networks aggregate information from a node's neighborhood to learn node representations. The key challenge is designing permutation-invariant aggregation functions.
# At each layer: # 1. Create messages from neighbors m_v = f(h_v, h_u, e_uv) for u in neighbors(v) # 2. Aggregate messages h_v^l = AGG({m_v : u in N(v)}) # 3. Update node representation h_v^(l+1) = UPDATE(h_v^l, h_v^(l-1)) Where: - h_v: node v features - e_uv: edge features - N(v): neighbors of node v - AGG: aggregation function (mean, sum, max, attention)
Graph Convolutional Network with spectral convolutions
Formula: H^(l+1) = σ(D^(-1/2)AD^(-1/2)H^(l)W^(l))
Graph Attention Network with learnable attention weights
Formula: h_i^l = σ(Σ α_ij * Wh_j)
Sample and aggregate approach with subsampling
Formula: AGG(h_v, {h_u})
Graph Isomorphism Network - most expressive GNN
Formula: h_v = MLP((1+ε)h_v + Σ h_u)
Techniques for explaining GNN predictions
Use gradients to identify important nodes/edges for predictions.
Examples: Gradient-based attention, Integrated Gradients
Pros: Fast, model-agnostic
Cons: Only provides local importance
Remove or mask parts of the graph and observe prediction changes.
Examples: GNNExplainer, PGExplainer
Pros: Model-agnostic, interpretable
Cons: Computationally expensive
Leverage built-in attention mechanisms in attention-based GNNs.
Examples: GAT attention weights, Multi-head attention
Pros: No additional computation needed
Cons: Only for attention-based models
Train simpler interpretable models to approximate GNN behavior.
Examples: Graph Attention Patterns, Tree Ensembles
Pros: Highly interpretable
Cons: Approximation may be inaccurate
Using attention to understand GNN decisions
For GAT and other attention-based GNNs, the attention weights themselves can provide explanations. These weights indicate how much each neighbor influences a node's representation.
from torch_geometric.nn import GATConv
import torch
import torch.nn.functional as F
class GATWithAttention(torch.nn.Module):
def __init__(self, num_features, hidden_dim, num_classes):
super().__init__()
self.gat1 = GATConv(num_features, hidden_dim, heads=4)
self.gat2 = GATConv(hidden_dim * 4, num_classes, heads=1)
def forward(self, x, edge_index):
# First layer
x, attention_weights = self.gat1(x, edge_index, return_attention_weights=True)
x = F.dropout(x, training=self.training)
# Second layer
x = self.gat2(x, edge_index)
return F.log_softmax(x, dim=1), attention_weights
# Extract attention explanations
model.eval()
logits, (edge_index, attention) = model(x, edge_index)
# Visualize attention for node i
node_i_attention = attention[i, :, :] # Multi-head attention
Attention weights can be visualized as heatmaps showing which edges are most important. Higher weights indicate stronger influence on the prediction.
Identifying causal subgraphs for predictions
Many XGNN methods identify important subgraphs—connected components of the input graph that are crucial for the prediction. These provide interpretable, local explanations.
import torch import torch.nn.functional as F from torch_geometric.nn import GNNExplainer # Trained GNN model model = trained_gnn_model() # Create explainer explainer = GNNExplainer(model, epochs=100, return_type='log_prob') # Explain a specific node node_idx = 100 # Node to explain subgraph = explainer.explain(x, edge_index, node_idx) # subgraph contains: # - node_mask: Which nodes are important # - edge_mask: Which edges are important # - subgraph_edge_index: Edges of important subgraph # Visualize explainer.visualize_subgraph(node_idx, subgraph.edge_index)
GNNExplainer finds subgraph G_S that maximizes mutual information:
This maximizes the information that subgraph G_S provides about prediction Y, while keeping the subgraph small and connected.
"What would happen if we changed the graph?"
Counterfactual explanations show minimal changes needed to flip a prediction. For graphs, this means adding/removing edges or changing node features.
# Find minimal changes to flip prediction
from explainability import CounterfactualGNN
explainer = CounterfactualGNN(model,
perturbation=('edge', 0.1))
# Generate counterfactual for node i
original_pred = model(graph, node_idx=i)
original_class = original_pred.argmax()
# Generate counterfactual graph
cf_graph, changes = explainer.explain(graph, node_idx=i)
cf_pred = model(cf_graph, node_idx=i)
cf_class = cf_pred.argmax()
print(f"Original: Class {original_class}")
print(f"Counterfactual: Class {cf_class}")
print(f"Changes: {changes}")The counterfactual should actually change the prediction to the desired class.
Minimal changes are more interpretable and actionable.
Changes should be realistic and implementable in the real world.
Metrics for explanation quality
How well the explanation predicts the model's behavior.
Fidelity = 1 - |P_m(G_s) - P_m(G)|
Where G_s is the important subgraph
Explanations should identify small, focused subgraphs.
Sparsity = 1 - |G_s| / |G|
Explanations should be consistent across similar inputs.
Stability = 1 - D(explain(G_1), explain(G_2))
Where D is distance metric
Guidelines for XGNN