-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
2,454 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Files to ignore. | ||
.git | ||
.idea | ||
.vscode | ||
.ipynb_checkpoints | ||
alphanli | ||
!alphanli/dev.jsonl | ||
!alphanli/dev-labels.lst | ||
runs | ||
ckpts | ||
__pycache__ | ||
*.py[cod] | ||
*.ipynb | ||
*.sh | ||
!run_model.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
alphanli | ||
|
||
.idea/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
FROM pytorch/pytorch:1.0.1-cuda10.0-cudnn7-runtime | ||
|
||
WORKDIR /workspace/anli | ||
ENV LC_ALL=C.UTF-8 | ||
ENV LANG=C.UTF-8 | ||
|
||
# Install dependencies. | ||
COPY ./requirements.txt ./ | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
|
||
# Copy remaining code. | ||
COPY . . | ||
RUN chmod +x *.sh && \ | ||
mkdir /results | ||
|
||
# Run code. | ||
#CMD ["python", "-u"] | ||
CMD ["/bin/bash"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,118 @@ | ||
# L2R2 | ||
|
||
This repository provides author's implementation of [L2R2: Leveraging Ranking for Abductive Reasoning](https://arxiv.org/abs/2005.11223) (SIGIR 2020) | ||
PyTorch implementation of L2R2: Leveraging Ranking for Abductive Reasoning. | ||
|
||
## Usage | ||
|
||
### Set up environment | ||
|
||
L2R2 is tested on Python 3.6 and PyTorch 1.0.1. | ||
|
||
```shell script | ||
$ pip install -r requirements.txt | ||
``` | ||
|
||
### Prepare data | ||
|
||
[αNLI](https://leaderboard.allenai.org/anli/submissions/get-started) | ||
```shell script | ||
$ wget https://storage.googleapis.com/ai2-mosaic/public/alphanli/alphanli-train-dev.zip | ||
$ unzip -d alphanli alphanli-train-dev.zip | ||
``` | ||
|
||
### Training | ||
|
||
We train the L2R2 models on 4 K80 GPUs. The appropriate batch size on each K80 is 1, so the batch size in our experiment is 4. | ||
|
||
The available `criterion` for optimization could selected in: | ||
- list_net: list-wise *KLD* loss used in ListNet | ||
- list_mle: list-wise *Likelihood* loss used in ListMLE | ||
- approx_ndcg: list-wise *ApproxNDCG* loss used in ApproxNDCG | ||
- rank_net: pair-wise *Logistic* loss used in RankNet | ||
- hinge: pair-wise *Hinge* loss used in Ranking SVM | ||
- lambda: pair-wise *LambdaRank* loss used in LambdaRank | ||
|
||
Note that in our experiment, we manually reduce the learning rate instead of using any automatic learning rate scheduler. | ||
|
||
For example, we first fine-tune the pre-trained RoBERTa-large model for up to 10 epochs with a learning rate of 5e-6 and save the model checkpoint which performs best on the dev set. | ||
```shell script | ||
$ python run.py \ | ||
--data_dir=alphanli/ \ | ||
--output_dir=ckpts/ \ | ||
--model_type='roberta' \ | ||
--model_name_or_path='roberta-large' \ | ||
--linear_dropout_prob=0.6 \ | ||
--max_hyp_num=22 \ | ||
--tt_max_hyp_num=22 \ | ||
--max_seq_len=72 \ | ||
--do_train \ | ||
--do_eval \ | ||
--criterion='list_net' \ | ||
--per_gpu_train_batch_size=1 \ | ||
--per_gpu_eval_batch_size=1 \ | ||
--learning_rate=5e-6 \ | ||
--weight_decay=0.0 \ | ||
--num_train_epochs=10 \ | ||
--seed=42 \ | ||
--log_period=50 \ | ||
--eval_period=100 \ | ||
--overwrite_output_dir | ||
``` | ||
|
||
Then, we continue to fine-tune the just saved model for up to 3 epochs with a smaller learning rate, such as 3e-6, 1e-6 and 5e-7, until the performance on the dev set is no longer improved. | ||
```shell script | ||
python run.py \ | ||
--data_dir=alphanli/ \ | ||
--output_dir=ckpts/ \ | ||
--model_type='roberta' \ | ||
--model_name_or_path=ckpts/H22_L72_E3_B4_LR5e-06_WD0.0_MMddhhmmss/checkpoint-best_acc/ \ | ||
--linear_dropout_prob=0.6 \ | ||
--max_hyp_num=22 \ | ||
--tt_max_hyp_num=22 \ | ||
--max_seq_len=72 \ | ||
--do_train \ | ||
--do_eval \ | ||
--criterion='list_net' \ | ||
--per_gpu_train_batch_size=1 \ | ||
--per_gpu_eval_batch_size=1 \ | ||
--learning_rate=1e-6 \ | ||
--weight_decay=0.0 \ | ||
--num_train_epochs=3 \ | ||
--seed=43 \ | ||
--log_period=50 \ | ||
--eval_period=100 \ | ||
--overwrite_output_dir | ||
``` | ||
Note: change the seed to reshuffle training samples. | ||
|
||
### Evaluation | ||
|
||
Evaluate the performance on the dev set. | ||
```shell script | ||
$ export MODEL_DIR="ckpts/H22_L72_E3_B4_LR5e-07_WD0.0_MMddhhmmss/checkpoint-best_acc/" | ||
$ python run.py \ | ||
--data_dir=alphanli/ \ | ||
--output_dir=$MODEL_DIR \ | ||
--model_type='roberta' \ | ||
--model_name_or_path=$MODEL_DIR \ | ||
--max_hyp_num=2 \ | ||
--max_seq_len=72 \ | ||
--do_eval \ | ||
--per_gpu_eval_batch_size=1 | ||
``` | ||
|
||
### Inference | ||
```shell script | ||
$ ./run_model.sh | ||
``` | ||
|
||
## Citation | ||
``` | ||
@article{zhu2020l2r2, | ||
title={L2R2: Leveraging Ranking for Abductive Reasoning}, | ||
author={Zhu, Yunchang and Pang, Liang and Lan, Yanyan and Cheng, Xueqi}, | ||
journal={arXiv preprint arXiv:2005.11223}, | ||
year={2020} | ||
} | ||
``` | ||
|
Oops, something went wrong.