Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FSDP when performing GPTQ-LoRA with Triton V2 #15

Merged
merged 2 commits into from
May 21, 2024
Merged

Conversation

fabianlim
Copy link
Contributor

@fabianlim fabianlim commented May 20, 2024

GPTQ-LoRa depends on the AutoGPTQ package, but there are issues that prevent the base GPTQ model from being FSDPed

The issue comes from that QuantLinear class stores the parameters (i.e. qweight, qzeros) in torch.int32, which results in

  • FSDP passing over these parameters, because FSDP only handles torch.nn.Parameter, and these are torch.Tensor
  • FSDP unable to handle integer types.

The fix is to then use torch.tensor.view, which does a C++ style reinterpret cast in QuantLinear.forward before call in the the QuantLinearFunction autograd function. We create the nn.Parameter in the same vein, by doing a qweight.view(torch_type) to force the parameter to be a of torch_type (which is going to be a float type)

Reproduce

To reproduce the fix, consider the command

export CUDA_VISIBLE_DEVICES=0,1
accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29501 \
-m tuning.sft_trainer --model_name_or_path TheBloke/Nous-Hermes-Llama2-70B-GPTQ --acceleration_framework_config_file /data/flim/fms-acceleration-oss/scripts/benchmarks/../../sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml --packing True --max_seq_len 4096 --learning_rate 2e-4 --fp16 True --torch_dtype float16 --peft_method lora --r 16 --lora_alpha 16 --lora_dropout 0.0 --target_modules q_proj k_proj v_proj o_proj --use_flash_attn True --response_template '
### Response:' --dataset_text_field output --include_tokens_per_second True --num_train_epochs 1 --gradient_accumulation_steps 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --max_steps 100 --training_data_path benchmark_outputs/data/cache.json --per_device_train_batch_size 2 --output_dir benchmark_outputs/exp_1/hf
fix prepare forward backward
N 35 B 48 B 36 B
Y 18 B 32 B 19 B

Losses and Throughputs AFTER FIX

{'loss': 1.0727, 'grad_norm': 0.09820556640625, 'learning_rate': 0.0002, 'epoch': 0.01}
{'loss': 0.9477, 'grad_norm': 0.19384765625, 'learning_rate': 0.00017777777777777779, 'epoch': 0.03}
{'loss': 0.9168, 'grad_norm': 0.07080078125, 'learning_rate': 0.00015555555555555556, 'epoch': 0.04}
{'loss': 0.9182, 'grad_norm': 0.0616455078125, 'learning_rate': 0.00013333333333333334, 'epoch': 0.06}
{'loss': 0.8815, 'grad_norm': 0.056671142578125, 'learning_rate': 0.00011111111111111112, 'epoch': 0.07}
{'loss': 0.9014, 'grad_norm': 0.058685302734375, 'learning_rate': 8.888888888888889e-05, 'epoch': 0.08}
{'loss': 0.8754, 'grad_norm': 0.06451416015625, 'learning_rate': 6.666666666666667e-05, 'epoch': 0.1}
{'loss': 0.8863, 'grad_norm': 0.05322265625, 'learning_rate': 4.4444444444444447e-05, 'epoch': 0.11}
{'loss': 0.8929, 'grad_norm': 0.0596923828125, 'learning_rate': 2.2222222222222223e-05, 'epoch': 0.13}
{'loss': 0.8821, 'grad_norm': 0.057037353515625, 'learning_rate': 0.0, 'epoch': 0.14}
{'train_runtime': 1882.4921, 'train_samples_per_second': 0.212, 'train_steps_per_second': 0.053, 'train_tokens_per_second': 435.168, 'train_loss': 0.9174925422668457, 'epoch': 0.14}

TODO:

  • apply the casting only when FSDP is enabled.
  • consider also making g_idx and scales as parameters so they can be sharded. Update: the code is quite flexibile now and easy to add more parameters
  • not really handing the low_cpu_mem_usage properly. the model is currently unncessarily loading the full model into GPU memory before prepare, which should be avoided.
  • extend this fix to other QuantLinear like marlin, etc (may not do).

