You'll need Python >= 3.6 to run the code in this repo.
Install the dependencies:
pip install --upgrade pip
pip install -r requirements.txt
If you're running on a shared machine and don't have the privileges to install Python packages globally, or if you just don't want to install these packages permanently, take a look at the "Virtual environments" section further down in the README.
To make sure pip is installing packages for the right Python version, run pip --version
and check that the path it reports is for the right Python interpreter.
To train an ELECTRA-small model on the SNLI natural language inference dataset, you can run the following command:
python3 run.py --do_train --task nli --dataset snli --output_dir ./trained_model/
Checkpoints will be written to sub-folders of the trained_model
output directory.
To evaluate the final trained model on the SNLI dev set, you can use
python3 run.py --do_eval --task nli --dataset snli --model ./trained_model/ --output_dir ./eval_output/
To prevent run.py
from trying to use a GPU for training, pass the argument --no_cuda
.
To train/evaluate a question answering model on SQuAD instead, change --task nli
and --dataset snli
to --task qa
and --dataset squad
.
Descriptions of other important arguments are available in the comments in run.py
.
Data and models will be automatically downloaded and cached in ~/.cache/huggingface/
.
To change the caching directory, you can modify the shell environment variable HF_HOME
or TRANSFORMERS_CACHE
.
For more details, see this doc.
An ELECTRA-small based NLI model trained on SNLI for 3 epochs (e.g. with the command above) should achieve an accuracy of around 89%, depending on batch size. An ELECTRA-small based QA model trained on SQuAD for 3 epochs should achieve around 78 exact match score and 86 F1 score.
This repo uses Huggingface Datasets to load data.
The Dataset objects loaded by this module can be filtered and updated easily using the Dataset.filter
and Dataset.map
methods.
For more information on working with datasets loaded as HF Dataset objects, see this page.