This repo containes python script extracting the attentions from existing BERT-based NLP models, visualizing the sparsity of the multi-head attention and pruning/quantizing the attention to see how it affects the performance. The results are analyzed in our paper "On the Distribution, Sparsity, and Inference-time Quantization of Attention Values in Transformers" at ACL Findings 2021.
git clone --recursive https://github.com/StonyBrookNLP/spiqa.git
cd spiqa
conda env create -f environment.yml
conda activate nlp
cd transformers
pip install -e .
cd ..
The model will be automatically downloaded here
Feature | size |
---|---|
#layers of multi-head attention | 12 |
#heads per layer | 12 |
max length of tokens | 320 |
attention dim | 320x320 |
context+question size | 600~700 |
Dataset | squad-dev-v1.1 |
The file roberta_squad_analyzer.py
can be used to extract and analyze the attention on the SQuAD v1.1 dataset.
The usage is shown below:
usage: roberta_squad_analyzer.py [-h] [-at ATT_THRESHOLD] [-ht HS_THRESHOLD] [-d] [-e] [-m] [-s] [-qv] [-od] [-hs] [-sa SAMPLES] [-aq ATT_QUANT_BITS] [-hq HSTATE_QUANT_BITS]
roberta squad analyzer: analyzer sparsity of the roberta on squad
optional arguments:
-h, --help show this help message and exit
-at ATT_THRESHOLD, --att_threshold ATT_THRESHOLD
set attention sparsity threshold
-ht HS_THRESHOLD, --hs_threshold HS_THRESHOLD
set hidden states sparsity threshold
-d, --distribution print histogram
-e, --evaluation evaluate model only without any plot
-m, --heatmap print heatmap
-s, --sparsity compute sparsity
-qv, --quant_visualize
quantize the attention
-od, --otf_distribution
print attention histogram without saving aggregrated params
-hs, --hidden_states print hidden states histogram without saving aggregrated params
-sa SAMPLES, --samples SAMPLES
number of samples for distribution
-aq ATT_QUANT_BITS, --att_quant_bits ATT_QUANT_BITS
base for attention quantization
-hq HSTATE_QUANT_BITS, --hstate_quant_bits HSTATE_QUANT_BITS
base for hidden states quantization
For example, to extract 100 instances' attention value:
python roberta_squad_analyzer.py -e -sa 100
# results will be in ./params/attention_sampled.npy
to plot all the attention values' distribution in the dataset:
python roberta_squad_analyzer.py -od
# results will be in ./res_fig/at_hist_per_token_layer_N_head_M.png
to collect the pruned attention for 100 instances with the threshold as 0.001:
python roberta_squad_analyzer.py -e -sa 100 -at 0.001
to collect the pruned and quantized attention for 100 instances with the threshold as 0.001 and quantization bits as 3:
python roberta_squad_analyzer.py -e -sa 100 -at 0.001 -aq 3
Their are different ways to quantize the attention:
function call | quantization method name | assigned value |
---|---|---|
quantize_attention_linear_slinear(attention, bits) | uniform quantization, linear scale, w/o pruning | the closest bin edge to the original |
quantize_attention_linear_slinear_midval(att, bits) | uniform quantization, linear scale, w/o pruning | mid-point of the bin edges |
quantize_attention_linear_slinear_clamped(att, bits) | uniform quantization, linear scale, w/ pruning | the closest bin edge to the original |
quantize_attention_linear_slinear_clamped_midval(att, bits) | uniform quantization, linear scale, w/ pruning | mid-point of the bin edges |
quantize_attention_linear_slog(att, bits) | uniform quantization, log scale, w/o pruning | the closest bin edge to the original |
quantize_attention_linear_slog_midval(att, bits) | uniform quantization, log scale, w/o pruning | mid-point of the bin edges |
quantize_attention_linear_slog_clamped_midval(att, bits) | uniform quantization, log scale, w/ pruning | mid-point of the bin edges |
quantize_attention_uniform_slinear_clamped_mean(att, bits) | uniform quantization, linear scale, w/ pruning | average of the original in each bin |
quantize_attention_uniform_slog_clamped_mean(att, bits) | uniform quantization, log scale, w/ pruning | average of the original in each bin |
quantize_attention_uniform_slog_mean(att, bits) | uniform quantization, log scale, w/o pruning | average of the original in each bin |
quantize_attention_binarization(att, bits=1) | binarization | 0 or 1.0/(number of values > threshold) |
replace the quantization function call with the desired quantization method:
# replace quantize_attention_linear_slog_clamped_midval as needed
attention_probs = self.quantize_attention_linear_slog_clamped_midval(attention_probs, quantize)
To help reproducing our results on how the quantization affects the model accuracy, we provide our experiment data online. The way to reproduce the figures is:
./download_quant_params.sh
python xmodel_comparator.py
- We observed the high levels of inherent sparsity in the attention distributions, which widely exists in the heads and layers
- most attention values can be pruned (i.e. set to zero) and the remaining non-zero values can be mapped to a small number of discrete-levels (i.e. unique values) without any significant impact on accuracy. Approximately 80% of the values can be set to zero without significant impact on the accuracy for QA and sentiment analysis tasks.
- when we add quantization utilizing a log-scaling, we find a 3-bit discrete representation is sufficient to achieve accuracy within 1% of using the full floating points of the original model.
@inproceedings{ji-etal-2021-on,
title = "{O}n the {D}istribution, {S}parsity, and
{I}nference-time {Q}uantization of {A}ttention {V}alues in {T}ransformers",
author = "Ji, Tianchu and
Jain, Shraddhan and
Ferdman, Michael and
Milder, Peter and
Schwartz, H. Andrew and
Balasubramanian, Niranjan",
booktitle = "Findings of the Association for Computational Linguistics: ACL 2021",
month = aug,
year = "2021",
address = "Online",
publisher = "Association for Computational Linguistics"
}