From 4322843e873e4a2cac51ab1943dd76328e4bedf1 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Wed, 16 Oct 2024 10:54:31 -0600 Subject: [PATCH] add mixtral model patch Signed-off-by: Anh Uong --- .../liger/fused_linear_cross_entropy_loss.py | 155 +++++++++++++++++- .../fms_acceleration_foak/models/mixtral.py | 7 + 2 files changed, 159 insertions(+), 3 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py index 5ab9dc9e..91bbc4cc 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py @@ -34,6 +34,14 @@ _CONFIG_FOR_DOC, LLAMA_INPUTS_DOCSTRING, ) +from transformers.models.mixtral.modeling_mixtral import ( + _CONFIG_FOR_DOC, + MIXTRAL_INPUTS_DOCSTRING, +) +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) from transformers.utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, @@ -289,7 +297,8 @@ def forward(self, lin_weight, _input, target, bias=None): self.reduction, ) -@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +# TODO: how to add diff docstrings for diff model types? what if the loss functions aren't the same across models? +# @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) @@ -328,9 +337,9 @@ def lce_forward( Example: ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> from transformers import AutoTokenizer, AutoModelForCausalLM - >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" @@ -374,6 +383,7 @@ def lce_forward( loss = None logits = None + # patch change if self.training and (labels is not None): shift_hidden_states = hidden_states[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -425,4 +435,143 @@ def lce_forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + ) + +# TODO: is adding a separate copy of lce_forward() the right path or should the additional logic for Moe models be in the single lce_forward? +@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) +@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) +# Ignore copy +def lce_forward_mixtral( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, +) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + # patch change + if self.training and (labels is not None): + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + # TODO: unique differing part to mixtral model forward + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + # TODO: should this loss manipulation be indented in?? or should it be added to even the liger loss? + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, ) \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py index 67eada1c..fe832aea 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py @@ -23,6 +23,7 @@ combine_triggers, ) from transformers.models.mixtral.modeling_mixtral import ( + MixtralForCausalLM, MixtralAttention, MixtralRMSNorm, ) @@ -31,6 +32,7 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward_mixtral from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops @@ -93,6 +95,11 @@ def get_mp_rules(base_type): "transformers.models.mixtral.modeling_mixtral", ), ), + ModelPatcherRule( + rule_id="mixtral-fused-lce", + trigger=ModelPatcherTrigger(check=MixtralForCausalLM), + forward=lce_forward_mixtral, + ), ModelPatcherRule( rule_id="mixtral-rope", import_and_maybe_reload=(