@fabianlim fabianlim marked this pull request as draft May 20, 2024 06:51
@wynterl wynterl closed this May 20, 2024
@wynterl wynterl reopened this May 20, 2024
@fabianlim fabianlim changed the base branch from gen-configs to dev May 20, 2024 08:36
@fabianlim fabianlim marked this pull request as ready for review May 20, 2024 14:36
@fabianlim fabianlim merged commit 2003a3e into dev May 21, 2024
2 checks passed
@fabianlim fabianlim deleted the autogptq-fsdp branch May 21, 2024 12:46
fabianlim added a commit that referenced this pull request May 27, 2024
…or GPTQ-LoRA (#20)

* Add GitHub Workflow for Linting , Formatting and Test. Activate Workflow for Framework (#7)

* add lint workflow

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add pylintrc, update .tox fix files

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* activate test and minor fix

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* lint benchmarks.py and add workflow to dev

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* Improvements to Benchmark Scripts and Config Generation Workflow (#13)

* fix benches and add verify configs

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* update readme and add workflow

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add packaging dep

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* update torch dep in framework and run-benches

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* take host env in run-benches

* add display bench results script

* rename summary.csv to raw_summary.csv and update run_benchmarks.sh

* export environment variables in shell command

* dump out pip requirements for repro, and add default FHT_branch

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* Added support for running official HF baseline FSDP-QLoRA benchmark (#16)

* new baseline scenario

* rename variables

* added warning when plugin allows SFTTrainer to handle PEFT on single device

* Fix FSDP when performing GPTQ-LoRA with Triton V2  (#15)

* wrap in parameters and torch view to correct dtype

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* refactor to apply patch only on FSDP and simplify

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* Provide Memory Benchmarking Feature to Benchmarking Code (#14)

* add gpu memory logging support

* made improvements to GPU reference and result collation

* Renamed memory logging argument to reflect its readings as reserved me
mory using nvidia-smi and changed aggregation function in result collation

* variable renames

* manual linting

* added memory logging functionality via HFTrainer

* added support to benchmark memory using HFTrainer and updated READMEwith explanation of the 2 memory benchmarking options

* addressed changes requested in PR #14

* fix bug and smplify gpu logs aggregation logic

* fixes to calculation of HFTrainer Mem Logging values

* fix calculations

* more fixes

* fix to ignore including  stage inside max calculation of alloc memory

* more comments and README updates

* added fix to keyerror due to empty output dict from OOM

* manual linting

* added benchmark results to refs

* remove unnecessary columns in results gathering

* made changes to results gathering

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Co-authored-by: achew010 <[email protected]>
@fabianlim
Copy link
Contributor Author

so because of the casting we are facing this error in #25 now


  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl

    return forward_call(*args, **kwargs)

  File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py", line 75, in _forward_q

    _fused_op(X)

  File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py", line 64, in _fused_op

    Q, K, V = fused_operation(attn, X)

  File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py", line 620, in apply_lora_qkv

    Q, K, V = LoRA_QKV.apply(

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply

    return super().apply(*args, **kwargs)  # type: ignore[misc]

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd

    return fwd(*args, **kwargs)

  File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py", line 464, in forward

    QW = dequant248(Q_qweight, Q_scales, Q_qzeros, Q_g_idx, Q_bits)

  File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/kernels.py", line 137, in dequant248

    dequant_kernel_248[grid](

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in run

    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>

    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 122, in _bench

    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/testing.py", line 102, in do_bench

    fn()

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 110, in kernel_call

    self.fn.run(

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/jit.py", line 532, in run

    self.cache[device][key] = compile(

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/compiler/compiler.py", line 543, in compile

    next_module = compile_kernel(module)

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/compiler/compiler.py", line 435, in <lambda>

    ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1237, in ast_to_ttir

    raise CompilationError(fn.src, node, repr(e)) from e

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Memory Consumption for GPTQ-LoRA is higher than QLoRA in Distributed Finetuning
3 participants