forked from TRI-ML/prismatic-vlms
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds 2.7B Phi-2 LLM with non-LLaMa tokenizer handling, recommended "Input: / Output:" formatting.
- Loading branch information
Showing
12 changed files
with
172 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .base_llm import LLMBackbone | ||
from .llama2 import LLaMa2LLMBackbone | ||
from .mistral import MistralLLMBackbone | ||
from .phi import PhiLLMBackbone |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
phi.py | ||
Class definition for all LLMs derived from PhiForCausalLM. | ||
""" | ||
|
||
from typing import Optional, Type | ||
|
||
import torch | ||
from torch import nn as nn | ||
from transformers import PhiForCausalLM | ||
from transformers.models.phi.modeling_phi import PhiDecoderLayer | ||
|
||
from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone | ||
from prismatic.models.backbones.llm.prompting import PhiPromptBuilder, PromptBuilder | ||
|
||
# Registry ==> Support Phi Models (from HF Transformers) | ||
# fmt: off | ||
PHI_MODELS = { | ||
# === Phi-2 === | ||
"phi-2-3b": { | ||
"llm_family": "phi", "llm_cls": PhiForCausalLM, "hf_hub_path": "microsoft/phi-2" | ||
} | ||
} | ||
# fmt: on | ||
|
||
|
||
class PhiLLMBackbone(HFCausalLLMBackbone): | ||
def __init__( | ||
self, | ||
llm_backbone_id: str, | ||
llm_max_length: int = 2048, | ||
hf_token: Optional[str] = None, | ||
inference_mode: bool = False, | ||
use_flash_attention_2: bool = True, | ||
) -> None: | ||
super().__init__( | ||
llm_backbone_id, | ||
llm_max_length=llm_max_length, | ||
hf_token=hf_token, | ||
inference_mode=inference_mode, | ||
use_flash_attention_2=use_flash_attention_2, | ||
**PHI_MODELS[llm_backbone_id], | ||
) | ||
|
||
# [Special Case] Phi PAD Token Handling --> for clarity, we add an extra token (and resize) | ||
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) | ||
self.llm.config.pad_token_id = self.tokenizer.pad_token_id | ||
self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64) | ||
|
||
@property | ||
def prompt_builder_fn(self) -> Type[PromptBuilder]: | ||
if self.identifier.startswith("phi-2"): | ||
return PhiPromptBuilder | ||
|
||
raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`") | ||
|
||
@property | ||
def transformer_layer_cls(self) -> Type[nn.Module]: | ||
return PhiDecoderLayer | ||
|
||
@property | ||
def half_precision_dtype(self) -> torch.dtype: | ||
return torch.bfloat16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .base_prompter import PromptBuilder, PurePromptBuilder | ||
from .llama2_chat_prompter import LLaMa2ChatPromptBuilder | ||
from .mistral_instruct_prompter import MistralInstructPromptBuilder | ||
from .phi_prompter import PhiPromptBuilder | ||
from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
phi_prompter.py | ||
Defines a PromptBuilder for building Phi-2 Input/Output Prompts --> recommended pattern used by HF / Microsoft. | ||
Also handles Phi special case BOS token additions. | ||
Reference: https://huggingface.co/microsoft/phi-2#qa-format | ||
""" | ||
|
||
from typing import Optional | ||
|
||
from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder | ||
|
||
|
||
class PhiPromptBuilder(PromptBuilder): | ||
def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None: | ||
super().__init__(model_family, system_prompt) | ||
|
||
# Note =>> Phi Tokenizer is an instance of `CodeGenTokenizer(Fast)` | ||
# =>> By default, does *not* append <BOS> / <EOS> tokens --> we handle that here (IMPORTANT)! | ||
self.bos, self.eos = "<|endoftext|>", "<|endoftext|>" | ||
|
||
# Get role-specific "wrap" functions | ||
# =>> Note that placement of <bos>/<eos> were based on experiments generating from Phi-2 in Input/Output mode | ||
self.wrap_human = lambda msg: f"Input: {msg}\nOutput: " | ||
self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}\n{self.eos}" | ||
|
||
# === `self.prompt` gets built up over multiple turns === | ||
self.prompt, self.turn_count = "", 0 | ||
|
||
def add_turn(self, role: str, message: str) -> str: | ||
assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt") | ||
message = message.replace("<image>", "").strip() | ||
|
||
# Special Handling for "first" input --> prepend a <BOS> token (expected by Prismatic) | ||
if self.turn_count == 0: | ||
bos_human_message = f"{self.bos}{self.wrap_human(message)}" | ||
wrapped_message = bos_human_message | ||
elif (self.turn_count % 2) == 0: | ||
human_message = self.wrap_human(message) | ||
wrapped_message = human_message | ||
else: | ||
gpt_message = self.wrap_gpt(message) | ||
wrapped_message = gpt_message | ||
|
||
# Update Prompt | ||
self.prompt += wrapped_message | ||
|
||
# Bump Turn Counter | ||
self.turn_count += 1 | ||
|
||
# Return "wrapped_message" (effective string added to context) | ||
return wrapped_message | ||
|
||
def get_potential_prompt(self, message: str) -> None: | ||
# Assumes that it's always the user's (human's) turn! | ||
prompt_copy = str(self.prompt) | ||
|
||
human_message = self.wrap_human(message) | ||
prompt_copy += human_message | ||
|
||
return prompt_copy.rstrip() | ||
|
||
def get_prompt(self) -> str: | ||
return self.prompt.rstrip() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters