diff --git a/prismatic/conf/models.py b/prismatic/conf/models.py index 09e40b94..cf2b96d4 100644 --- a/prismatic/conf/models.py +++ b/prismatic/conf/models.py @@ -259,7 +259,7 @@ class Exp_13B_Llama2(Exp_13B_One_Stage): llm_backbone_id: str = "llama2-13b-pure" -# ~ Additional LLM Backbones ~ +# ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~ @dataclass class Ext_Exp_7B_Llama2_Chat(Exp_7B_One_Stage): model_id: str = "llama2-chat+7b" @@ -272,6 +272,18 @@ class Ext_Exp_13B_Llama2_Chat(Exp_13B_One_Stage): llm_backbone_id: str = "llama2-13b-chat" +@dataclass +class Ext_Exp_7B_Mistral_V1(Exp_7B_One_Stage): + model_id: str = "mistral-v0.1+7b" + llm_backbone_id: str = "mistral-v0.1-7b-pure" + + +@dataclass +class Ext_Exp_7B_Mistral_Instruct_V1(Exp_7B_One_Stage): + model_id: str = "mistral-instruct-v0.1+7b" + llm_backbone_id: str = "mistral-v0.1-7b-instruct" + + # Section 4.3B :: ✌️ --> Co-training on Language-only Data # =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training) @dataclass @@ -514,9 +526,11 @@ class ModelRegistry(Enum): EXP_LLAMA2_7B = Exp_7B_Llama2 EXP_LLAMA2_13B = Exp_13B_Llama2 - # ~ Additional LLM Backbone Experiments :: LLaMa-2 Chat ~ + # ~ Additional LLM Backbone Experiments :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~ EXT_EXP_LLAMA2_CHAT_7B = Ext_Exp_7B_Llama2_Chat EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat + EXT_EXP_MISTRAL_V1_7B = Ext_Exp_7B_Mistral_V1 + EXT_EXP_MISTRAL_INSTRUCT_V1_7B = Ext_Exp_7B_Mistral_Instruct_V1 # Cotraining w/ Unimodal Data EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining diff --git a/prismatic/models/backbones/llm/__init__.py b/prismatic/models/backbones/llm/__init__.py index 22df7ec6..dcd89551 100644 --- a/prismatic/models/backbones/llm/__init__.py +++ b/prismatic/models/backbones/llm/__init__.py @@ -1,2 +1,3 @@ from .base_llm import LLMBackbone from .llama2 import LLaMa2LLMBackbone +from .mistral import MistralLLMBackbone diff --git a/prismatic/models/backbones/llm/base_llm.py b/prismatic/models/backbones/llm/base_llm.py index e7b089b0..37840c98 100644 --- a/prismatic/models/backbones/llm/base_llm.py +++ b/prismatic/models/backbones/llm/base_llm.py @@ -145,7 +145,12 @@ def __init__( # Load (Fast) Tokenizer overwatch.info(f"Loading [bold]{llm_family}[/] (Fast) Tokenizer via the AutoTokenizer API", ctx_level=1) - self.tokenizer = AutoTokenizer.from_pretrained(hf_hub_path, model_max_length=self.llm_max_length, token=hf_token) + self.tokenizer = AutoTokenizer.from_pretrained( + hf_hub_path, + model_max_length=self.llm_max_length, + token=hf_token, + padding_side="right", + ) # Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input # starts with a token unless `add_special_tokens = False`; for these models, we empirically diff --git a/prismatic/models/backbones/llm/mistral.py b/prismatic/models/backbones/llm/mistral.py new file mode 100644 index 00000000..9bc25553 --- /dev/null +++ b/prismatic/models/backbones/llm/mistral.py @@ -0,0 +1,72 @@ +""" +mistral.py + +Class definition for all LLMs derived from MistralForCausalLM. +""" + +from typing import Optional, Type + +import torch +from torch import nn as nn +from transformers import MistralForCausalLM +from transformers.models.mistral.modeling_mistral import MistralDecoderLayer + +from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone +from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder + +# Registry =>> Support Mistral Models (from HF Transformers) +# fmt: off +MISTRAL_MODELS = { + # === Base Mistral v0.1 === + "mistral-v0.1-7b-pure": { + "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1" + }, + + # === Mistral Instruct v0.1 === + "mistral-v0.1-7b-instruct": { + "llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1" + } +} +# fmt: on + + +class MistralLLMBackbone(HFCausalLLMBackbone): + def __init__( + self, + llm_backbone_id: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, + use_flash_attention_2: bool = True, + ) -> None: + super().__init__( + llm_backbone_id, + llm_max_length=llm_max_length, + hf_token=hf_token, + inference_mode=inference_mode, + use_flash_attention_2=use_flash_attention_2, + **MISTRAL_MODELS[llm_backbone_id], + ) + + # [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize) + self.tokenizer.add_special_tokens({"pad_token": ""}) + self.llm.config.pad_token_id = self.tokenizer.pad_token_id + self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) + + @property + def prompt_builder_fn(self) -> Type[PromptBuilder]: + if self.identifier.endswith("-pure"): + return PurePromptBuilder + + elif self.identifier.endswith("-instruct"): + return MistralInstructPromptBuilder + + raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") + + @property + def transformer_layer_cls(self) -> Type[nn.Module]: + return MistralDecoderLayer + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/prismatic/models/backbones/llm/prompting/__init__.py b/prismatic/models/backbones/llm/prompting/__init__.py index 6e2779f0..f6501be0 100644 --- a/prismatic/models/backbones/llm/prompting/__init__.py +++ b/prismatic/models/backbones/llm/prompting/__init__.py @@ -1,3 +1,4 @@ from .base_prompter import PromptBuilder, PurePromptBuilder from .llama2_chat_prompter import LLaMa2ChatPromptBuilder +from .mistral_instruct_prompter import MistralInstructPromptBuilder from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder diff --git a/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py b/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py new file mode 100644 index 00000000..dd8f124d --- /dev/null +++ b/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py @@ -0,0 +1,61 @@ +""" +mistral_instruct_prompter.py + +Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s + +Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format +""" + +from typing import Optional + +from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder + + +class MistralInstructPromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + + # Note =>> Mistral Tokenizer is an instance of LlamaTokenizer + # =>> Mistral Instruct *does not* use a System Prompt + self.bos, self.eos = "", "" + + # Get role-specific "wrap" functions + # =>> Note that {bos} should only be added on FIRST utterance (key departure vs. LLaMa-2 Chat) + self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + if (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.removeprefix(self.bos).rstrip() + + def get_prompt(self) -> str: + # Remove prefix because it gets auto-inserted by tokenizer! + return self.prompt.removeprefix(self.bos).rstrip() diff --git a/prismatic/models/materialize.py b/prismatic/models/materialize.py index 3e5d65a7..995fb445 100644 --- a/prismatic/models/materialize.py +++ b/prismatic/models/materialize.py @@ -9,7 +9,7 @@ from transformers import PreTrainedTokenizerBase -from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone +from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone from prismatic.models.backbones.vision import ( CLIPViTBackbone, DinoCLIPViTBackbone, @@ -63,6 +63,10 @@ # === Vicuna-v1.5 Backbones === "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, + + # === Mistral v0.1 Backbones === + "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}}, + "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}}, } # fmt: on diff --git a/scripts/pretrain.py b/scripts/pretrain.py index 85cf1eee..c2d72025 100644 --- a/scripts/pretrain.py +++ b/scripts/pretrain.py @@ -51,7 +51,7 @@ class PretrainConfig: # ModelConfig (`prismatic/conf/models.py`); override with --model.type `ModelRegistry..model_id` model: ModelConfig = field( - default_factory=ModelConfig.get_choice_class(ModelRegistry.EXT_EXP_LLAMA2_CHAT_13B.model_id) + default_factory=ModelConfig.get_choice_class(ModelRegistry.EXT_EXP_MISTRAL_INSTRUCT_V1_7B.model_id) ) # DatasetConfig (`prismatic/conf/datasets.py`); override with --dataset.type `DatasetRegistry..dataset_id`