Skip to content

Commit

Permalink
Disable MLP Fused Ops if Not SwiGLU, Depracted Fast Quantized Peft Pl…
Browse files Browse the repository at this point in the history
…ugin, Update Benchmarks (#106)

* disable MLP fused op for non-silu, and removed all qpeft plugin

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix the filter drops rule

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix all models

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* accurately set trl in bnb qpeft fix and file rename

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* update bench

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim authored Nov 14, 2024
1 parent 5b35eae commit 9239802
Show file tree
Hide file tree
Showing 13 changed files with 391 additions and 370 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def get_callbacks_and_ready_for_train(
# the meta device fix for quantized models is since this transformers version
# or if trl is installed then its only for this version
if _transformers_version >= "4.45" and (
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
not _trl_installed or (_trl_installed and _trl_version >= "0.11.4")
):
# guarded
# NOTE: replace this later with a more specific accelerate version check
Expand Down
1 change: 0 additions & 1 deletion plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ This library contains fused operations and custom kernels, to be expanded over t

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE (**Disabled**) | Contains extracted code | | ✅
[fast_kernels](./src/fms_accelerate_foak/framework_plugin_fast_kernels.py) | Enhanced version of `fast_quantized_peft`, also works for full-FT and non-quant peft | Contains extracted code | | ✅

### Supported DataType Settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@

# Local
from .framework_plugin_fast_kernels import FastKernelsAccelerationPlugin
from .framework_plugin_fast_quantized_peft import FastQuantizedPeftAccelerationPlugin
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@
from fms_acceleration import AccelerationPlugin, AccelerationPluginConfigError
from peft import LoraConfig
from peft.tuners.lora.layer import LoraLayer
from transformers import TrainingArguments
from transformers import PretrainedConfig, TrainingArguments
import torch

# Local
from .framework_plugin_fast_quantized_peft import lora_adapters_switch_ddp_from_fsdp
from .utils import lora_adapters_switch_ddp_from_fsdp
from .models.utils import filter_mp_rules


# consider rewriting register_foak_model_patch_rules into something
# like this also
def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = None):
def register_foak_model_patch_rules(
base_type: str,
filter_endswith: Set[str] = None,
config: PretrainedConfig = None,
):

# Third Party
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel
Expand All @@ -45,20 +50,21 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] =
mixtral,
)

# create model specific rules
rules = [
*gpt_bigcode.get_mp_rules(base_type),
*granite.get_mp_rules(base_type),
*granite.get_mp_rules(base_type, config),
*granitemoe.get_mp_rules(base_type),
*llama.get_mp_rules(base_type),
*mistral.get_mp_rules(base_type),
*llama.get_mp_rules(base_type, config),
*mistral.get_mp_rules(base_type, config),
*mixtral.get_mp_rules(base_type),
]

if filter_endswith is not None:
# filter rules
rules = [
r for r in rules if any(r.rule_id.endswith(x) for x in filter_endswith)
]
# for filtering rules that apply regardless of model arch
# - this would be useful for implementing switches for
# turning off rules that affect all models
if filter_endswith:
rules = filter_mp_rules(rules, filter_endswith)

for _rule in rules:
ModelPatcher.register(_rule)
Expand Down Expand Up @@ -151,18 +157,22 @@ def augmentation(

terms = set()
for k, v in self.configurations.items():
if isinstance(v, bool) and v is False:
continue

if k in FILTER_MAP and k not in omitted:
ts = FILTER_MAP[k]
if isinstance(ts, str):
ts = {ts}
if isinstance(v, bool) and v is False:
continue

terms.update(ts)

# wrapper function to register foak patches
# - the base layer setting below will be ignored in non quantized-lora settings
register_foak_model_patch_rules2(
base_type=self.configurations["base_layer"], filter_endswith=terms
register_foak_model_patch_rules(
base_type=self.configurations["base_layer"],
filter_endswith=terms,
config=model.config,
)
return model, modifiable_args

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
from functools import partial
import warnings

# Third Party
from fms_acceleration.model_patcher import (
Expand All @@ -22,15 +23,24 @@
combine_functions,
combine_triggers,
)
from transformers import PretrainedConfig

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
from .utils import (
KEY_MLP,
KEY_O,
KEY_QKV,
build_lora_fused_ops,
filter_mp_rules,
get_hidden_activation_fn_key,
trigger_fused_ops,
)


def get_mp_rules(base_type: str):
def get_mp_rules(base_type: str, config: PretrainedConfig = None):
"""
Function to access all patch rules in this module.
If it is a forward_builder rule with `base_type` in
Expand All @@ -47,7 +57,7 @@ def get_mp_rules(base_type: str):
except ImportError:
return []

return [
rules = [
# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
Expand Down Expand Up @@ -133,3 +143,15 @@ def get_mp_rules(base_type: str):
),
),
]

# perform model specific filtering
act = get_hidden_activation_fn_key(config)
if config and act != "silu":
warnings.warn(
f"Granite Rules: activation is {act}, "
"thus disabling LoRA fused-op for MLP, since only SwiGLU "
"is supported. This only affects quantized-peft."
)
rules = filter_mp_rules(rules, {"mlp"}, drop=True)

return rules
Loading

0 comments on commit 9239802

Please sign in to comment.