diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py index a2be13a..d40b7e1 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py @@ -27,6 +27,7 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops @@ -40,6 +41,7 @@ def get_mp_rules(base_type: str): try: # Third Party from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel + GraniteForCausalLM, GraniteAttention, GraniteMLP, GraniteRMSNorm, @@ -120,6 +122,11 @@ def get_mp_rules(base_type: str): "transformers.models.granite.modeling_granite", ), ), + ModelPatcherRule( + rule_id="granite-fused-lce", + trigger=ModelPatcherTrigger(check=GraniteForCausalLM), + forward=lce_forward, + ), # TODO: have a generic version of this rule # - get the module name # - check if "apply_rotary_pos_emb" exists diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py index 4226b6a..be66811 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -33,9 +33,8 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops - from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward +from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops def get_mp_rules(base_type: str): """ diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py index 8e773a2..0ea886d 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py @@ -23,6 +23,7 @@ combine_triggers, ) from transformers.models.mistral.modeling_mistral import ( + MistralForCausalLM, MistralAttention, MistralMLP, MistralRMSNorm, @@ -32,9 +33,9 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops - def get_mp_rules(base_type): """ Function to access all patch rules in this module. @@ -110,6 +111,11 @@ def get_mp_rules(base_type): "transformers.models.mistral.modeling_mistral", ), ), + ModelPatcherRule( + rule_id="mistral-fused-lce", + trigger=ModelPatcherTrigger(check=MistralForCausalLM), + forward=lce_forward, + ), ModelPatcherRule( rule_id="mistral-rope", import_and_maybe_reload=(