forked from PAIR-code/lit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathxnli_demo.py
79 lines (61 loc) · 2.85 KB
/
xnli_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Lint as: python3
r"""Example demo for multilingual NLI on the XNLI eval set.
To run locally with our trained model:
python -m lit_nlp.examples.xnli_demo --port=5432
Then navigate to localhost:5432 to access the demo UI.
To train a model for this task, use tools/glue_trainer.py or your favorite
trainer script to fine-tune a multilingual encoder, such as
bert-base-multilingual-cased, on the mnli task.
Note: the LIT UI can handle around 10k examples comfortably, depending on your
hardware. The monolingual (english) eval sets for MNLI are about 9.8k each,
while each language for XNLI is about 2.5k examples, so we recommend using the
--languages flag to load only the languages you're interested in.
"""
from absl import app
from absl import flags
from absl import logging
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.examples.datasets import classification
from lit_nlp.examples.datasets import glue
from lit_nlp.examples.models import glue_models
import transformers # for path caching
# NOTE: additional flags defined in server_flags.py
FLAGS = flags.FLAGS
flags.DEFINE_list(
"languages", ["en", "es", "hi", "zh"],
"Languages to load from XNLI. Available languages: "
"ar,bg,de,el,en,es,fr,hi,ru,sw,th,tr,ur,zh,vi")
flags.DEFINE_string(
"model_path",
"https://storage.googleapis.com/what-if-tool-resources/lit-models/mbert_mnli.tar.gz",
"Path to fine-tuned model files. Expects model to be in standard "
"transformers format, e.g. as saved by model.save_pretrained() and "
"tokenizer.save_pretrained().")
flags.DEFINE_integer(
"max_examples", None, "Maximum number of examples to load into LIT. "
"Note: MNLI eval set is 10k examples, so will take a while to run and may "
"be slow on older machines. Set --max_examples=200 for a quick start.")
def main(_):
# Normally path is a directory; if it's an archive file, download and
# extract to the transformers cache.
model_path = FLAGS.model_path
if model_path.endswith(".tar.gz"):
model_path = transformers.file_utils.cached_path(
model_path, extract_compressed_file=True)
models = {"nli": glue_models.MNLIModel(model_path, inference_batch_size=16)}
datasets = {
"xnli": classification.XNLIData("validation", FLAGS.languages),
"mnli_dev": glue.MNLIData("validation_matched"),
"mnli_dev_mm": glue.MNLIData("validation_mismatched"),
}
# Truncate datasets if --max_examples is set.
for name in datasets:
logging.info("Dataset: '%s' with %d examples", name, len(datasets[name]))
datasets[name] = datasets[name].slice[:FLAGS.max_examples]
logging.info(" truncated to %d examples", len(datasets[name]))
# Start the LIT server. See server_flags.py for server options.
lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
lit_demo.serve()
if __name__ == "__main__":
app.run(main)