forked from PAIR-code/lit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathglue_models.py
502 lines (430 loc) · 19.8 KB
/
glue_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
# Lint as: python3
"""Wrapper for fine-tuned HuggingFace models in LIT."""
import os
import re
from typing import Optional, Dict, List, Iterable
from absl import logging
import attr
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.lib import utils
import numpy as np
import tensorflow as tf
import transformers
JsonDict = lit_types.JsonDict
Spec = lit_types.Spec
def _from_pretrained(cls, *args, **kw):
"""Load a transformers model in TF2, with fallback to PyTorch weights."""
try:
return cls.from_pretrained(*args, **kw)
except OSError as e:
logging.warning("Caught OSError loading model: %s", e)
logging.warning(
"Re-trying to convert from PyTorch checkpoint (from_pt=True)")
return cls.from_pretrained(*args, from_pt=True, **kw)
@attr.s(auto_attribs=True, kw_only=True)
class GlueModelConfig(object):
"""Config options for a GlueModel."""
# Preprocessing options
max_seq_length: int = 128
inference_batch_size: int = 32
# Input options
text_a_name: str = "sentence1"
text_b_name: Optional[str] = "sentence2" # set to None for single-segment
label_name: str = "label"
# Output options
labels: Optional[List[str]] = None # set to None for regression
null_label_idx: Optional[int] = None
compute_grads: bool = True # if True, compute and return gradients.
class GlueModel(lit_model.Model):
"""GLUE benchmark model, using Keras/TF2 and Huggingface Transformers.
This is a general-purpose classification or regression model. It works for
one- or two-segment input, and predicts either a multiclass label or
a regression score. See GlueModelConfig for available options.
This implements the LIT API for inference (e.g. input_spec(), output_spec(),
and predict()), but also provides a train() method to run fine-tuning.
This is a full-featured implementation, which includes embeddings, attention,
gradients, as well as support for the different input and output types above.
For a more minimal example, see ../simple_tf2_demo.py.
"""
@property
def is_regression(self) -> bool:
return self.config.labels is None
# TODO(lit-dev): upgrade version of huggingface so we can just pass
# output_hidden_states and output_attentions at inference time, rather than
# as part of the model config. Then we don't need a special for_training mode.
def __init__(self,
model_name_or_path="bert-base-uncased",
for_training=False,
**config_kw):
self.config = GlueModelConfig(**config_kw)
self._load_model(model_name_or_path, for_training)
def _load_model(self, model_name_or_path, for_training):
"""Load model. Can be overridden for testing."""
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path)
model_config = transformers.AutoConfig.from_pretrained(
model_name_or_path,
num_labels=1 if self.is_regression else len(self.config.labels),
output_hidden_states=(not for_training),
output_attentions=(not for_training),
)
self.model = _from_pretrained(
transformers.TFAutoModelForSequenceClassification,
model_name_or_path,
config=model_config)
def _preprocess(self, inputs: Iterable[JsonDict]) -> Dict[str, tf.Tensor]:
segments = [
(ex[self.config.text_a_name],
ex[self.config.text_b_name] if self.config.text_b_name else None)
for ex in inputs
]
encoded_input = self.tokenizer.batch_encode_plus(
segments,
return_tensors="tf",
add_special_tokens=True,
max_length=self.config.max_seq_length,
pad_to_max_length=True)
# Trim everything to the actual max length, to remove extra padding.
max_tokens = tf.reduce_max(
tf.reduce_sum(encoded_input["attention_mask"], axis=1))
encoded_input = {k: v[:, :max_tokens] for k, v in encoded_input.items()}
return encoded_input
def _make_dataset(self, inputs: Iterable[JsonDict]) -> tf.data.Dataset:
"""Make a tf.data.Dataset from inputs in LIT format."""
encoded_input = self._preprocess(inputs)
if self.is_regression:
labels = tf.constant([ex[self.config.label_name] for ex in inputs],
dtype=tf.float32)
else:
labels = tf.constant([
self.config.labels.index(ex[self.config.label_name]) for ex in inputs
],
dtype=tf.int64)
# encoded_input is actually a transformers.tokenization_utils.BatchEncoding
# object, which tf.data.Dataset doesn't like. Convert to a regular dict.
return tf.data.Dataset.from_tensor_slices((dict(encoded_input), labels))
def train(self,
train_inputs: List[JsonDict],
validation_inputs: List[JsonDict],
learning_rate=2e-5,
batch_size=32,
num_epochs=3,
keras_callbacks=None):
"""Run fine-tuning."""
train_dataset = self._make_dataset(train_inputs).shuffle(128).batch(
batch_size).repeat(-1)
# Use larger batch for validation since inference is about 1/2 memory usage
# of backprop.
eval_batch_size = 2 * batch_size
validation_dataset = self._make_dataset(validation_inputs).batch(
eval_batch_size)
# Prepare model for training.
opt = tf.keras.optimizers.Adam(learning_rate=learning_rate, epsilon=1e-08)
if self.is_regression:
loss = tf.keras.losses.MeanSquaredError()
metric = tf.keras.metrics.RootMeanSquaredError("rmse")
else:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
self.model.compile(optimizer=opt, loss=loss, metrics=[metric])
steps_per_epoch = len(train_inputs) // batch_size
validation_steps = len(validation_inputs) // eval_batch_size
history = self.model.fit(
train_dataset,
epochs=num_epochs,
steps_per_epoch=steps_per_epoch,
validation_data=validation_dataset,
validation_steps=validation_steps,
callbacks=keras_callbacks,
verbose=2)
return history
def save(self, path: str):
"""Save model weights and tokenizer info.
To re-load, pass the path to the constructor instead of the name of a
base model.
Args:
path: directory to save to. Will write several files here.
"""
if not os.path.isdir(path):
os.mkdir(path)
self.tokenizer.save_pretrained(path)
self.model.save_pretrained(path)
def _segment_slicers(self, tokens: List[str]):
"""Slicers along the tokens dimension for each segment.
For tokens ['[CLS]', a0, a1, ..., '[SEP]', b0, b1, ..., '[SEP]'],
we want to get the slices [a0, a1, ...] and [b0, b1, ...]
Args:
tokens: <string>[num_tokens], including special tokens
Returns:
(slicer_a, slicer_b), slice objects
"""
try:
split_point = tokens.index(self.tokenizer.sep_token)
except ValueError:
split_point = len(tokens) - 1
slicer_a = slice(1, split_point) # start after [CLS]
slicer_b = slice(split_point + 1, len(tokens) - 1) # end before last [SEP]
return slicer_a, slicer_b
def _postprocess(self, output: Dict[str, np.ndarray]):
"""Per-example postprocessing, on NumPy output."""
ntok = output.pop("ntok")
output["tokens"] = self.tokenizer.convert_ids_to_tokens(
output.pop("input_ids")[:ntok])
# Tokens for each segment, individually.
slicer_a, slicer_b = self._segment_slicers(output["tokens"])
output["tokens_" + self.config.text_a_name] = output["tokens"][slicer_a]
if self.config.text_b_name:
output["tokens_" + self.config.text_b_name] = output["tokens"][slicer_b]
# Embeddings for each segment, individually.
output["input_embs_" + self.config.text_a_name] = (
output["input_embs"][slicer_a])
if self.config.text_b_name:
output["input_embs_" + self.config.text_b_name] = (
output["input_embs"][slicer_b])
# Gradients for each segment, individually.
if self.config.compute_grads:
output["token_grad_" +
self.config.text_a_name] = output["input_emb_grad"][slicer_a]
if self.config.text_b_name:
output["token_grad_" +
self.config.text_b_name] = output["input_emb_grad"][slicer_b]
if self.is_regression:
output["grad_class"] = None
else:
# Return the label corresponding to the class index used for gradients.
output["grad_class"] = self.config.labels[output["grad_class"]]
# Process attention.
for key in output:
if not re.match(r"layer_(\d+)/attention", key):
continue
# Select only real tokens, since most of this matrix is padding.
# <float32>[num_heads, max_seq_length, max_seq_length]
# -> <float32>[num_heads, num_tokens, num_tokens]
output[key] = output[key][:, :ntok, :ntok].transpose((0, 2, 1))
# Make a copy of this array to avoid memory leaks, since NumPy otherwise
# keeps a pointer around that prevents the source array from being GCed.
output[key] = output[key].copy() # pytype: disable=attribute-error
return output
def _scatter_embs(self, passed_input_embs, input_embs, batch_indices,
offsets):
"""Scatters custom passed embeddings into the default model embeddings.
Args:
passed_input_embs: <tf.float32>[num_scatter_tokens], the custom passed
embeddings to be scattered into the default model embeddings.
input_embs: the default model embeddings.
batch_indices: the indices of the embeddings to replace in the format
(batch_index, sequence_index).
offsets: the offset from which to scatter the custom embedding (number of
tokens from the start of the sequence).
Returns:
The default model embeddings with scattered custom embeddings.
"""
# <float32>[scatter_batch_size, num_tokens, emb_size]
filtered_embs = [emb for emb in passed_input_embs if emb is not None]
# Prepares update values that should be scattered in, i.e. one for each
# of the (scatter_batch_size * num_tokens) word embeddings.
# <np.float32>[scatter_batch_size * num_tokens, emb_size]
updates = np.concatenate(filtered_embs)
# Prepares indices in format (batch_index, sequence_index) for all
# values that should be scattered in, i.e. one for each of the
# (scatter_batch_size * num_tokens) word embeddings.
scatter_indices = []
for (batch_index, sentence_embs, offset) in zip(batch_indices,
filtered_embs, offsets):
for (token_index, emb) in enumerate(sentence_embs):
scatter_indices.append([batch_index, token_index + offset])
# Scatters passed word embeddings into embeddings gathered from tokens.
# <tf.float32>[batch_size, num_tokens + num_special_tokens, emb_size]
return tf.tensor_scatter_nd_update(input_embs, scatter_indices, updates)
def scatter_all_embeddings(self, inputs, input_embs):
"""Scatters custom passed embeddings for text segment inputs.
Args:
inputs: the model inputs, which contain any custom embeddings to scatter.
input_embs: the default model embeddings.
Returns:
The default model embeddings with scattered custom embeddings.
"""
# Gets batch indices of any word embeddings that were passed for text_a.
passed_input_embs_a = [ex.get("input_embs_" + self.config.text_a_name)
for ex in inputs]
batch_indices_a = [index for (index, emb) in enumerate(
passed_input_embs_a) if emb is not None]
# If word embeddings were passed in for text_a, scatter them into the
# embeddings, gathered from the input ids. 1 is passed in as the offset
# for each, since text_a starts at index 1, after the [CLS] token.
if batch_indices_a:
input_embs = self._scatter_embs(
passed_input_embs_a, input_embs, batch_indices_a,
offsets=np.ones(len(batch_indices_a), dtype=np.int64))
if self.config.text_b_name:
# Gets batch indices of any word embeddings that were passed for text_b.
passed_input_embs_b = [ex.get("input_embs_" + self.config.text_b_name)
for ex in inputs]
batch_indices_b = [
index for (index, emb) in enumerate(passed_input_embs_b)
if emb is not None]
# If word embeddings were also passed in for text_b, scatter them into the
# embeddings gathered from the input ids. The offsets are the [lengths
# of the corresponding text_a embeddings] + 2, since text_b starts after
# [CLS] [text_a tokens] [SEP]. (This assumes that text_b embeddings
# will only be passed together with text_a embeddings.)
if batch_indices_b:
lengths = np.array([len(embed) for embed in passed_input_embs_a
if embed is not None])
input_embs = self._scatter_embs(
passed_input_embs_b, input_embs, batch_indices_b,
offsets=(lengths + 2))
return input_embs
##
# LIT API implementation
def max_minibatch_size(self):
return self.config.inference_batch_size
def predict_minibatch(self, inputs: Iterable[JsonDict]):
# Use watch_accessed_variables to save memory by having the tape do nothing
# if we don't need gradients.
with tf.GradientTape(
watch_accessed_variables=self.config.compute_grads) as tape:
encoded_input = self._preprocess(inputs)
# Gathers word embeddings from BERT model embedding layer using input ids
# of the tokens.
input_ids = encoded_input["input_ids"]
word_embeddings = self.model.bert.embeddings.word_embeddings
# <tf.float32>[batch_size, num_tokens, emb_size]
input_embs = tf.gather(word_embeddings, input_ids)
# Scatter in any passed in embeddings.
# <tf.float32>[batch_size, num_tokens, emb_size]
input_embs = self.scatter_all_embeddings(inputs, input_embs)
tape.watch(input_embs) # Watch input_embs for gradient calculation.
model_inputs = encoded_input.copy()
model_inputs["input_ids"] = None
logits, embs, attentions = self.model(model_inputs,
inputs_embeds=input_embs,
training=False)
batched_outputs = {
"input_ids": encoded_input["input_ids"],
"ntok": tf.reduce_sum(encoded_input["attention_mask"], axis=1),
"cls_emb": embs[-1][:, 0], # last layer, first token
"input_embs": input_embs
}
assert len(attentions) == self.model.config.num_hidden_layers
for i, layer_attention in enumerate(attentions):
batched_outputs[f"layer_{i}/attention"] = layer_attention
if self.is_regression:
# <tf.float32>[batch_size]
batched_outputs["score"] = tf.squeeze(logits, axis=-1)
scalar_pred_for_gradients = batched_outputs["score"]
else:
# <tf.float32>[batch_size, num_labels]
batched_outputs["probas"] = tf.nn.softmax(logits, axis=-1)
# If a class for the gradients has been specified in the input,
# calculate gradients for that class. Otherwise, calculate gradients for
# the arg_max class.
arg_max = tf.math.argmax(batched_outputs["probas"], axis=-1).numpy()
grad_classes = [ex.get("grad_class", arg_max[i]) for (i, ex) in
enumerate(inputs)]
# Convert the class names to indices if needed.
grad_classes = [self.config.labels.index(label)
if isinstance(label, str) else label
for label in grad_classes]
gather_indices = list(enumerate(grad_classes))
# <tf.float32>[batch_size]
scalar_pred_for_gradients = tf.gather_nd(batched_outputs["probas"],
gather_indices)
if self.config.compute_grads:
batched_outputs["grad_class"] = tf.convert_to_tensor(grad_classes)
# Request gradients after the tape is run.
# Note: embs[0] includes position and segment encodings, as well as subword
# embeddings.
if self.config.compute_grads:
# <tf.float32>[batch_size, num_tokens, emb_dim]
batched_outputs["input_emb_grad"] = tape.gradient(
scalar_pred_for_gradients, input_embs)
detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()}
# Sequence of dicts, one per example.
unbatched_outputs = utils.unbatch_preds(detached_outputs)
return map(self._postprocess, unbatched_outputs)
def input_spec(self) -> Spec:
ret = {}
ret[self.config.text_a_name] = lit_types.TextSegment()
if self.config.text_b_name:
ret[self.config.text_b_name] = lit_types.TextSegment()
if self.is_regression:
ret[self.config.label_name] = lit_types.RegressionScore(required=False)
else:
ret[self.config.label_name] = lit_types.CategoryLabel(
required=False, vocab=self.config.labels)
# The input_embs_ and grad_class fields are used for Integrated Gradients.
ret["input_embs_" + self.config.text_a_name] = lit_types.TokenEmbeddings(
align="tokens", required=False)
if self.config.text_b_name:
ret["input_embs_" + self.config.text_b_name] = lit_types.TokenEmbeddings(
align="tokens", required=False)
ret["grad_class"] = lit_types.CategoryLabel(required=False,
vocab=self.config.labels)
return ret
def output_spec(self) -> Spec:
ret = {"tokens": lit_types.Tokens()}
ret["tokens_" + self.config.text_a_name] = lit_types.Tokens()
if self.config.text_b_name:
ret["tokens_" + self.config.text_b_name] = lit_types.Tokens()
if self.is_regression:
ret["score"] = lit_types.RegressionScore(parent=self.config.label_name)
else:
ret["probas"] = lit_types.MulticlassPreds(
parent=self.config.label_name,
vocab=self.config.labels,
null_idx=self.config.null_label_idx)
ret["cls_emb"] = lit_types.Embeddings()
# The input_embs_ and grad_class fields are used for Integrated Gradients.
ret["input_embs_" + self.config.text_a_name] = lit_types.TokenEmbeddings(
align="tokens_" + self.config.text_a_name)
if self.config.text_b_name:
ret["input_embs_" + self.config.text_b_name] = lit_types.TokenEmbeddings(
align="tokens_" + self.config.text_b_name)
# Gradients, if requested.
if self.config.compute_grads:
ret["grad_class"] = lit_types.CategoryLabel(required=False,
vocab=self.config.labels)
ret["token_grad_" + self.config.text_a_name] = lit_types.TokenGradients(
align="tokens_" + self.config.text_a_name,
grad_for="input_embs_" + self.config.text_a_name,
grad_target="grad_class")
if self.config.text_b_name:
ret["token_grad_" + self.config.text_b_name] = lit_types.TokenGradients(
align="tokens_" + self.config.text_b_name,
grad_for="input_embs_" + self.config.text_b_name,
grad_target="grad_class")
# Attention heads, one field for each layer.
for i in range(self.model.config.num_hidden_layers):
ret[f"layer_{i}/attention"] = lit_types.AttentionHeads(
align=("tokens", "tokens"))
return ret
class SST2Model(GlueModel):
"""Classification model on SST-2."""
def __init__(self, *args, **kw):
super().__init__(
*args,
text_a_name="sentence",
text_b_name=None,
labels=["0", "1"],
null_label_idx=0,
**kw)
class MNLIModel(GlueModel):
"""Classification model on MultiNLI."""
def __init__(self, *args, **kw):
super().__init__(
*args,
text_a_name="premise",
text_b_name="hypothesis",
labels=["entailment", "neutral", "contradiction"],
**kw)
class STSBModel(GlueModel):
"""Regression model on STS-B."""
def __init__(self, *args, **kw):
super().__init__(
*args,
text_a_name="sentence1",
text_b_name="sentence2",
labels=None,
**kw)