Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add liger kernel with fused cross entropy loss #93

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

anhuong
Copy link
Collaborator

@anhuong anhuong commented Oct 16, 2024

No description provided.

@anhuong anhuong force-pushed the fused-cross-entropyloss-simplified branch from e7e7c3d to 05cdbe6 Compare October 16, 2024 23:30
Comment on lines 300 to 302
# TODO: how to add diff docstrings for diff model types? what if the loss functions aren't the same across models?
# @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: do we need thisadd_start_docstrings_to_model_forward header that exists in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1133 but is different for each model type?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thinik dont need

Comment on lines 409 to 412
# TODO: differing line below in granite models compared to llama/mistral model type
# logits = logits / self.config.logits_scaling
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward methods are very similar for llama, granite, and mistral however the above line as you can see is within granite here: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/granite/modeling_granite.py#L1102 --> how to handle this difference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the naive method is to copy and modify multiple forward functions for the fused kernel. The better solution is the other design

Comment on lines 440 to 444
# TODO: is adding a separate copy of lce_forward() the right path or should the additional logic for Moe models be in the single lce_forward?
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
# Ignore copy
def lce_forward_mixtral(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a different forward method for mixtral since it was different than the others in line 550-561 and in the object returned is of different type. Is adding a separate forward method the best way to handle this?

@@ -40,6 +41,7 @@ def get_mp_rules(base_type: str):
try:
# Third Party
from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel
GraniteForCausalLM,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing granite here refers to teh GraniteForCausalLM models and the granite with llama architecture models would be handled within the llama.py correct? I'm not sure if the liger kernels work for any of the granite models, do you know?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats the problem with this solution, because we copy and paste the forward functions wholesale. so if the forwards are different this doesnt work

@anhuong
Copy link
Collaborator Author

anhuong commented Oct 16, 2024

Will update with benchmark data as I gather them, currently I do see the new kernel being used:

***** FMS AccelerationFramework *****
INFO:framework.py:***** FMS AccelerationFramework *****
Active Plugin: FastKernelsAccelerationPlugin. Python package: fms_acceleration_foak. Version: 0.2.1.dev0.
INFO:framework.py:Active Plugin: FastKernelsAccelerationPlugin. Python package: fms_acceleration_foak. Version: 0.2.1.dev0.
***************** Module Forwards Patching *************
INFO:framework.py:***************** Module Forwards Patching *************
Rule: llama-fused-lce Module:                           Class: LlamaForCausalLM Num:  1
INFO:framework.py:Rule: llama-fused-lce Module:                           Class: LlamaForCausalLM Num:  1
Rule: llama-rms       Module: input_layernorm           Class: LlamaRMSNorm    Num: 32
INFO:framework.py:Rule: llama-rms       Module: input_layernorm           Class: LlamaRMSNorm    Num: 32
Rule: llama-rms       Module: model                     Class: LlamaRMSNorm    Num:  1
INFO:framework.py:Rule: llama-rms       Module: model                     Class: LlamaRMSNorm    Num:  1
Rule: llama-rms       Module: post_attention_layernorm  Class: LlamaRMSNorm    Num: 32
INFO:framework.py:Rule: llama-rms       Module: post_attention_layernorm  Class: LlamaRMSNorm    Num: 32
Rule: llama-rope      Module:                           Class: LlamaForCausalLM Num:  1
INFO:framework.py:Rule: llama-rope      Module:                           Class: LlamaForCausalLM Num:  1
***************** Accelerator Patching *************
INFO:framework.py:***************** Accelerator Patching *************
Currently training with a batch size of: 2
max_steps is given, it will override any value given in num_train_epochs
/home/tuning/.local/lib/python3.11/site-packages/trl/trainer/sft_trainer.py:421: UserWarning: You passed `packing=True` to the SFTTrainer/SFTConfig, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached.
  warnings.warn(
***** Running training *****
  Num examples = 2,364
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 100
  Number of trainable parameters = 1,703,936

Running with llama3-8b model, I ran with no framework, foak-fast-kernels (fast_loss, fast_rsm_layernorm, fast_rope_embeddings), and foak-fast-kernels-liger (fused_linear_loss, fast_rsm_layernorm, fast_rope_embeddings) and saw that memory decreased further by using the fused cross entropy loss but that train runtime was the same as using fast kernels with cross entropy loss.

Findings:

Framework Num GPUs per_device_batch train_mem_gpu_peaked_delta mem_nvidia_mem_reserved mem_peak_torch_mem_alloc_in_bytes train_runtime train_samples_per_second train_steps_per_second train_tokens_per_second
No framework 2 2 17GB 31KB 25GB 253.9422 1.575 0.394 3225.931
No framework 4 1 9.5GB 19.8KB 13.5GB 144.9639 2.759 0.69 2825.53
Foak-fast-kernels 2 2 12.8GB 31.5KB 20.8GB 225.3918 1.775 0.444 3634.559
Foak-fast-kernels 4 1 7.4GB 17.8KB 11.5GB 134.7904 2.968 0.742 3038.792
Foak-fast-kernels-liger 2 2 8.6GB 19.6KB 16.7GB 225.7871 1.772 0.443 3628.197
Foak-fast-kernels-liger 4 1 5.6GB 15.8KB 9.6GB 135.1398 2.96 0.74 3030.935
  • train_mem_gpu_peaked_delta
    • 25% speedup from base to foak-fast-kernels
    • 49% speedup from base to foak-fast-kernels-liger
  • mem_peak_torch_mem_alloc_in_bytes
    • 16.8% speedup from base to foak-fast-kernels
    • 33% speedup from base to foak-fast-kernels-liger


# fast rms norm triton kernels
fast_rsm_layernorm: True

# fast RoPE embedding triton kernels
fast_rope_embeddings: True

# fused linear cross entropy loss
fused_linear_loss: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is problematic because fast_loss and fused_linear_loss both switch the loss, so it causes weird combinations

Maybe its better to convert fast_loss to accept both boolean and a string

fast_loss: True means legacy fast_loss: 'fused' means fused

Copy link
Contributor

@fabianlim fabianlim Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or use fast_loss = True / False and fast_loss_impl we can have the deafult (non-fused) and "fused"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that's why we are only allowing the user to set one, if one is true, the other must be false. That prevents this issue right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just feel that if already having True/False switches is already abit annoying as you mentioned before. Now having constraints between them might be even more confusing

@fabianlim fabianlim mentioned this pull request Oct 17, 2024
@fabianlim
Copy link
Contributor

@anhuong it is a bit unexpected that you are not seeing speedups, as it was reported previously for the same model (alibet different seq lengths are batch sizes) that there are significant speedups #76

Even though the bench parameters that you tried are different, in ours we actually pack the sequences to full 2048, so we expect that this packing will end up with similar throughputs

Once difference I can see are that yours are multi gpu runs, can we first look at the issue using 1 gpu first.

@fabianlim
Copy link
Contributor

fabianlim commented Oct 31, 2024

@anhuong i was able to repro the liger benches by running them like this

cd $HOME/Liger-Kernel/examples/huggingface
pip install -r requirements.txt
bash run_benchmarks.sh

and doing some hacking to enable / disable the fused_cross entropy loss. I ran this with two GPUs.

image image

Benches of FMS Accel

I have also tested this PR on FMS Accel on a set of different parameters, basically more aggressive batch sizes.

image
image

It is possible that the speedups did not show before the non-fused version OOMs

benchmarks.csv

@anhuong
Copy link
Collaborator Author

anhuong commented Oct 31, 2024

For reference, using the fms-acceleration benchmarks I tried:

  • I removed the memory probes and only set memory logging to nvidia and tried with higher batch sizes but the benchmarking is still showing improved memory but the same train_runtime as running with the previous kernels. I had to run with the mistral model as the llama model with larger batch sizes was hitting OOM
    • GPU=1, batch_size=16, train_runtime=1803 (none)--> 1573 (foak-fast-kernels) -->1570 (foak-fast-kernels)-liger
    • GPU=2, batch_size=32, train_runtime=1792-->1563-->1560
    • GPU=4, batch_size=64, train_runtime=1863-->1632-->1629
    • GPU=4, batch_size=64, FT, train_runtime=2169-->1929-->1927
  • Llama3-8b, 4 GPUs, max_seq_length=512, dataset_num_samples=10000
    • Lora, batch_size=32 (per_device=8) — 149 —> 131 --> 133
    • Lora, batch_size=64 (per_device=16) — 132 --> 115 --> 116
    • Lora, batch_size=100 (per_device=25) — 128 --> 111 --> 110
    • FT, batch_size=32 — 191 --> 172 --> 179
    • FT, batch_size=64 — 167 --> 149 --> 152
    • FT, batch_size=100 — 158 --> 141 --> 158

As you can see the train_runtime improves going from 0 kernels to fast kernels but from fast kernels to liger fused kernel showed minimal or no speedup as Fabian showed as well.

As noted Fabian, I had to decrease the dataset size and max_seq_length in order to run full benchmarks which is what Aaron had done in the issue,

@fabianlim
Copy link
Contributor

@anhuong but isnt this a speedup for llama3?

GPU=4, batch_size=64, FT, train_runtime=80777-->78182-->72640

We expect more speedup for larger vocabularly, in which case llama3 will see more benefit than mistral

@anhuong
Copy link
Collaborator Author

anhuong commented Nov 1, 2024

Sorry that was memory metrics, wrote down the wrong info. I did see a small speedup in FT with batch_size=100 adding above nevermind it actually took longer :(

Signed-off-by: Anh Uong <[email protected]>
@fabianlim fabianlim force-pushed the fused-cross-entropyloss-simplified branch from f1671b1 to d58960c Compare November 14, 2024 08:53
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim fabianlim force-pushed the fused-cross-entropyloss-simplified branch from d58960c to 2c202ef Compare November 14, 2024 16:29
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim fabianlim force-pushed the fused-cross-entropyloss-simplified branch 2 times, most recently from 2462613 to 0df953c Compare November 18, 2024 00:42
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim fabianlim force-pushed the fused-cross-entropyloss-simplified branch from 0df953c to 1a69314 Compare November 18, 2024 01:12
@fabianlim
Copy link
Contributor

@anhuong I did quite a few changes to this PR

  • refactor: I moved the liger code to be under the fused_ops folder, since its more of that
  • benches: I updated the scenarios-liger.yaml benches to better feel what it should be tested.
  • configs: i changed fast_loss to now accept True, False or fused_ce_liger

New Benches

image

image

@fabianlim fabianlim marked this pull request as ready for review November 18, 2024 01:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants