본문 바로가기

Papers/ML

[리뷰] Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking (ICLR 2021)

by Schlichtkrull et al.

Paper Link : https://openreview.net/pdf?id=WznmQa42ZAx

Contents

1. Introduction

2. Related Works

3. Methods

4. Synthetic Experiment

5. Question Answering

6. Semantic Role Labeling

7. Summary & Conclusion


Abstract

  • GNNs are famous apporach for integrating structural inductive biases into NLP models
  • There is little work on interpreting for understanding which parts of the graphs (e.g. syntactic trees or co-reference structures) contribute to a prediction
  • We train classifier in a fully differentiable fashion, employing stochastic gates and encouraging sparsity through the expected $L_0$ norm
  • We can drop a large proportion of edges w/o deteriorating the performance of the model

Method

  • A post-hoc method for interpreting the predictions of GNNs which identifies unnecessary edges
  • Attribution method to analyse GNN models providing insights into the info flow in models

Tasks

  • Multi-hop question answering
  • Semantic role rabelling

1. Introduction

As they focus on post-hoc analysis of GNNs to develop method for understanding how GNNs use the input graph.

They aim at identifying which edges in the graph the GNN relies on and which layer they are used

 

Principle for an interpretation method

  1. able to identify relevant paths in the input graph : GNN의 reasoning pattern을 보여주기 위해
  2. tractable : GNN-based NLP 방법론에 사용하기 위해
  3. faithful : 모델이 예측에 어떻게 이르는지에 대해 알려주기 위해

Contributions

  • The study suggest a novel interpretation method for GNNs
  • The authors address shortcomings of previous studies and their method improves faithfulness
  • They use GraphMask analyses GNN models for NLP tasks (QA and smantic role labeling)

Erasure Search : an approach wherein attribution happens by searching for a maximal subset of features that can be entirely removed w/o affecting model predictions

Or, we can easily explain this term as "leave-one-out" method: 

A word’s importance can be measured by the difference in model confidence before and after that word is removed from the input—the word is important if confidence decreases significantly. This is the leave-one-out method (Li et al., 2016b).

 

(Ref 1) Understanding Neural Networks through Representation Erasure https://arxiv.org/pdf/1612.08220.pdf

(Ref 2) Pathologies of Neural Models Make Interpretations Difficult https://arxiv.org/pdf/1804.07781.pdf

2. Related Works

GNNExplainer identifies a compact subgraph structure and a small subset of node features that have a crucial role in GNN’s prediction. It is formulated as an optimization task that maximizes the mutual information.

(Ref 1) 논문 링크

(Ref 2) 투빅스 블로그 설명

 

3. Methods

3.1 GNN

GNN is a layered architecture which takes an input graph G = <V,E> (V: nodes, E: edges) to produce a prediction

At every layer k, a GNN computes a node representation $h_u^{(k)}$ for each node $u \in V$

 

Message function M : 이전 layer의 edge (u, v)와 relation type을 인풋으로 받아 message $m_{u,v}^{(k)}$를 얻는다

$$m_{u,v}^{(k)} = M^{(k)} (h_u^{(k-1)}, h_v^{(k-1)}, r_{u,v})$$

 

Aggregation function A : v의 인접 node N(v)의 message를 인풋으로 받아 node representation $h_{v}^{(k)}$를 얻는다

$$h_{v}^{(k)} = A^{(k)} (\{m_{u,v}^{(k)}: u \in N(v)\})$$

  • $r_{u,v}$ : relation type between nodes
  • $N(v)$ : the set of neighbor nodes of v
  • aggregation function : mean-, sum-, or max-pooling

 

3.2 GraphMask

Goal : to detect which edges (u,v) at layer k can be ignored w/o affecting model predictions

 

$$\hat m_{u,v}^{(k)} = z_{u,v}^{(k)} m_{u,v}^{(k)} + b^{(k)} (1-z_{u,v}^{(k)})$$

  • $z_{u,v}^{(k)} \in {0,1}$

 

Erasure breaks the principle in 2 ways : not tractable and danger of hindsight bias

In erasure search,

  1. optimization happens individually for each example
  2. when even non-superfluous(불필요하지 않은) edges are aggressively pruned
  3. a smilar prediction could be made using alternative smaller subgraphs

: hindsight bias

 

Amortisation (양도)

Goal : use parameterized erasure function rather than an individual per-edge choice to prevent hindsight bias

  • compute $z_{u,v}^{(k)}$ through a simple function g learned once for every task accorss data points
    • parameters $\pi$ are used to explain predictions for examples unseen in the training time
  • explainer is not provided with a look-ahead
    • look-ahead : access to layers above the studied layer, making their model vulnerable to hindsight bias

