forked from PAIR-code/lit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgradient_maps.py
376 lines (312 loc) · 15.2 KB
/
gradient_maps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Gradient-based attribution."""
from typing import cast, List, Text, Optional
from absl import logging
from lit_nlp.api import components as lit_components
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import dtypes
from lit_nlp.api import model as lit_model
from lit_nlp.api import types
from lit_nlp.components.citrus import utils as citrus_utils
from lit_nlp.lib import utils
import numpy as np
JsonDict = types.JsonDict
Spec = types.Spec
class GradientNorm(lit_components.Interpreter):
"""Salience map from gradient L2 norm."""
def find_fields(self, output_spec: Spec) -> List[Text]:
# Find TokenGradients fields
grad_fields = utils.find_spec_keys(output_spec, types.TokenGradients)
# Check that these are aligned to Tokens fields
for f in grad_fields:
tokens_field = output_spec[f].align # pytype: disable=attribute-error
assert tokens_field in output_spec
assert isinstance(output_spec[tokens_field], types.Tokens)
return grad_fields
def _interpret(self, grads: np.ndarray, tokens: np.ndarray):
assert grads.shape[0] == len(tokens)
# Norm of dy/d(embs)
grad_norm = np.linalg.norm(grads, axis=1)
grad_norm /= np.sum(grad_norm)
# <float32>[num_tokens]
return grad_norm
def run(self,
inputs: List[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.Dataset,
model_outputs: Optional[List[JsonDict]] = None,
config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
"""Run this component, given a model and input(s)."""
# Find gradient fields to interpret
output_spec = model.output_spec()
grad_fields = self.find_fields(output_spec)
logging.info('Found fields for gradient attribution: %s', str(grad_fields))
if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test
return None
# Run model, if needed.
if model_outputs is None:
model_outputs = list(model.predict(inputs))
assert len(model_outputs) == len(inputs)
all_results = []
for o in model_outputs:
# Dict[field name -> interpretations]
result = {}
for grad_field in grad_fields:
token_field = cast(types.TokenGradients, output_spec[grad_field]).align
tokens = o[token_field]
scores = self._interpret(o[grad_field], tokens)
result[grad_field] = dtypes.SalienceMap(tokens, scores)
all_results.append(result)
return all_results
class GradientDotInput(lit_components.Interpreter):
"""Salience map using the values of gradient * input as attribution."""
def find_fields(self, input_spec: Spec, output_spec: Spec) -> List[Text]:
# Find TokenGradients fields
grad_fields = utils.find_spec_keys(output_spec, types.TokenGradients)
# Check that these are aligned to Tokens fields
aligned_fields = []
for f in grad_fields:
tokens_field = output_spec[f].align # pytype: disable=attribute-error
assert tokens_field in output_spec
assert isinstance(output_spec[tokens_field], types.Tokens)
embeddings_field = output_spec[f].grad_for
if embeddings_field is not None:
assert embeddings_field in input_spec
assert isinstance(input_spec[embeddings_field], types.TokenEmbeddings)
assert embeddings_field in output_spec
assert isinstance(output_spec[embeddings_field], types.TokenEmbeddings)
aligned_fields.append(f)
else:
logging.info('Skipping %s since embeddings field not found.', str(f))
return aligned_fields
def _interpret(self, grads: np.ndarray, embs: np.ndarray):
assert grads.shape == embs.shape
# dot product of gradients and embeddings
# <float32>[num_tokens]
grad_dot_input = np.sum(grads * embs, axis=-1)
scores = citrus_utils.normalize_scores(grad_dot_input)
return scores
def run(self,
inputs: List[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.Dataset,
model_outputs: Optional[List[JsonDict]] = None,
config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
"""Run this component, given a model and input(s)."""
# Find gradient fields to interpret
input_spec = model.input_spec()
output_spec = model.output_spec()
grad_fields = self.find_fields(input_spec, output_spec)
logging.info('Found fields for gradient attribution: %s', str(grad_fields))
if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test
return None
# Run model, if needed.
if model_outputs is None:
model_outputs = list(model.predict(inputs))
assert len(model_outputs) == len(inputs)
all_results = []
for o in model_outputs:
# Dict[field name -> interpretations]
result = {}
for grad_field in grad_fields:
embeddings_field = cast(types.TokenGradients,
output_spec[grad_field]).grad_for
scores = self._interpret(o[grad_field], o[embeddings_field])
token_field = cast(types.TokenGradients, output_spec[grad_field]).align
tokens = o[token_field]
result[grad_field] = dtypes.SalienceMap(tokens, scores)
all_results.append(result)
return all_results
class IntegratedGradients(lit_components.Interpreter):
"""Salience map from Integrated Gradients.
Integrated Gradients is an attribution method originally proposed in
Sundararajan et al. (https://arxiv.org/abs/1703.01365), which attributes an
importance value for each input feature based on the gradients of the model
output with respect to the input. The feature attribution values are
calculated by taking the integral of gradients along a straight path from a
baseline to the input being analyzed. The original implementation can be
found at: https://github.com/ankurtaly/Integrated-Gradients/blob/master/
BertModel/bert_model_utils.py
This component requires that the following fields in the model spec. Field
names like `embs` are placeholders; you can call them whatever you like,
and as with other LIT components multiple segments are supported.
Output:
- TokenEmbeddings (`embs`) to return the input embeddings
- TokenGradients (`grads`) to return gradients w.r.t. `embs`
- A label field (`target`) to return the label that `grads`
was computed for. This is usually a CategoryLabel, but can be anything
since it will just be fed back into the model.
Input
- TokenEmbeddings (`embs`) to accept the modified input embeddings
- A label field to (`target`) to pin the gradient target to the same
label for all integral steps, since the argmax prediction may change.
"""
def __init__(self, interpolation_steps=30):
# TODO(b/168042999): Make this parameter configurable in the UI.
self.interpolation_steps = interpolation_steps
def find_fields(self, input_spec: Spec, output_spec: Spec) -> List[Text]:
# Find TokenGradients fields
grad_fields = utils.find_spec_keys(output_spec, types.TokenGradients)
# Check that these are aligned to Tokens fields
aligned_fields = []
for f in grad_fields:
tokens_field = output_spec[f].align # pytype: disable=attribute-error
assert tokens_field in output_spec
assert isinstance(output_spec[tokens_field], types.Tokens)
embeddings_field = output_spec[f].grad_for
grad_class_key = output_spec[f].grad_target
if embeddings_field is not None and grad_class_key is not None:
assert embeddings_field in input_spec
assert isinstance(input_spec[embeddings_field], types.TokenEmbeddings)
assert embeddings_field in output_spec
assert isinstance(output_spec[embeddings_field], types.TokenEmbeddings)
assert grad_class_key in input_spec
assert grad_class_key in output_spec
aligned_fields.append(f)
else:
logging.info('Skipping %s since embeddings field not found.', str(f))
return aligned_fields
def get_interpolated_inputs(self, baseline: np.ndarray, target: np.ndarray,
num_steps: int) -> np.ndarray:
"""Gets num_step linearly interpolated inputs from baseline to target."""
if num_steps <= 0: return np.array([])
if num_steps == 1: return np.array([baseline, target])
delta = target - baseline # <float32>[num_tokens, emb_size]
# Creates scale values array of shape [num_steps, num_tokens, emb_dim],
# where the values in scales[i] are the ith step from np.linspace.
# <float32>[num_steps, 1, 1]
scales = np.linspace(0, 1, num_steps + 1,
dtype=np.float32)[:, np.newaxis, np.newaxis]
shape = (num_steps + 1,) + delta.shape
# <float32>[num_steps, num_tokens, emb_size]
deltas = scales * np.broadcast_to(delta, shape)
interpolated_inputs = baseline + deltas
return interpolated_inputs # <float32>[num_steps, num_tokens, emb_size]
def estimate_integral(self, path_gradients: np.ndarray) -> np.ndarray:
"""Estimates the integral of the path_gradients using trapezoid rule."""
path_gradients = (path_gradients[:-1] + path_gradients[1:]) / 2
# There are num_steps elements in the path_gradients. Summing num_steps - 1
# terms and dividing by num_steps - 1 is equivalent to taking
# the average.
return np.average(path_gradients, axis=0)
def get_baseline(self, embeddings: np.ndarray) -> np.ndarray:
"""Returns baseline embeddings to use in Integrated Gradients."""
# Replaces embeddings in the original input with the zero embedding, or
# with the specified token embedding.
baseline = np.zeros_like(embeddings)
# TODO(ellenj): Add option to use a token's embedding as the baseline.
return baseline
def get_salience_result(self, model_input: JsonDict, model: lit_model.Model,
model_output: JsonDict, grad_fields: List[Text]):
result = {}
output_spec = model.output_spec()
# We ensure that the embedding and gradient class fields are present in the
# model's input spec in find_fields().
embeddings_fields = [
cast(types.TokenGradients,
output_spec[grad_field]).grad_for for grad_field in grad_fields]
# The gradient class input is used to specify the target class of the
# gradient calculation (if unspecified, this option defaults to the argmax,
# which could flip between interpolated inputs).
grad_class_key = cast(types.TokenGradients,
output_spec[grad_fields[0]]).grad_target
# TODO(b/168042999): Add option to specify the class to explain in the UI.
grad_class = model_output[grad_class_key]
interpolated_inputs = {}
all_embeddings = []
all_baselines = []
for embed_field in embeddings_fields:
# <float32>[num_tokens, emb_size]
embeddings = np.array(model_output[embed_field])
all_embeddings.append(embeddings)
# Starts with baseline of zeros. <float32>[num_tokens, emb_size]
baseline = self.get_baseline(embeddings)
all_baselines.append(baseline)
# Get interpolated inputs from baseline to original embedding.
# <float32>[interpolation_steps, num_tokens, emb_size]
interpolated_inputs[embed_field] = self.get_interpolated_inputs(
baseline, embeddings, self.interpolation_steps)
# Create model inputs and populate embedding field(s).
inputs_with_embeds = []
for i in range(self.interpolation_steps):
input_copy = model_input.copy()
# Interpolates embeddings for all inputs simultaneously.
for embed_field in embeddings_fields:
# <float32>[num_tokens, emb_size]
input_copy[embed_field] = interpolated_inputs[embed_field][i]
input_copy[grad_class_key] = grad_class
inputs_with_embeds.append(input_copy)
embed_outputs = model.predict(inputs_with_embeds)
# Create list with concatenated gradients for each interpolate input.
gradients = []
for o in embed_outputs:
# <float32>[total_num_tokens, emb_size]
interp_gradients = np.concatenate([o[field] for field in grad_fields])
gradients.append(interp_gradients)
# <float32>[interpolation_steps, total_num_tokens, emb_size]
path_gradients = np.stack(gradients, axis=0)
# Calculate integral
# <float32>[total_num_tokens, emb_size]
integral = self.estimate_integral(path_gradients)
# <float32>[total_num_tokens, emb_size]
concat_embeddings = np.concatenate(all_embeddings)
# <float32>[total_num_tokens, emb_size]
concat_baseline = np.concatenate(all_baselines)
# <float32>[total_num_tokens, emb_size]
integrated_gradients = integral * (np.array(concat_embeddings) -
np.array(concat_baseline))
# Dot product of integral values and (embeddings - baseline).
# <float32>[total_num_tokens]
attributions = np.sum(integrated_gradients, axis=-1)
# TODO(b/168042999): Make normalization customizable in the UI.
# <float32>[total_num_tokens]
scores = citrus_utils.normalize_scores(attributions)
for grad_field in grad_fields:
# Format as salience map result.
token_field = cast(types.TokenGradients, output_spec[grad_field]).align
tokens = model_output[token_field]
# Only use the scores that correspond to the tokens in this grad_field.
# The gradients for all input embeddings were concatenated in the order
# of the grad fields, so they can be sliced out in the same order.
sliced_scores = scores[:len(tokens)] # <float32>[num_tokens in field]
scores = scores[len(tokens):] # <float32>[num_remaining_tokens]
assert len(tokens) == len(sliced_scores)
result[grad_field] = dtypes.SalienceMap(tokens, sliced_scores)
return result
def run(self,
inputs: List[JsonDict],
model: lit_model.Model,
dataset: lit_dataset.Dataset,
model_outputs: Optional[List[JsonDict]] = None,
config: Optional[JsonDict] = None) -> Optional[List[JsonDict]]:
"""Run this component, given a model and input(s)."""
# Find gradient fields to interpret
input_spec = model.input_spec()
output_spec = model.output_spec()
grad_fields = self.find_fields(input_spec, output_spec)
logging.info('Found fields for integrated gradients: %s', str(grad_fields))
if len(grad_fields) == 0: # pylint: disable=g-explicit-length-test
return None
# Run model, if needed.
if model_outputs is None:
model_outputs = list(model.predict(inputs))
all_results = []
for model_output, model_input in zip(model_outputs, inputs):
result = self.get_salience_result(model_input, model, model_output,
grad_fields)
all_results.append(result)
return all_results