-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add liger kernel with fused cross entropy loss #93
base: main
Are you sure you want to change the base?
feat: add liger kernel with fused cross entropy loss #93
Conversation
Signed-off-by: 1000850000 user <[email protected]> Signed-off-by: Anh Uong <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
e7e7c3d
to
05cdbe6
Compare
# TODO: how to add diff docstrings for diff model types? what if the loss functions aren't the same across models? | ||
# @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) | ||
@replace_return_docstrings( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: do we need thisadd_start_docstrings_to_model_forward
header that exists in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1133 but is different for each model type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thinik dont need
# TODO: differing line below in granite models compared to llama/mistral model type | ||
# logits = logits / self.config.logits_scaling | ||
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss | ||
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward methods are very similar for llama, granite, and mistral however the above line as you can see is within granite here: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/granite/modeling_granite.py#L1102 --> how to handle this difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the naive method is to copy and modify multiple forward functions for the fused kernel. The better solution is the other design
# TODO: is adding a separate copy of lce_forward() the right path or should the additional logic for Moe models be in the single lce_forward? | ||
@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) | ||
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | ||
# Ignore copy | ||
def lce_forward_mixtral( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a different forward method for mixtral since it was different than the others in line 550-561 and in the object returned is of different type. Is adding a separate forward method the best way to handle this?
@@ -40,6 +41,7 @@ def get_mp_rules(base_type: str): | |||
try: | |||
# Third Party | |||
from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel | |||
GraniteForCausalLM, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing granite here refers to teh GraniteForCausalLM models and the granite with llama architecture models would be handled within the llama.py correct? I'm not sure if the liger kernels work for any of the granite models, do you know?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thats the problem with this solution, because we copy and paste the forward functions wholesale. so if the forwards are different this doesnt work
Will update with benchmark data as I gather them, currently I do see the new kernel being used: ***** FMS AccelerationFramework *****
INFO:framework.py:***** FMS AccelerationFramework *****
Active Plugin: FastKernelsAccelerationPlugin. Python package: fms_acceleration_foak. Version: 0.2.1.dev0.
INFO:framework.py:Active Plugin: FastKernelsAccelerationPlugin. Python package: fms_acceleration_foak. Version: 0.2.1.dev0.
***************** Module Forwards Patching *************
INFO:framework.py:***************** Module Forwards Patching *************
Rule: llama-fused-lce Module: Class: LlamaForCausalLM Num: 1
INFO:framework.py:Rule: llama-fused-lce Module: Class: LlamaForCausalLM Num: 1
Rule: llama-rms Module: input_layernorm Class: LlamaRMSNorm Num: 32
INFO:framework.py:Rule: llama-rms Module: input_layernorm Class: LlamaRMSNorm Num: 32
Rule: llama-rms Module: model Class: LlamaRMSNorm Num: 1
INFO:framework.py:Rule: llama-rms Module: model Class: LlamaRMSNorm Num: 1
Rule: llama-rms Module: post_attention_layernorm Class: LlamaRMSNorm Num: 32
INFO:framework.py:Rule: llama-rms Module: post_attention_layernorm Class: LlamaRMSNorm Num: 32
Rule: llama-rope Module: Class: LlamaForCausalLM Num: 1
INFO:framework.py:Rule: llama-rope Module: Class: LlamaForCausalLM Num: 1
***************** Accelerator Patching *************
INFO:framework.py:***************** Accelerator Patching *************
Currently training with a batch size of: 2
max_steps is given, it will override any value given in num_train_epochs
/home/tuning/.local/lib/python3.11/site-packages/trl/trainer/sft_trainer.py:421: UserWarning: You passed `packing=True` to the SFTTrainer/SFTConfig, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached.
warnings.warn(
***** Running training *****
Num examples = 2,364
Num Epochs = 1
Instantaneous batch size per device = 2
Total train batch size (w. parallel, distributed & accumulation) = 4
Gradient Accumulation steps = 1
Total optimization steps = 100
Number of trainable parameters = 1,703,936 Running with llama3-8b model, I ran with no framework, foak-fast-kernels (fast_loss, fast_rsm_layernorm, fast_rope_embeddings), and foak-fast-kernels-liger (fused_linear_loss, fast_rsm_layernorm, fast_rope_embeddings) and saw that memory decreased further by using the fused cross entropy loss but that train runtime was the same as using fast kernels with cross entropy loss. Findings:
|
|
||
# fast rms norm triton kernels | ||
fast_rsm_layernorm: True | ||
|
||
# fast RoPE embedding triton kernels | ||
fast_rope_embeddings: True | ||
|
||
# fused linear cross entropy loss | ||
fused_linear_loss: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is problematic because fast_loss
and fused_linear_loss
both switch the loss, so it causes weird combinations
Maybe its better to convert fast_loss
to accept both boolean and a string
fast_loss: True
means legacy fast_loss: 'fused'
means fused
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or use fast_loss
= True / False and fast_loss_impl
we can have the deafult (non-fused) and "fused"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But that's why we are only allowing the user to set one, if one is true, the other must be false. That prevents this issue right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just feel that if already having True/False
switches is already abit annoying as you mentioned before. Now having constraints between them might be even more confusing
@anhuong it is a bit unexpected that you are not seeing speedups, as it was reported previously for the same model (alibet different seq lengths are batch sizes) that there are significant speedups #76 Even though the bench parameters that you tried are different, in ours we actually pack the sequences to full 2048, so we expect that this packing will end up with similar throughputs Once difference I can see are that yours are multi gpu runs, can we first look at the issue using 1 gpu first. |
@anhuong i was able to repro the liger benches by running them like this
and doing some hacking to enable / disable the fused_cross entropy loss. I ran this with two GPUs. Benches of FMS AccelI have also tested this PR on FMS Accel on a set of different parameters, basically more aggressive batch sizes. It is possible that the speedups did not show before the non-fused version OOMs |
For reference, using the fms-acceleration benchmarks I tried:
As you can see the train_runtime improves going from 0 kernels to fast kernels but from fast kernels to liger fused kernel showed minimal or no speedup as Fabian showed as well. As noted Fabian, I had to decrease the dataset size and max_seq_length in order to run full benchmarks which is what Aaron had done in the issue, |
@anhuong but isnt this a speedup for llama3?
We expect more speedup for larger vocabularly, in which case llama3 will see more benefit than mistral |
Sorry that was memory metrics, wrote down the wrong info. |
Signed-off-by: Anh Uong <[email protected]>
sample-configurations/foak-fast-kernels-sample-configuration.yaml
Outdated
Show resolved
Hide resolved
Signed-off-by: Anh Uong <[email protected]>
…ss-simplified Signed-off-by: Yu Chin Fabian Lim <[email protected]>
f1671b1
to
d58960c
Compare
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
d58960c
to
2c202ef
Compare
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
2462613
to
0df953c
Compare
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
0df953c
to
1a69314
Compare
@anhuong I did quite a few changes to this PR
New Benches |
No description provided.