Alternative : non-amortized version

  • choose the parameters $\pi$ independently for each gate, w/o any parameter sharing across gates
  • susceptibele to hindsight bias

 

3.3 Parameter Estimation

f : GNN 
L : layers
G : graph
X : input embeddings

 

Goal : to identify $G_S = \{ G_S^{(1)}, ... G_S^{(L)}\}$ of informative sub-graphs

We search for a graph with the mnimal # of edges while maintaining $f(G_S, X) \approx f(G,X)$

  1. a divergence $D_* [f(G,X) || f(G_S,X)]$ to measure how two outputs differ
  2. torlerance level $\beta \in R_{>0}$ within which differences are regarded as acceptable

A practical way to minimize the # of non-zeros predicted by g is minimizing the $L_0$ norm (total # of edges that are not masked)

 

Objective

$$max_{\lambda} min_{\pi, b} \sum_{G,X \in D} ( \sum_{k=1}^{L} \sum_{(u,v) \in E} I_{[R!=0]} (z_{u,v}^{(k)}) ) + \lambda (D_* [f(G,X) || f(G_S,X)] - \beta)$$

  • I : indicator function
  • $\lambda$ : Lagrange multiplier

Objective is not differentiable

 

We cannot use gradient-based optimization since

  1. $L_0$ is dis continuous and has zero derivatives almost everywhere
  2. outputting a binary value needs a discontinuous activation (e.g. step function, which is not differentiable)

-> Solution: hard concrete distribution, which assigns a non-zero probability to exact zeroes

4. Synthetic Experiment

Task

  • A star graph G with a single centroid vertext $v_0$, leaf vertices $v_1, ..., v_n$, and edges $(v_1, v_0), ..., (v_n, v_0)$
  • every edge (u, v) is assigned one of several colors $c_{u,v} \in C$
  • given a query ${x,y} \in CxC$, the task is to predict whether # edges assigned x > # edges assigned y

Toy Examples (출처: 본문)

Results

  • Only amortized version of our method approximately replicates the gold standard
  • Erasure search, GNNExplainer, non-amortized version of our method : exploit their training regime to reach the same low-penalty solution with perfect model performance, but which is not faithful to the original model behavior
  • Amortization : prevents overfitting to the objective
  • Integrated gradients : scalar attribution scores vary greatly across examples with different # edges

5. Question Answering

Task : multi-hop QA

Goal : to investigate which edge types are used across the three layers

6. Semantic Role Labeling

Task : to identify arguments of a given predicate and assign them to semantic roles

Goal : to investigate which dependency types the GNN relies on

Task 설명

 

the proportion of edges retained on paths of varying length between the predicate and predicted roles / 유지된 edge의 percentage, path 길이와 predicate type에 따라 edge 0-2개가 유지된 확률을 보여주고 있다.

7. Summary & Conclusion

  • The authors introduce GraphMask, a post-hoc interpretation method applicable to any GNNs
  • By learning end-to-end differentiable hard gates for every message and amortising over the trainin set, GraphMask is faithful, scalable, and capable of identifying both how edges and paths influence predictions
  • They show analysis on the predictions of two NLP models (SRL and QA)

 

※ 위 내용은 Graph 초심자가 스터디 및 개인 학습 용으로 작성한 글로,

일부 정확하지 않은 내용을 포함할 수도 있습니다. (피드백 환영합니다!) 


Q&A

What is "look-ahead" in section 3.2, and why does it help to avoid hindsight bias?

  • look-ahead : access to layers above the studied layer, making their model vulnerable to hindsight bias
  • look-ahead를 제공하지 않으므로서 hindsight bias를 방지할 수 있다.

 

What is the meaning of the convergence of attribution score distribution to near-binary values? Is it more beneficial to interpret or to compute?

  • 이 연구의 objective는 미분 가능하지 않음
  • Hard concrete distribution을 써서 non-zero 확률을 exact zeroes로 변환할 것임 (near-binary)
  • 이 분포의 좋은 성질은 미분 가능한 reparameterization trick (문제가 되는 변수를 잘 변환하여 미분 가능하도록 변환)을 사용해 sampling할 수 있다는 점임
  • Attribution score는 non-zero mask를 샘플링하는 기댓값이 되고, 이 non-zero 값들은 정보 누수가 될 수 있음
  • GaphMask는 score의 기댓값을 near-binary로 가정하는 분포로 수렴함

 

반응형