Skip to content

Commit

Permalink
Add Mistral/Mistral-Instruct Backbones
Browse files Browse the repository at this point in the history
  • Loading branch information
siddk committed Apr 23, 2024
1 parent f5c097d commit 9e4c962
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 5 deletions.
18 changes: 16 additions & 2 deletions prismatic/conf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class Exp_13B_Llama2(Exp_13B_One_Stage):
llm_backbone_id: str = "llama2-13b-pure"


# ~ Additional LLM Backbones ~
# ~ Additional LLM Backbones :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~
@dataclass
class Ext_Exp_7B_Llama2_Chat(Exp_7B_One_Stage):
model_id: str = "llama2-chat+7b"
Expand All @@ -272,6 +272,18 @@ class Ext_Exp_13B_Llama2_Chat(Exp_13B_One_Stage):
llm_backbone_id: str = "llama2-13b-chat"


@dataclass
class Ext_Exp_7B_Mistral_V1(Exp_7B_One_Stage):
model_id: str = "mistral-v0.1+7b"
llm_backbone_id: str = "mistral-v0.1-7b-pure"


@dataclass
class Ext_Exp_7B_Mistral_Instruct_V1(Exp_7B_One_Stage):
model_id: str = "mistral-instruct-v0.1+7b"
llm_backbone_id: str = "mistral-v0.1-7b-instruct"


# Section 4.3B :: ✌️ --> Co-training on Language-only Data
# =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training)
@dataclass
Expand Down Expand Up @@ -514,9 +526,11 @@ class ModelRegistry(Enum):
EXP_LLAMA2_7B = Exp_7B_Llama2
EXP_LLAMA2_13B = Exp_13B_Llama2

# ~ Additional LLM Backbone Experiments :: LLaMa-2 Chat ~
# ~ Additional LLM Backbone Experiments :: LLaMa-2 Chat, Mistral v0.1, Mistral v0.1 Instruct ~
EXT_EXP_LLAMA2_CHAT_7B = Ext_Exp_7B_Llama2_Chat
EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat
EXT_EXP_MISTRAL_V1_7B = Ext_Exp_7B_Mistral_V1
EXT_EXP_MISTRAL_INSTRUCT_V1_7B = Ext_Exp_7B_Mistral_Instruct_V1

