Skip to content

Commit

Permalink
Add MLP & QLoRA Fused Ops and Kernels, Mixtral (#29)
Browse files Browse the repository at this point in the history
* refactor

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fixes

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* refactor mistral

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add mixtral

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* some refactoring after introducing mlp

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* remove extranous files

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add bnb

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* lint + fmt and improvements to readme

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* bench fixes

* need to handle lora adapters device due to #26

* allow replay of failed benches, addressing comment in #14

* update benches (remove l40)

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim authored Jun 2, 2024
1 parent b2b8fe6 commit 8103238
Show file tree
Hide file tree
Showing 23 changed files with 626 additions and 326 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Plugin | Description | Depends | License | Status
--|--|--|--|--
[framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Beta
[accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface<br>AutoGPTQ | Apache 2.0<br>MIT | Beta
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 with exclusions. | Coming Soon
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta
MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon

## Usage with FMS HF Tuning
Expand Down Expand Up @@ -174,7 +174,6 @@ The benchmarks can be reproduced [with the provided scripts](./scripts/benchmark

See below CSV files for various results:
- [A100-80GB](./scripts/benchmarks/refs/a100_80gb.csv).
- [L40-40GB](./scripts/benchmarks/refs/l40_40gb.csv).

### Code Architecture

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# consider making a map if patching more kernels
PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"]


# This function may be moved after merging
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
def _patch_target_module(
Expand Down Expand Up @@ -123,6 +124,7 @@ def create_new_module_peft(
# if module cannot be found, return None which results in a raise in the call-stack
return new_module


# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
Expand All @@ -133,9 +135,9 @@ def patch_forward_to_view_attributes_before_call(
):
# patch old_forward to view attribtues to torch_dype
# before call

if submodule_names is None:
submodule_names = ''
submodule_names = ""
if isinstance(submodule_names, str):
submodule_names = [submodule_names]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,18 @@ def __init__(self, configurations: Dict[str, Dict]):
def model_loader(self, model_name: str, **kwargs):
# guarded imports
# Third Party
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
AutoGPTQForCausalLM,
BaseQuantizeConfig,
)
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)

# Local
from .autogptq_utils import ( #pylint: disable=import-outside-toplevel
patch_forward_to_view_attributes_before_call,
PATCH_FOR_FSDP_TRITON_V2
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
PATCH_FOR_FSDP_TRITON_V2,
patch_forward_to_view_attributes_before_call,
)

# Currently we allow only a quantized checkpoint to be loaded, we do not
Expand Down Expand Up @@ -214,8 +219,14 @@ def augmentation(
):
# guarded imports
# Third Party
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
GPTQLoraModel,
get_gptq_peft_model,
)

# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
create_new_module_peft,
Expand Down
36 changes: 11 additions & 25 deletions plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following:


1. Fused operations and kernels are extracted from [unsloth](#extracted-code-from-unsloth).
1. Fused operations and kernels extracted from [unsloth](#extracted-code-from-unsloth).
- Low-Rank Adapter Fused Operations
- Fast RoPE Triton Kernels
- Fast RMS LayerNorm Triton Kernels
Expand All @@ -13,42 +13,28 @@ This library contains fused operations and custom kernels, to be expanded over t

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | Loads fused lora, fast cross-entropy, fast rms, fast RoPE | | | ✅
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE | Contains extracted code | | ✅

### Code Extracted from Unsloth

<!--
NOTE: the
- fused_ops/unsloth_lora -> unsloth main
* utils (fast_dequant, fast_gemv, fast_linear_forward, matmul_lora)
* geglu, swiglu (this can be reused across other models, but currently used inside MLP fused ops only)
* bnb (fast_lora.py)
* gtqp (fast_lora, triton) -> jeromeku
- kernels
* cross_ent, rms, rope -> unsloth main
-->

Notes on the extraction of code from [unsloth](https://github.com/unslothai/unsloth):
- while unsloth is released under Apache 2.0, there are [exceptions to the permissive licenses scattered in the code base](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
- While unsloth is [released under Apache 2.0](https://github.com/unslothai/unsloth/blob/main/LICENSE), there are comments indicating some exceptions strewn throughout the code base, see [an example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
```
it would require a commercial license if used to run on more than 4 GPUs, see
https://github.com/unslothai/unsloth/blob/d215fd902cf28feb8abcfde2d25281d0fbf9d28c/unsloth/models/llama.py#L1140-L1143
it would require a commercial license if used to run on more than 4 GPUs ...
```
- these exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55), around the model files (namely `llama.py`, `mistral.py`, etc).
* These model files are **not extracted**.
- All code extracted here before the Feb 2024 Release, see table below.
- These exceptions appear to be located around the trainer improvements, see [another example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1177-L1183).
- These exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55); any code that appears in any file where such exceptions occur **is not extracted**.
- Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and **completely rewritten from scratch**.
- All extracted code appears before the Feb 2024 Release.
- In the table below we record what was extracted, and the exact commit from which it was taken.

Path | Description | Extracted From | Modifications | Date
--|--|--|--|--
[fused_ops/unsloth_lora](./src/fms_acceleration_foak/fused_ops/unsloth_lora) | QLoRA fast dequant, activation kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `fast_lora.py` | 28 Jan 2024
[fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`<br>`triton/layers.py` | 6 Feb 2024
[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py` | 28 Jan 2024

<!--
[models/](./src/fms_accelerate_unsloth/models/) | Model Forwards | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc)<br><br>`tohrnii/mixtral` @ [a55b7400](https://github.com/tohrnii/unsloth/commit/a55b740062b4fc8ce8f5196bfabe3cf860020ca7) | `llama.py`<br>`mistral.py`<br>`mixtral.py`| `llama.py`<br>`mistral.py`<br>`mixtral.py` | 6 Feb 2024<br><br> 22 Feb 2024
-->

[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`<br>`rms_layernorm.py` | 28 Jan 2024

## Known Issues

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Callable, Dict, Tuple

# Third Party
from accelerate.utils import set_module_tensor_to_device
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig
from peft.tuners.lora.layer import LoraLayer
Expand Down Expand Up @@ -63,9 +64,20 @@ def _all_reduce_hook(grad):
return grad

for mod in modules:
# NOTE: assuming lora has no bias
A = mod.lora_A.default
B = mod.lora_B.default

# install hooks on the adapters
mod.lora_A.default.weight.register_hook(_all_reduce_hook)
mod.lora_B.default.weight.register_hook(_all_reduce_hook)
A.weight.register_hook(_all_reduce_hook)
B.weight.register_hook(_all_reduce_hook)

# because we will ignore these from FSDP, we need to manually
# move them to gpu if they are already not on them
if not A.weight.is_cuda:
set_module_tensor_to_device(A, "weight", "cuda")
if not B.weight.is_cuda:
set_module_tensor_to_device(B, "weight", "cuda")


class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin):
Expand All @@ -82,10 +94,7 @@ def __init__(self, configurations: Dict[str, Dict]):

self._base_layer = self._check_config_and_maybe_check_values(
key="peft.quantization.fused_ops_and_kernels.base_layer",
values=[
"auto_gptq",
# "bitsandbytes" # enable later when we have BNB implemented
],
values=["auto_gptq", "bitsandbytes"],
)

# only support these at the moment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,10 @@ def apply_lora_o(self, X):
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
pass

# added by [email protected]
# this will be patchable on the actual module
def apply_lora_o_v2(self, X):
OW, OW_quant, OA, OB, OS = get_lora_parameters(self)
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,10 @@ def apply_lora_o(self, X):
Oqstate, OA, OB, OS = get_lora_parameters(self.o_proj)
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
return O

# added by [email protected]
# this version can be directly patched on the output linear
def apply_lora_o_v2(self, X):
Oqstate, OA, OB, OS = get_lora_parameters(self)
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
return O
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def backward(ctx, dY):
pass
pass


def fast_rope_embedding(Q, K, cos, sin):
# modified by [email protected]
# NOTE: fast_rope embeddings currently does not account for position ids
def fast_rope_embedding(Q, K, cos, sin, position_ids=None):
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
return Q, K
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Local
from .model_patcher import ModelPatcher

PATCHES = [".models.llama", ".models.mistral"]
PATCHES = [".models.llama", ".models.mistral", ".models.mixtral"]
PLUGIN_PREFIX = "fms_acceleration_foak"

# TODO: remove the need for the prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@
from functools import partial

# Third Party
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaMLP,
LlamaRMSNorm,
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
from .utils import build_lora_fused_ops, trigger_fused_ops
from .model_patcher import (
ModelPatcher,
ModelPatcherRule,
ModelPatcherTrigger,
combine_functions,
combine_triggers,
)
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops

# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
Expand All @@ -42,18 +52,54 @@
ModelPatcher.register(
ModelPatcherRule(
rule_id="llama-qkvo",
trigger=combine_triggers(
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaAttention,
submodule_names=["q_proj", "k_proj", "v_proj"],
)
),
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaAttention,
submodule_names=["o_proj"],
)
),
logic="OR",
),
forward_builder=combine_functions(
partial(
build_lora_fused_ops,
submodule_names=["q_proj", "k_proj", "v_proj"],
fused_op=KEY_QKV,
),
partial(
build_lora_fused_ops,
submodule_names=["o_proj"],
fused_op=KEY_O,
),
logic="APPEND",
),
forward_builder_args=["base_type"],
)
)

ModelPatcher.register(
ModelPatcherRule(
rule_id="llama-mlp",
trigger=ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaAttention,
qkv_module_names=["q_proj", "k_proj", "v_proj"],
o_module_name="o_proj",
attn_cls=LlamaMLP,
submodule_names=["up_proj", "down_proj", "gate_proj"],
)
),
forward_builder=partial(
build_lora_fused_ops,
qkv_module_names=["q_proj", "k_proj", "v_proj"],
o_module_name="o_proj",
submodule_names=["up_proj", "down_proj", "gate_proj"],
fused_op=KEY_MLP,
),
forward_builder_args=["base_type"],
)
Expand Down
Loading

0 comments on commit 8103238

Please sign in to comment.