Skip to content

Commit

Permalink
Add Phi-2 Backbone (TRI-ML#10)
Browse files Browse the repository at this point in the history
Adds 2.7B Phi-2 LLM with non-LLaMa tokenizer handling, recommended
"Input: / Output:" formatting.
  • Loading branch information
siddk authored Apr 30, 2024
2 parents a867c18 + 97cf31b commit 0467cbb
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 15 deletions.
9 changes: 8 additions & 1 deletion prismatic/conf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class Exp_13B_Llama2(Exp_13B_One_Stage):
llm_backbone_id: str = "llama2-13b-pure"


# ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~
# ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct, Phi-2 ~
@dataclass
class Ext_Exp_7B_Llama2_Chat(Exp_7B_One_Stage):
model_id: str = "llama2-chat+7b"
Expand All @@ -284,6 +284,12 @@ class Ext_Exp_7B_Mistral_Instruct_V1(Exp_7B_One_Stage):
llm_backbone_id: str = "mistral-v0.1-7b-instruct"


@dataclass
class Ext_Exp_3B_Phi_2(Exp_7B_One_Stage):
model_id: str = "phi-2+3b"
llm_backbone_id: str = "phi-2-3b"


# Section 4.3B :: ✌️ --> Co-training on Language-only Data
# =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training)
@dataclass
Expand Down Expand Up @@ -531,6 +537,7 @@ class ModelRegistry(Enum):
EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat
EXT_EXP_MISTRAL_V1_7B = Ext_Exp_7B_Mistral_V1
EXT_EXP_MISTRAL_INSTRUCT_V1_7B = Ext_Exp_7B_Mistral_Instruct_V1
EXT_EXP_PHI_2_3B = Ext_Exp_3B_Phi_2

# Cotraining w/ Unimodal Data
EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining
Expand Down
1 change: 1 addition & 0 deletions prismatic/models/backbones/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_llm import LLMBackbone
from .llama2 import LLaMa2LLMBackbone
from .mistral import MistralLLMBackbone
from .phi import PhiLLMBackbone
22 changes: 17 additions & 5 deletions prismatic/models/backbones/llm/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,23 +152,35 @@ def __init__(
padding_side="right",
)

# Explicitly verify that Tokenizer padding_side is set to right for training!
assert self.tokenizer.padding_side == "right", "Tokenizer `padding_side` is not set to `right`!"

# Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input
# starts with a <BOS> token unless `add_special_tokens = False`; for these models, we empirically
# find that adding image patches *after* the BOS leads to much better performance
# find that adding image patches *after* the BOS leads to much better performance.
#
# As a result we explicitly validate that a tokenizer conforms to the expected behavior; if you're reading this
# line, it's probably because you're adding a new LLM with a different tokenizer behavior. If so, feel free to
# override this, but make sure to make the appropriate changes in the `datasets.py` and VLM `forward()` logic!
# override the `SPECIAL_CASES` set below, but make sure to make the appropriate changes in the `datasets.py`
# and VLM `forward()` logic!
SPECIAL_CASES = {
# Phi-2 Tokenizer doesn't add any BOS tokens by default, and sets BOS == EOS == "<|endoftext|>"
# =>> We'll prepend BOS to first input (to play nicely with image token insertion logic; verified that
# this works well with base LLM generation.
# =>> Like Llama-2 Tokenizers -- we'll add a special PAD token for training purposes.
"phi-2-3b",
}
if self.identifier in SPECIAL_CASES:
return

# Note =>> this assert should hold for all Llama-derived tokenizers (`LlamaTokenizerFast` ==> includes Mistral!
assert (self.tokenizer("Test 123", add_special_tokens=True).input_ids[0] == self.tokenizer.bos_token_id) and (
self.tokenizer("Test 123", add_special_tokens=False).input_ids[0] != self.tokenizer.bos_token_id
), (
f"Default Tokenizer of type `{type(self.tokenizer)}` does not automatically prefix inputs with BOS token!\n"
"Please read the comment in `base_llm.py` for more information!"
)

# Additionally, explicitly verify that Tokenizer padding_side is set to right for training!
assert self.tokenizer.padding_side == "right", "Tokenizer `padding_side` is not set to `right`!"

def get_fsdp_wrapping_policy(self) -> Callable:
"""Return a `transformer_auto_wrap_policy` where we wrap each instance of `self.transformer_layer_cls`"""
transformer_block_policy = partial(
Expand Down
64 changes: 64 additions & 0 deletions prismatic/models/backbones/llm/phi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
phi.py
Class definition for all LLMs derived from PhiForCausalLM.
"""

from typing import Optional, Type

import torch
from torch import nn as nn
from transformers import PhiForCausalLM
from transformers.models.phi.modeling_phi import PhiDecoderLayer

from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder

# Registry ==> Support Phi Models (from HF Transformers)
# fmt: off
PHI_MODELS = {
# === Phi-2 ===
"phi-2-3b": {
"llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2"
}
}
# fmt: on


class PhiLLMBackbone(HFCausalLLMBackbone):
def __init__(
self,
llm_backbone_id: str,
llm_max_length: int = 2048,
hf_token: Optional[str] = None,
inference_mode: bool = False,
use_flash_attention_2: bool = True,
) -> None:
super().__init__(
llm_backbone_id,
llm_max_length=llm_max_length,
hf_token=hf_token,
inference_mode=inference_mode,
use_flash_attention_2=use_flash_attention_2,
**PHI_MODELS[llm_backbone_id],
)

# [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize)
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
self.llm.config.pad_token_id = self.tokenizer.pad_token_id
self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)

@property
def prompt_builder_fn(self) -> Type[PromptBuilder]:
if self.identifier.startswith("phi-2"):
return PhiPromptBuilder

raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")

@property
def transformer_layer_cls(self) -> Type[nn.Module]:
return PhiDecoderLayer

@property
def half_precision_dtype(self) -> torch.dtype:
return torch.bfloat16
1 change: 1 addition & 0 deletions prismatic/models/backbones/llm/prompting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_prompter import PromptBuilder, PurePromptBuilder
from .llama2_chat_prompter import LLaMa2ChatPromptBuilder
from .mistral_instruct_prompter import MistralInstructPromptBuilder
from .phi_prompter import PhiPromptBuilder
from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> No
self.bos, self.eos = "<s>", "</s>"

# Get role-specific "wrap" functions
self.wrap_human = lambda msg: f"{self.bos}[INST] {msg} [/INST] "
self.wrap_human = lambda msg: f"[INST] {msg} [/INST] "
self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"

# === `self.prompt` gets built up over multiple turns ===
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ class MistralInstructPromptBuilder(PromptBuilder):
def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
super().__init__(model_family, system_prompt)

# Note =>> Mistral Tokenizer is an instance of LlamaTokenizer
# Note =>> Mistral Tokenizer is an instance of `LlamaTokenizer(Fast)`
# =>> Mistral Instruct *does not* use a System Prompt
self.bos, self.eos = "<s>", "</s>"

# Get role-specific "wrap" functions
# =>> Note that {bos} should only be added on FIRST utterance (key departure vs. LLaMa-2 Chat)
self.wrap_human = lambda msg: f"[INST] {msg} [/INST] "
self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"

Expand Down
65 changes: 65 additions & 0 deletions prismatic/models/backbones/llm/prompting/phi_prompter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
phi_prompter.py
Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft.
Also handles Phi special case BOS token additions.
Reference: https://huggingface.co/microsoft/phi-2#qa-format
"""

from typing import Optional

from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder


class PhiPromptBuilder(PromptBuilder):
def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
super().__init__(model_family, system_prompt)

# Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)`
# =>> By default, does *not* append <BOS> / <EOS> tokens --> we handle that here (IMPORTANT)!
self.bos, self.eos = "<|endoftext|>", "<|endoftext|>"

# Get role-specific "wrap" functions
# =>> Note that placement of <bos>/<eos> were based on experiments generating from Phi-2 in Input/Output mode
self.wrap_human = lambda msg: f"Input: {msg}\nOutput: "
self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}"

# === `self.prompt` gets built up over multiple turns ===
self.prompt, self.turn_count = "", 0

def add_turn(self, role: str, message: str) -> str:
assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
message = message.replace("<image>", "").strip()

# Special Handling for "first" input --> prepend a <BOS> token (expected by Prismatic)
if self.turn_count == 0:
bos_human_message = f"{self.bos}{self.wrap_human(message)}"
wrapped_message = bos_human_message
elif (self.turn_count % 2) == 0:
human_message = self.wrap_human(message)
wrapped_message = human_message
else:
gpt_message = self.wrap_gpt(message)
wrapped_message = gpt_message

# Update Prompt
self.prompt += wrapped_message

# Bump Turn Counter
self.turn_count += 1

# Return "wrapped_message" (effective string added to context)
return wrapped_message

def get_potential_prompt(self, message: str) -> None:
# Assumes that it's always the user's (human's) turn!
prompt_copy = str(self.prompt)

human_message = self.wrap_human(message)
prompt_copy += human_message

return prompt_copy.rstrip()

def get_prompt(self) -> str:
return self.prompt.rstrip()
5 changes: 4 additions & 1 deletion prismatic/models/materialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from transformers import PreTrainedTokenizerBase

from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone
from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone
from prismatic.models.backbones.vision import (
CLIPViTBackbone,
DinoCLIPViTBackbone,
Expand Down Expand Up @@ -67,6 +67,9 @@
# === Mistral v0.1 Backbones ===
"mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}},
"mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}},

# === Phi-2 Backbone ===
"phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}},
}

# fmt: on
Expand Down
9 changes: 7 additions & 2 deletions prismatic/preprocessing/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import LlamaTokenizerFast, PreTrainedTokenizerBase
from transformers import CodeGenTokenizerFast, LlamaTokenizerFast, PreTrainedTokenizerBase

from prismatic.models.backbones.llm.prompting import PromptBuilder
from prismatic.models.backbones.vision import ImageTransform
Expand Down Expand Up @@ -52,7 +52,7 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
the "prompt" from the human, and instead directly predict the caption from the image.
As a concrete example given the "raw data" for the first example:
example = self.examples[0]["conversations]` = {
example = self.examples[0]["conversations"]` = {
[
{"from": "human", "value": "Render a clear and concise summary of the photo.\n<image>"},
{"from": "gpt", "value": "select luxury furniture 3 - inch gel memory foam mattress topper"}
Expand Down Expand Up @@ -144,6 +144,11 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
# Llama Tokenizer (Fast) adds extra character if a string ends in whitespace --> strip if non-empty!
if isinstance(self.tokenizer, LlamaTokenizerFast):
msg = msg.rstrip()

# Phi-2 Tokenizer == CodeGenTokenizer (Fast) -- no special handling!
elif isinstance(self.tokenizer, CodeGenTokenizerFast):
pass

else:
raise ValueError(f"Tokenizer of type `{type(self.tokenizer)}` is not explicitly handled!")

Expand Down
2 changes: 1 addition & 1 deletion scripts/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PretrainConfig:

# ModelConfig (`prismatic/conf/models.py`); override with --model.type `ModelRegistry.<MODEL>.model_id`
model: ModelConfig = field(
default_factory=ModelConfig.get_choice_class(ModelRegistry.EXT_EXP_MISTRAL_INSTRUCT_V1_7B.model_id)
default_factory=ModelConfig.get_choice_class(ModelRegistry.EXT_EXP_PHI_2_3B.model_id)
)

# DatasetConfig (`prismatic/conf/datasets.py`); override with --dataset.type `DatasetRegistry.<DATASET>.dataset_id`
Expand Down
4 changes: 2 additions & 2 deletions scripts/sagemaker/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ class LaunchConfig:

# Prismatic VLM Pretraining Parameters
model_type: str = ( # Unique Model ID (specifies config)
ModelRegistry.PRISM_DINOSIGLIP_224PX_7B.model_id
ModelRegistry.EXT_EXP_MISTRAL_V1_7B.model_id
)
dataset_type: str = ( # Unique Dataset ID (specifies config)
DatasetRegistry.LLAVA_LVIS4V_LRV.dataset_id
DatasetRegistry.LLAVA_V15.dataset_id
)

# Stage & Batch Size Parameters =>> Set dynamically based on instance count!
Expand Down

0 comments on commit 0467cbb

Please sign in to comment.