Skip to content

Commit

Permalink
add mistral and granite model patch
Browse files Browse the repository at this point in the history
Signed-off-by: Anh Uong <[email protected]>
  • Loading branch information
anhuong committed Oct 16, 2024
1 parent 4322843 commit 05cdbe6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops


Expand All @@ -40,6 +41,7 @@ def get_mp_rules(base_type: str):
try:
# Third Party
from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel
GraniteForCausalLM,
GraniteAttention,
GraniteMLP,
GraniteRMSNorm,
Expand Down Expand Up @@ -120,6 +122,11 @@ def get_mp_rules(base_type: str):
"transformers.models.granite.modeling_granite",
),
),
ModelPatcherRule(
rule_id="granite-fused-lce",
trigger=ModelPatcherTrigger(check=GraniteForCausalLM),
forward=lce_forward,
),
# TODO: have a generic version of this rule
# - get the module name
# - check if "apply_rotary_pos_emb" exists
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops

from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops

def get_mp_rules(base_type: str):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
combine_triggers,
)
from transformers.models.mistral.modeling_mistral import (
MistralForCausalLM,
MistralAttention,
MistralMLP,
MistralRMSNorm,
Expand All @@ -32,9 +33,9 @@
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops


def get_mp_rules(base_type):
"""
Function to access all patch rules in this module.
Expand Down Expand Up @@ -110,6 +111,11 @@ def get_mp_rules(base_type):
"transformers.models.mistral.modeling_mistral",
),
),
ModelPatcherRule(
rule_id="mistral-fused-lce",
trigger=ModelPatcherTrigger(check=MistralForCausalLM),
forward=lce_forward,
),
ModelPatcherRule(
rule_id="mistral-rope",
import_and_maybe_reload=(
Expand Down

0 comments on commit 05cdbe6

Please sign in to comment.