Skip to content

Latest commit

 

History

History
90 lines (80 loc) · 4.27 KB

File metadata and controls

90 lines (80 loc) · 4.27 KB

Adaptive Span

Adaptive Span is a novel self-attention mechanism that can learn its optimal attention span. This allows us to extend significantly the maximum context size used in Transformer, while maintaining control over their memory footprint and computational time. It uses the Truncated BPTT technique for training, as in transformerXL.

Adaptive Span was introduced by paper: Adaptive Attention Span in Transformers, which achieved state-of-the-art language modeling results at the time of publication.

We manage to reproduce their result in fairseq and keep most of the original implementation untouched. You can refer to the their sweep file as well if any combination of hyperparameter is not clear.

0. Setup

First you need to process the Enwik8 dataset, we use the pre-tokenized dataset from adaptive span paper. You can download the dataset, and then run:

fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
    --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
    --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
1. Train a Adaptive Span model on Enwik8

We will train a 12-layer Adaptive Span model following the hyperparameters used in the original paper.

The following command assumes 4 GPUs, so that the total batch size is 64 sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:

CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
    --user-dir examples/adaptive_span \
    --data  ~/data/enwik8/data-bin/ \
    --fp16 --fp16-no-flatten-grads --max-update 600000 \
    --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
    --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
    --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
    --validate-interval-updates 1000 \
    --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
    --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
    --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07

This should land around 1.05 on validation, 1.03 on test. You can lower the --aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc improvement to the transformerXL baseline here. If training on a single GPU, set --update-freq=4 to accumulate 4x gradients and simulate training on 4 GPUs. You can also reproduce the transformerXL result on enwik8 using this code base. It should land around 1.06 on test,matching the original paper. You can try by

CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
    --user-dir examples/truncated_bptt \
    ~/data/enwik8/data-bin/ \
    --task truncated_bptt_lm  --fp16 --max-update 400000 \
    --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
    --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
    --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
    --lr-scheduler cosine --warmup-updates 0 \
    --lr 0.0 --lr 0.00025 --batch-size 15 \
    --update-freq 1 --seed 2 --log-format json --log-interval 25 \
    --fp16
2. Evaluate

For Adaptive Span:

fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
 --user-dir examples/adaptive_span \
 --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test

For Transformer-XL evaluation:

fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
    --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
    --tokens-per-sample 80 \
    --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
    --gen-subset valid

Note: During training the model saw 512 tokens of context (--tokens-per-sample=512), with batch size 8. These settings match the evaluation settings from the original paper.