From 1d4370cd5a3db3446c64ecc5b10b35b4d03c4fd8 Mon Sep 17 00:00:00 2001 From: Sidd Karamcheti Date: Wed, 24 Apr 2024 10:29:10 -0700 Subject: [PATCH 1/2] Add Phi-2 LLM with fixed tokenization --- prismatic/conf/models.py | 9 ++- prismatic/models/backbones/llm/__init__.py | 1 + prismatic/models/backbones/llm/base_llm.py | 22 +++++-- prismatic/models/backbones/llm/phi.py | 64 ++++++++++++++++++ .../backbones/llm/prompting/__init__.py | 1 + .../llm/prompting/llama2_chat_prompter.py | 2 +- .../prompting/mistral_instruct_prompter.py | 3 +- .../backbones/llm/prompting/phi_prompter.py | 65 +++++++++++++++++++ prismatic/models/materialize.py | 5 +- prismatic/preprocessing/datasets/datasets.py | 9 ++- scripts/pretrain.py | 2 +- 11 files changed, 170 insertions(+), 13 deletions(-) create mode 100644 prismatic/models/backbones/llm/phi.py create mode 100644 prismatic/models/backbones/llm/prompting/phi_prompter.py diff --git a/prismatic/conf/models.py b/prismatic/conf/models.py index cf2b96d4..6f507b0d 100644 --- a/prismatic/conf/models.py +++ b/prismatic/conf/models.py @@ -259,7 +259,7 @@ class Exp_13B_Llama2(Exp_13B_One_Stage): llm_backbone_id: str = "llama2-13b-pure" -# ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~ +# ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct, Phi-2 ~ @dataclass class Ext_Exp_7B_Llama2_Chat(Exp_7B_One_Stage): model_id: str = "llama2-chat+7b" @@ -284,6 +284,12 @@ class Ext_Exp_7B_Mistral_Instruct_V1(Exp_7B_One_Stage): llm_backbone_id: str = "mistral-v0.1-7b-instruct" +@dataclass +class Ext_Exp_3B_Phi_2(Exp_7B_One_Stage): + model_id: str = "phi-2+3b" + llm_backbone_id: str = "phi-2-3b" + + # Section 4.3B :: ✌️ --> Co-training on Language-only Data # =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training) @dataclass @@ -531,6 +537,7 @@ class ModelRegistry(Enum): EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat EXT_EXP_MISTRAL_V1_7B = Ext_Exp_7B_Mistral_V1 EXT_EXP_MISTRAL_INSTRUCT_V1_7B = Ext_Exp_7B_Mistral_Instruct_V1 + EXT_EXP_PHI_2_3B = Ext_Exp_3B_Phi_2 # Cotraining w/ Unimodal Data EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining diff --git a/prismatic/models/backbones/llm/__init__.py b/prismatic/models/backbones/llm/__init__.py index dcd89551..1cdb1995 100644 --- a/prismatic/models/backbones/llm/__init__.py +++ b/prismatic/models/backbones/llm/__init__.py @@ -1,3 +1,4 @@ from .base_llm import LLMBackbone from .llama2 import LLaMa2LLMBackbone from .mistral import MistralLLMBackbone +from .phi import PhiLLMBackbone diff --git a/prismatic/models/backbones/llm/base_llm.py b/prismatic/models/backbones/llm/base_llm.py index 37840c98..4f9d33b3 100644 --- a/prismatic/models/backbones/llm/base_llm.py +++ b/prismatic/models/backbones/llm/base_llm.py @@ -152,13 +152,28 @@ def __init__( padding_side="right", ) + # Explicitly verify that Tokenizer padding_side is set to right for training! + assert self.tokenizer.padding_side == "right", "Tokenizer `padding_side` is not set to `right`!" + # Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input # starts with a token unless `add_special_tokens = False`; for these models, we empirically - # find that adding image patches *after* the BOS leads to much better performance + # find that adding image patches *after* the BOS leads to much better performance. # # As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this # line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to - # override this, but make sure to make the appropriate changes in the `datasets.py` and VLM `forward()` logic! + # override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py` + # and VLM `forward()` logic! + SPECIAL_CASES = { + # Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>" + # =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that + # this works well with base LLM generation. + # =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes. + "phi-2-3b", + } + if self.identifier in SPECIAL_CASES: + return + + # Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral! assert (self.tokenizer("Test 123", add_special_tokens=True).input_ids[0] == self.tokenizer.bos_token_id) and ( self.tokenizer("Test 123", add_special_tokens=False).input_ids[0] != self.tokenizer.bos_token_id ), ( @@ -166,9 +181,6 @@ def __init__( "Please read the comment in `base_llm.py` for more information!" ) - # Additionally, explicitly verify that Tokenizer padding_side is set to right for training! - assert self.tokenizer.padding_side == "right", "Tokenizer `padding_side` is not set to `right`!" - def get_fsdp_wrapping_policy(self) -> Callable: """Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`""" transformer_block_policy = partial( diff --git a/prismatic/models/backbones/llm/phi.py b/prismatic/models/backbones/llm/phi.py new file mode 100644 index 00000000..27bb8b79 --- /dev/null +++ b/prismatic/models/backbones/llm/phi.py @@ -0,0 +1,64 @@ +""" +phi.py + +Class definition for all LLMs derived from PhiForCausalLM. +""" + +from typing import Optional, Type + +import torch +from torch import nn as nn +from transformers import PhiForCausalLM +from transformers.models.phi.modeling_phi import PhiDecoderLayer + +from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone +from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder + +# Registry ==> Support Phi Models (from HF Transformers) +# fmt: off +PHI_MODELS = { + # === Phi-2 === + "phi-2-3b": { + "llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2" + } +} +# fmt: on + + +class PhiLLMBackbone(HFCausalLLMBackbone): + def __init__( + self, + llm_backbone_id: str, + llm_max_length: int = 2048, + hf_token: Optional[str] = None, + inference_mode: bool = False, + use_flash_attention_2: bool = True, + ) -> None: + super().__init__( + llm_backbone_id, + llm_max_length=llm_max_length, + hf_token=hf_token, + inference_mode=inference_mode, + use_flash_attention_2=use_flash_attention_2, + **PHI_MODELS[llm_backbone_id], + ) + + # [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize) + self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) + self.llm.config.pad_token_id = self.tokenizer.pad_token_id + self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) + + @property + def prompt_builder_fn(self) -> Type[PromptBuilder]: + if self.identifier.startswith("phi-2"): + return PhiPromptBuilder + + raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") + + @property + def transformer_layer_cls(self) -> Type[nn.Module]: + return PhiDecoderLayer + + @property + def half_precision_dtype(self) -> torch.dtype: + return torch.bfloat16 diff --git a/prismatic/models/backbones/llm/prompting/__init__.py b/prismatic/models/backbones/llm/prompting/__init__.py index f6501be0..8292789c 100644 --- a/prismatic/models/backbones/llm/prompting/__init__.py +++ b/prismatic/models/backbones/llm/prompting/__init__.py @@ -1,4 +1,5 @@ from .base_prompter import PromptBuilder, PurePromptBuilder from .llama2_chat_prompter import LLaMa2ChatPromptBuilder from .mistral_instruct_prompter import MistralInstructPromptBuilder +from .phi_prompter import PhiPromptBuilder from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder diff --git a/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py b/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py index 66c41aeb..5dd3ffdf 100644 --- a/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py +++ b/prismatic/models/backbones/llm/prompting/llama2_chat_prompter.py @@ -36,7 +36,7 @@ def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> No self.bos, self.eos = "", "" # Get role-specific "wrap" functions - self.wrap_human = lambda msg: f"{self.bos}[INST] {msg} [/INST] " + self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" # === `self.prompt` gets built up over multiple turns === diff --git a/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py b/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py index dd8f124d..35a5eab8 100644 --- a/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py +++ b/prismatic/models/backbones/llm/prompting/mistral_instruct_prompter.py @@ -15,12 +15,11 @@ class MistralInstructPromptBuilder(PromptBuilder): def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: super().__init__(model_family, system_prompt) - # Note =>> Mistral Tokenizer is an instance of LlamaTokenizer + # Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)` # =>> Mistral Instruct *does not* use a System Prompt self.bos, self.eos = "", "" # Get role-specific "wrap" functions - # =>> Note that {bos} should only be added on FIRST utterance (key departure vs. LLaMa-2 Chat) self.wrap_human = lambda msg: f"[INST] {msg} [/INST] " self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}" diff --git a/prismatic/models/backbones/llm/prompting/phi_prompter.py b/prismatic/models/backbones/llm/prompting/phi_prompter.py new file mode 100644 index 00000000..b350ea3a --- /dev/null +++ b/prismatic/models/backbones/llm/prompting/phi_prompter.py @@ -0,0 +1,65 @@ +""" +phi_prompter.py + +Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft. +Also handles Phi special case BOS token additions. + +Reference: https://huggingface.co/microsoft/phi-2#qa-format +""" + +from typing import Optional + +from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder + + +class PhiPromptBuilder(PromptBuilder): + def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: + super().__init__(model_family, system_prompt) + + # Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)` + # =>> By default, does *not* append / tokens --> we handle that here (IMPORTANT)! + self.bos, self.eos = "<|endoftext|>", "<|endoftext|>" + + # Get role-specific "wrap" functions + # =>> Note that placement of / were based on experiments generating from Phi-2 in Input/Output mode + self.wrap_human = lambda msg: f"Input: {msg}\nOutput: " + self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}" + + # === `self.prompt` gets built up over multiple turns === + self.prompt, self.turn_count = "", 0 + + def add_turn(self, role: str, message: str) -> str: + assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") + message = message.replace("", "").strip() + + # Special Handling for "first" input --> prepend a token (expected by Prismatic) + if self.turn_count == 0: + bos_human_message = f"{self.bos}{self.wrap_human(message)}" + wrapped_message = bos_human_message + elif (self.turn_count % 2) == 0: + human_message = self.wrap_human(message) + wrapped_message = human_message + else: + gpt_message = self.wrap_gpt(message) + wrapped_message = gpt_message + + # Update Prompt + self.prompt += wrapped_message + + # Bump Turn Counter + self.turn_count += 1 + + # Return "wrapped_message" (effective string added to context) + return wrapped_message + + def get_potential_prompt(self, message: str) -> None: + # Assumes that it's always the user's (human's) turn! + prompt_copy = str(self.prompt) + + human_message = self.wrap_human(message) + prompt_copy += human_message + + return prompt_copy.rstrip() + + def get_prompt(self) -> str: + return self.prompt.rstrip() diff --git a/prismatic/models/materialize.py b/prismatic/models/materialize.py index 995fb445..df1ad612 100644 --- a/prismatic/models/materialize.py +++ b/prismatic/models/materialize.py @@ -9,7 +9,7 @@ from transformers import PreTrainedTokenizerBase -from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone +from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone from prismatic.models.backbones.vision import ( CLIPViTBackbone, DinoCLIPViTBackbone, @@ -67,6 +67,9 @@ # === Mistral v0.1 Backbones === "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}}, "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}}, + + # === Phi-2 Backbone === + "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}}, } # fmt: on diff --git a/prismatic/preprocessing/datasets/datasets.py b/prismatic/preprocessing/datasets/datasets.py index ffaa127f..35f866ed 100644 --- a/prismatic/preprocessing/datasets/datasets.py +++ b/prismatic/preprocessing/datasets/datasets.py @@ -17,7 +17,7 @@ import torch from PIL import Image from torch.utils.data import Dataset -from transformers import LlamaTokenizerFast, PreTrainedTokenizerBase +from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase from prismatic.models.backbones.llm.prompting import PromptBuilder from prismatic.models.backbones.vision import ImageTransform @@ -52,7 +52,7 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: the "prompt" from the human, and instead directly predict the caption from the image. As a concrete example given the "raw data" for the first example: - example = self.examples[0]["conversations]` = { + example = self.examples[0]["conversations"]` = { [ {"from": "human", "value": "Render a clear and concise summary of the photo.\n"}, {"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"} @@ -144,6 +144,11 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: # Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty! if isinstance(self.tokenizer, LlamaTokenizerFast): msg = msg.rstrip() + + # Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling! + elif isinstance(self.tokenizer, CodeGenTokenizerFast): + pass + else: raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!") diff --git a/scripts/pretrain.py b/scripts/pretrain.py index c2d72025..6fad6446 100644 --- a/scripts/pretrain.py +++ b/scripts/pretrain.py @@ -51,7 +51,7 @@ class PretrainConfig: # ModelConfig (`prismatic/conf/models.py`); override with --model.type `ModelRegistry..model_id` model: ModelConfig = field( - default_factory=ModelConfig.get_choice_class(ModelRegistry.EXT_EXP_MISTRAL_INSTRUCT_V1_7B.model_id) + default_factory=ModelConfig.get_choice_class(ModelRegistry.EXT_EXP_PHI_2_3B.model_id) ) # DatasetConfig (`prismatic/conf/datasets.py`); override with --dataset.type `DatasetRegistry..dataset_id` From 97cf31be32b39e3da96cf52ba435b03fccb4e9d0 Mon Sep 17 00:00:00 2001 From: Sidd Karamcheti Date: Thu, 25 Apr 2024 05:21:03 -0700 Subject: [PATCH 2/2] Update launch --- scripts/sagemaker/launch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/sagemaker/launch.py b/scripts/sagemaker/launch.py index 78fdfece..fb23dba3 100644 --- a/scripts/sagemaker/launch.py +++ b/scripts/sagemaker/launch.py @@ -42,10 +42,10 @@ class LaunchConfig: # Prismatic VLM Pretraining Parameters model_type: str = ( # Unique Model ID (specifies config) - ModelRegistry.PRISM_DINOSIGLIP_224PX_7B.model_id + ModelRegistry.EXT_EXP_MISTRAL_V1_7B.model_id ) dataset_type: str = ( # Unique Dataset ID (specifies config) - DatasetRegistry.LLAVA_LVIS4V_LRV.dataset_id + DatasetRegistry.LLAVA_V15.dataset_id ) # Stage & Batch Size Parameters =>> Set dynamically based on instance count!