Skip to content

Commit

Permalink
Add LLaMa-2 Chat Backbone
Browse files Browse the repository at this point in the history
  • Loading branch information
siddk committed Apr 23, 2024
1 parent f45f7d2 commit f5c097d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
18 changes: 18 additions & 0 deletions prismatic/conf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,19 @@ class Exp_13B_Llama2(Exp_13B_One_Stage):
llm_backbone_id: str = "llama2-13b-pure"


# ~ Additional LLM Backbones ~
@dataclass
class Ext_Exp_7B_Llama2_Chat(Exp_7B_One_Stage):
model_id: str = "llama2-chat+7b"
llm_backbone_id: str = "llama2-7b-chat"


@dataclass
class Ext_Exp_13B_Llama2_Chat(Exp_13B_One_Stage):
model_id: str = "llama2-chat+13b"
llm_backbone_id: str = "llama2-13b-chat"


# Section 4.3B :: ✌️ --> Co-training on Language-only Data
# =>> Note :: Run with `--dataset.type "llava-multimodal" (multimodal data only / no co-training)
@dataclass
Expand Down Expand Up @@ -501,6 +514,11 @@ class ModelRegistry(Enum):
EXP_LLAMA2_7B = Exp_7B_Llama2
EXP_LLAMA2_13B = Exp_13B_Llama2

# ~ Additional LLM Backbone Experiments :: LLaMa-2 Chat ~
EXT_EXP_LLAMA2_CHAT_7B = Ext_Exp_7B_Llama2_Chat
EXT_EXP_LLAMA2_CHAT_13B = Ext_Exp_13B_Llama2_Chat

# Cotraining w/ Unimodal Data
EXP_VICUNA_NO_COTRAINING_7B = Exp_7B_Vicuna_No_Cotraining
EXP_LLAMA2_NO_COTRAINING_7B = Exp_7B_Llama2_No_Cotraining

Expand Down
2 changes: 1 addition & 1 deletion scripts/hf-hub/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def upload(cfg: UploadConfig) -> None:
subprocess.run(
f"HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli upload {cfg.hub_repo} {convert_path!s} {cfg.model_id}/",
shell=True,
check=True
check=True,
)

# Done
Expand Down
2 changes: 1 addition & 1 deletion scripts/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class PretrainConfig:

# ModelConfig (`prismatic/conf/models.py`); override with --model.type `ModelRegistry.<MODEL>.model_id`
model: ModelConfig = field(
default_factory=ModelConfig.get_choice_class(ModelRegistry.PRISM_DINOSIGLIP_CONTROLLED_7B.model_id)
default_factory=ModelConfig.get_choice_class(ModelRegistry.EXT_EXP_LLAMA2_CHAT_13B.model_id)
)

# DatasetConfig (`prismatic/conf/datasets.py`); override with --dataset.type `DatasetRegistry.<DATASET>.dataset_id`
Expand Down

0 comments on commit f5c097d

Please sign in to comment.