Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extracted Subset of AutoGPTQ library into Accelerated-Peft Plugin #48

Merged
merged 20 commits into from
Jul 15, 2024

Conversation

achew010
Copy link
Contributor

@achew010 achew010 commented Jul 3, 2024

Description

This PR addresses #38 and extracts a subset of GPTQModel, a refactored fork of AutoGPTQ into fms_acceleration_peft/src/gptqmodel to do away with the problematic installation of AutoGPTQ.

This is because

  • AutoGPTQ hasn't had a release since Mar'24 and newer additions such as triton quantization kernels can only be used with custom installation from the main branch.
  • Installation of the main branch is unnecessarily slow given that this plugin only uses the triton kernel and a bunch of model code while the bulk of the time is spent compiling the cuda kernels. This also allows for easier testing and building of the plugin without the need for dependencies like cudatoolkit.

Additions

  • new folder in src/gptqmodel containing extracted code
  • modified autogptq plugin with compatibility between using external AutoGPTQ library (if installed and available) or the local refactored subset gptqmodel
  • unit tests in tests/test_gptq_model.py to ensure the extracted subset maintains the same behaviour as the original
  • fixes to FOAK plugin for compatibility to local extracted gptq package
  • comparison tool against reference benchmarks

Issues:

  1. Comparing new benchmarks against our current reference scripts/benchmarks/ref, we noticed a non-zero lora dropout will incur some memory overhead that make experiments for large models run out of memory (elaborated in Quantized Peft Benchmark Experiments Run Out of Memory with Non-Zero Lora Dropout #50). The comparison tool will pick this difference in experiment result as an outlier but will also flag out the parameter change in the report.

  2. Temporary fix to FOAK dequantization triton kernel to only offset if using official AutoGPTQ package.

    • Without the fix, the dequantization produces wrong base outputs that affect the loss badly when FOAK plugin is used with the local AutoGPTQ package
    • The reason is FOAK runs a dequant function not compatible with the local package in acceleration_peft. FOAK Plugin currently maintains its own triton kernels e.g. dequantization (similar with official AutoGPTQ) for GPTQ fused ops. But offset was removed in the dequantization function for our local package (see here).
    • A more permanent fix would be for FOAK plugin to rely on the accelerated_peft plugin to manage the dequantization function to use (local autogptq package or official autogptq) rather than maintaining a similar set of functions itself. Introduce a Better Dequantization Fix on Triton Function for FOAK Plugin's GPTQ Fused Operations #52 has been created to follow up on this .

Benchmarks

There seems to be an improvement to throughput with the new library on FOAK. Comparing the throughput from our reference against the updated benches on Mistral-7B-GPTQ and Llama70B-GPTQ in the table below.

  1. We see similar throughput to previous reference throughput for accelerated-peft-autogptq plugin

  2. We see higher throughput on the FOAK rows

  • Mistral-7B-GPTQ

    • 3954 -> 4101 tokens (1 gpu, bs=4)
    • 3911 -> 4040 tokens (2 gpu, bs=4)
  • Llama2-70B-GPTQ

    • 500 -> 554 tokens (1 gpu, bs=4)
    • 496 -> 551 (2 gpu, bs=4)

Mistral-7B-GPTQ

Reference

model_name_or_path framework_config num_gpus batch_size tokens_per_second mem_alloc_in_GIB peak_mem_alloc_in_GIB
TheBloke/Mistral-7B-v0.1-GPTQ accelerated-peft-autogptq 1 4 3332 4.87 15.8
TheBloke/Mistral-7B-v0.1-GPTQ accelerated-peft-autogptq 2 4 3236 2.74 16.0
TheBloke/Mistral-7B-v0.1-GPTQ accelerated-peft-autogptq-foak 1 4 3954 4.87 13.6
TheBloke/Mistral-7B-v0.1-GPTQ accelerated-peft-autogptq-foak 2 4 3911 2.74 15.6

Updated

model_name_or_path framework_config num_gpus batch_size tokens_per_second mem_alloc_in_bytes peak_mem_alloc_in_GIB
TheBloke/Mistral-7B-v0.1-GPTQ accelerated-peft-autogptq 1 4 3404 4.87 15.9
TheBloke/Mistral-7B-v0.1-GPTQ accelerated-peft-autogptq 2 4 3293 2.79 16.7
TheBloke/Mistral-7B-v0.1-GPTQ accelerated-peft-autogptq-foak 1 4 3965 4.87 13.6
TheBloke/Mistral-7B-v0.1-GPTQ accelerated-peft-autogptq-foak 2 4 3944 2.79 16.1

Llama70B-GPTQ

Reference

model_name_or_path framework_config num_gpus batch_size tokens_per_second mem_alloc_in_GIB peak_mem_alloc_in_GIB
TheBloke/Llama-2-70B-GPTQ accelerated-peft-autogptq 1 4 450 36.2 65.8
TheBloke/Llama-2-70B-GPTQ accelerated-peft-autogptq 2 4 444 18.1 70.1
TheBloke/Llama-2-70B-GPTQ accelerated-peft-autogptq-foak 1 4 500 36.2 65.0
TheBloke/Llama-2-70B-GPTQ accelerated-peft-autogptq-foak 2 4 496 18.1 69.4

Updated

model_name_or_path framework_config num_gpus batch_size tokens_per_second mem_alloc_in_bytes peak_mem_alloc_in_GIB
TheBloke/Llama-2-70B-GPTQ accelerated-peft-autogptq 1 4 455 36.2 67.2
TheBloke/Llama-2-70B-GPTQ accelerated-peft-autogptq 2 4 446 18.3 71.7
TheBloke/Llama-2-70B-GPTQ accelerated-peft-autogptq-foak 1 4 499 36.2 66.1
TheBloke/Llama-2-70B-GPTQ accelerated-peft-autogptq-foak 2 4 495 18.1 70.5

Unit Tests

=================================================================================================================== test session starts ===================================================================================================================
platform linux -- Python 3.10.12, pytest-8.2.2, pluggy-1.5.0
rootdir: /data/aaron/experimental/fms-acceleration/plugins/accelerated-peft
configfile: pyproject.toml
collected 7 items                                                                                                                                                                                                                                         

tests/test_gptqmodel.py ..                                                                                                                                                                                                                          [ 28%]
tests/test_peft_plugins.py ..                                                                                                                                                                                                                       [ 57%]
tests/test_q4_triton.py ..                                                                                                                                                                                                                          [ 85%]
tests/test_triton.py .                                                                                                                                                                                                                              [100%]

==================================================================================================================== warnings summary =====================================================================================================================
.tox/py/lib/python3.10/site-packages/transformers/utils/hub.py:124
  /data/aaron/experimental/fms-acceleration/plugins/accelerated-peft/.tox/py/lib/python3.10/site-packages/transformers/utils/hub.py:124: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
    warnings.warn(

tests/test_gptqmodel.py::test_pre_quantized_model_outputs_match
tests/test_gptqmodel.py::test_quantizing_pretrained_model_outputs_match
tests/test_gptqmodel.py::test_quantizing_pretrained_model_outputs_match
tests/test_q4_triton.py::TestsQ4Triton::test_generation_desc_act_false
tests/test_q4_triton.py::TestsQ4Triton::test_generation_desc_act_true
tests/test_triton.py::TestTriton::test_triton_qlinear
  /data/aaron/experimental/fms-acceleration/plugins/accelerated-peft/.tox/py/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
    warnings.warn(

tests/test_gptqmodel.py::test_pre_quantized_model_outputs_match
  /data/aaron/experimental/fms-acceleration/plugins/accelerated-peft/.tox/py/lib/python3.10/site-packages/auto_gptq/utils/peft_utils.py:360: UserWarning: You can just ignore this warning if the peft type you use isn't in ['LORA', 'ADALORA'].
  LlamaGPTQForCausalLM supports injecting fused attention but not enables this time. If you are training adapters, you must also disable fused attention injection when loading quantized base model at inference time, otherwise adapters may not be added to base model properly. If you are loading adapters to do inference, you can reference to adapter's config file to check whether the adapters are trained using base model that not enable fused attention injection.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================================== 7 passed, 8 warnings in 130.42s (0:02:10) ========================================================================================================

Comparison Tool

The tool compares the set of benchmark results against a previous reference. It generates a chart for every metric compared (e.g. train_loss, train_tokens_per_second, mem_alloc...) as well as a csv file of outliers that are significantly different from the reference.

Usage

python scripts/benchmarks/compare_with_reference.py \
--result_dir $BENCHMARK_RESULTS_DIR \
--reference_benchmark_filepath $REFERENCE_BENCHMARK_CSV_FILEPATH

Chart:

Generally we see the new benchmark results from the extracted gptq package (New axis) match closely with that of the previous benchmark using the official autogptq package (Ref axis).

Table:

In the table below, the values from the reference column refer to values seen in previous benchmarks and values from the new column refer to values seen in the current benchmark. Outliers will have significant difference between the 2 columns. The outliers seen below are reported outliers due to the OOM issue in #50.

Note: Any hyperparameter difference between the new bench results and the reference will be the rightmost columns appended at the back following reference and new.

outlier.csv

Copy link
Contributor

@fabianlim fabianlim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks quite good overall, but requesting first round of changes

plugins/accelerated-peft/pyproject.toml Outdated Show resolved Hide resolved
plugins/accelerated-peft/pyproject.toml Outdated Show resolved Hide resolved
plugins/accelerated-peft/requirements.txt Outdated Show resolved Hide resolved
plugins/accelerated-peft/src/gptqmodel/__init__.py Outdated Show resolved Hide resolved
plugins/accelerated-peft/tests/test_gptqmodel.py Outdated Show resolved Hide resolved
plugins/accelerated-peft/src/gptqmodel/utils/peft.py Outdated Show resolved Hide resolved
plugins/framework/tox.ini Outdated Show resolved Hide resolved
@achew010 achew010 marked this pull request as ready for review July 4, 2024 09:29
@fabianlim
Copy link
Contributor

fabianlim commented Jul 4, 2024

@achew010 this needs a formatting, and some bench results

Copy link
Contributor

@fabianlim fabianlim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved

@fabianlim fabianlim merged commit c3a069c into foundation-model-stack:main Jul 15, 2024
4 checks passed
@achew010 achew010 mentioned this pull request Jul 15, 2024
@achew010 achew010 deleted the extracted_autogptq branch July 26, 2024 04:02
fabianlim added a commit that referenced this pull request Nov 5, 2024
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
fabianlim added a commit that referenced this pull request Nov 8, 2024
* remove skip on test now #48 is complete

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix fusedops test

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix model patching in test

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix test to tail on input grads

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix dropout in fused_lora

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fmt + lint

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants