We reduce the size of GPT2 by compressing the MLPs weights (practically two thirds? of the model) of each attention layer with Kronecker Products. Effectively down-sizing GPT2 from 124M parameters to two variants, the 81M and 95M parameters models. Compared to other papers:
We train 3 classes of models, 67M (being the smallest we can get), 81M (mid size), and 95M (highest model we can get). Below are some "numbers".
Krony-PT (81M) outperforms DistilGPT (82M) on all benchmarks, and especilally on the Lambada dataset.
# Params | Model | wikitext-103 | wikitext-2 | Lambada |
---|---|---|---|---|
124M | GPT2 | 29.16 | 24.67 | 45.28 |
82M | DistilGPT2 | 44.53 | 36.48 | 76.00 |
81M | KronyPT-81M-1350 | 41.98 | 34.99 | - |
81M | KronyPT-81M-3950 | - | - | 64.92 |
Our 81M model performs on par with other Kronecker based models (x,y,z papers), while having 39M parameters less. Even outperforming KnGPT on Lambada.
# Params | Model | wikitext-103 | wikitext-2 | Lambada |
---|---|---|---|---|
81M | KronyPT-81M-1350 | 41.98 | 34.99 | - |
81M | KronyPT-81M-3950 | - | - | 64.92 |
119M(*) | TQCompressedGPT2 | 40.28 | 32.25 | 64.72 |
119M(*) | KnGPT-2 (Huawei) | 40.97 | 32.81 | 67.62 |
Here we compare different initialization strategies: Van Loan (VL) and a (new) prunning based init. (add the results for the prune based method).
Model | wikitext-103 | wikitext-2 | Lambada |
---|---|---|---|
95M - VL | 41.80 | 35.50 | 61.34 |
95M - prune | _ | _ | _ |
- Clone the repository.
- Create the data: check
./data/owt/prepare.py
We solely use Open Web Text (owt) for training.
- Generate a valid Kronecker decomposition: use script
kron_decompose.py
, specify the dimensions, and number of factors.
$ python kron_decompose.py dim_1 dim_2 n_factors
-
Update your training configuration at
./config/train_gpt2
- Change training specifications to your need.
-
Train the model:
$ python train.py config/train_gpt2.py
- After training, you should have a checkpoint (say
./checks/my_checkpoint.pt
) ready for evaluation.
Assuming you have your Kronecker checkpoint stored at ./checks/my_checkpoint.pt
- Convert the
KronyPT
to a GPT-like format.
$ python krony_to_gpt.py ./path/to/check.pt output_dir
This would convert your KronyPT
model to a suitable GPt-like format stored at out_dir
. The dimensions and number of factors are inferred directly from the checkpoint, hence no need to be provided.
- Test the perplexity for
wikitext
andlambada
:
$ python perplexity_eval.py output_dir wiki103
$ python perplexity_eval.py output_dir # to evaluate on all 3 datasets
You have 3 options: wiki103
, wiki1
, lambada
. Not specifying which dataset would return the perplexity for all 3.
- Check the file:
script/progress.md
- Add link to report.