# Cotraining w/ Unimodal Data
EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining
Expand Down
1 change: 1 addition & 0 deletions prismatic/models/backbones/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base_llm import LLMBackbone
from .llama2 import LLaMa2LLMBackbone
from .mistral import MistralLLMBackbone
7 changes: 6 additions & 1 deletion prismatic/models/backbones/llm/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,12 @@ def __init__(

# Load (Fast) Tokenizer
overwatch.info(f"Loading [bold]{llm_family}[/] (Fast) Tokenizer via the AutoTokenizer API", ctx_level=1)
self.tokenizer = AutoTokenizer.from_pretrained(hf_hub_path, model_max_length=self.llm_max_length, token=hf_token)
self.tokenizer = AutoTokenizer.from_pretrained(
hf_hub_path,
model_max_length=self.llm_max_length,
token=hf_token,
padding_side="right",
)

# Validation =>> Our VLM logic currently operates under the assumption that the tokenization of a new input
# starts with a <BOS> token unless `add_special_tokens = False`; for these models, we empirically
Expand Down
72 changes: 72 additions & 0 deletions prismatic/models/backbones/llm/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
mistral.py
Class definition for all LLMs derived from MistralForCausalLM.
"""

from typing import Optional, Type

import torch
from torch import nn as nn
from transformers import MistralForCausalLM
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer

from prismatic.models.backbones.llm.base_llm import HFCausalLLMBackbone
from prismatic.models.backbones.llm.prompting import MistralInstructPromptBuilder, PromptBuilder, PurePromptBuilder

# Registry =>> Support Mistral Models (from HF Transformers)
# fmt: off
MISTRAL_MODELS = {
# === Base Mistral v0.1 ===
"mistral-v0.1-7b-pure": {
"llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-v0.1"
},

# === Mistral Instruct v0.1 ===
"mistral-v0.1-7b-instruct": {
"llm_family": "mistral", "llm_cls": MistralForCausalLM, "hf_hub_path": "mistralai/Mistral-7B-Instruct-v0.1"
}
}
# fmt: on


class MistralLLMBackbone(HFCausalLLMBackbone):
def __init__(
self,
llm_backbone_id: str,
llm_max_length: int = 2048,
hf_token: Optional[str] = None,
inference_mode: bool = False,
use_flash_attention_2: bool = True,
) -> None:
super().__init__(
llm_backbone_id,
llm_max_length=llm_max_length,
hf_token=hf_token,
inference_mode=inference_mode,
use_flash_attention_2=use_flash_attention_2,
**MISTRAL_MODELS[llm_backbone_id],
)

# [Special Case] Mistral PAD Token Handling --> for clarity, we add an extra token (and resize)
self.tokenizer.add_special_tokens({"pad_token": "<PAD>"})
self.llm.config.pad_token_id = self.tokenizer.pad_token_id
self.llm.resize_token_embeddings(len(self.tokenizer), pad_to_multiple_of=64)

@property
def prompt_builder_fn(self) -> Type[PromptBuilder]:
if self.identifier.endswith("-pure"):
return PurePromptBuilder

elif self.identifier.endswith("-instruct"):
return MistralInstructPromptBuilder

raise ValueError(f"No PromptBuilder defined for LLM Backbone `{self.identifier}`")

@property
def transformer_layer_cls(self) -> Type[nn.Module]:
return MistralDecoderLayer

@property
def half_precision_dtype(self) -> torch.dtype:
return torch.bfloat16
1 change: 1 addition & 0 deletions prismatic/models/backbones/llm/prompting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_prompter import PromptBuilder, PurePromptBuilder
from .llama2_chat_prompter import LLaMa2ChatPromptBuilder
from .mistral_instruct_prompter import MistralInstructPromptBuilder
from .vicuna_v15_prompter import VicunaV15ChatPromptBuilder
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
mistral_instruct_prompter.py
Defines a PromptBuilder for building Mistral Instruct Chat Prompts --> recommended pattern used by HF / online tutorial.s
Reference: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
"""

from typing import Optional

from prismatic.models.backbones.llm.prompting.base_prompter import PromptBuilder


class MistralInstructPromptBuilder(PromptBuilder):
def __init__(self, model_family: str, system_prompt: Optional[str] = None) -> None:
super().__init__(model_family, system_prompt)

# Note =>> Mistral Tokenizer is an instance of LlamaTokenizer
# =>> Mistral Instruct *does not* use a System Prompt
self.bos, self.eos = "<s>", "</s>"

# Get role-specific "wrap" functions
# =>> Note that {bos} should only be added on FIRST utterance (key departure vs. LLaMa-2 Chat)
self.wrap_human = lambda msg: f"[INST] {msg} [/INST] "
self.wrap_gpt = lambda msg: f"{msg if msg != '' else ' '}{self.eos}"

# === `self.prompt` gets built up over multiple turns ===
self.prompt, self.turn_count = "", 0

def add_turn(self, role: str, message: str) -> str:
assert (role == "human") if (self.turn_count % 2 == 0) else (role == "gpt")
message = message.replace("<image>", "").strip()

if (self.turn_count % 2) == 0:
human_message = self.wrap_human(message)
wrapped_message = human_message
else:
gpt_message = self.wrap_gpt(message)
wrapped_message = gpt_message

# Update Prompt
self.prompt += wrapped_message

# Bump Turn Counter
self.turn_count += 1

# Return "wrapped_message" (effective string added to context)
return wrapped_message

def get_potential_prompt(self, message: str) -> None:
# Assumes that it's always the user's (human's) turn!
prompt_copy = str(self.prompt)

human_message = self.wrap_human(message)
prompt_copy += human_message

return prompt_copy.removeprefix(self.bos).rstrip()

def get_prompt(self) -> str:
# Remove prefix <bos> because it gets auto-inserted by tokenizer!
return self.prompt.removeprefix(self.bos).rstrip()
6 changes: 5 additions & 1 deletion prismatic/models/materialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from transformers import PreTrainedTokenizerBase

from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone
from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone
from prismatic.models.backbones.vision import (
CLIPViTBackbone,
DinoCLIPViTBackbone,
Expand Down Expand Up @@ -63,6 +63,10 @@
# === Vicuna-v1.5 Backbones ===
"vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},
"vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}},

# === Mistral v0.1 Backbones ===
"mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}},
"mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}},
}

# fmt: on
Expand Down
2 changes: 1 addition & 1 deletion scripts/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PretrainConfig:

# ModelConfig (`prismatic/conf/models.py`); override with --model.type `ModelRegistry.<MODEL>.model_id`
model: ModelConfig = field(
default_factory=ModelConfig.get_choice_class(ModelRegistry.EXT_EXP_LLAMA2_CHAT_13B.model_id)
default_factory=ModelConfig.get_choice_class(ModelRegistry.EXT_EXP_MISTRAL_INSTRUCT_V1_7B.model_id)
)

# DatasetConfig (`prismatic/conf/datasets.py`); override with --dataset.type `DatasetRegistry.<DATASET>.dataset_id`
Expand Down

0 comments on commit 9e4c962

Please sign in to comment.