Skip to content

Commit

Permalink
Alit/mamba 2 0 migration (#10338)
Browse files Browse the repository at this point in the history
  • Loading branch information
JRD971000 authored Sep 7, 2024
1 parent 7ba0681 commit 62c1dce
Show file tree
Hide file tree
Showing 11 changed files with 607 additions and 10 deletions.
16 changes: 16 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
Baichuan2Config,
Baichuan2Config7B,
Baichuan2Model,
BaseMambaConfig1_3B,
BaseMambaConfig2_7B,
BaseMambaConfig130M,
BaseMambaConfig370M,
BaseMambaConfig780M,
ChatGLM2Config6B,
ChatGLM3Config6B,
ChatGLMConfig,
Expand Down Expand Up @@ -71,12 +76,15 @@
Nemotron4Config340B,
NemotronConfig,
NemotronModel,
NVIDIAMambaConfig8B,
NVIDIAMambaHybridConfig8B,
Qwen2Config,
Qwen2Config1P5B,
Qwen2Config7B,
Qwen2Config72B,
Qwen2Config500M,
Qwen2Model,
SSMConfig,
Starcoder2Config,
Starcoder2Config3B,
Starcoder2Config7B,
Expand Down Expand Up @@ -120,6 +128,14 @@
"Nemotron4Config22B",
"Nemotron4Config340B",
"NemotronConfig",
"SSMConfig",
"BaseMambaConfig130M",
"BaseMambaConfig370M",
"BaseMambaConfig780M",
"BaseMambaConfig1_3B",
"BaseMambaConfig2_7B",
"NVIDIAMambaConfig8B",
"NVIDIAMambaHybridConfig8B",
"LlamaConfig",
"Llama2Config7B",
"Llama2Config13B",
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/llm/gpt/data/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
pad_to_max_length: bool = False,
):
super().__init__()
self.seq_length = seq_length
Expand All @@ -78,6 +79,7 @@ def __init__(
self.rampup_batch_size = rampup_batch_size
self.data_sampler = None
self.max_train_samples = None
self.pad_to_max_length = pad_to_max_length

def setup(self, stage: str):
self.data_sampler = MegatronDataSampler(
Expand All @@ -97,6 +99,7 @@ def train_dataloader(self) -> DataLoader:
self._create_dataset(
str(self.train_path),
max_num_samples=self.max_train_samples,
pad_to_max_length=self.pad_to_max_length,
)
)

Expand All @@ -105,6 +108,7 @@ def val_dataloader(self) -> DataLoader:
self._create_dataset(
str(self.validation_path),
is_test=True,
pad_to_max_length=self.pad_to_max_length,
),
)

Expand All @@ -114,6 +118,7 @@ def test_dataloader(self) -> DataLoader:
str(self.test_path),
tokens_to_generate=32,
is_test=True,
pad_to_max_length=self.pad_to_max_length,
)
)

Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/data/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
pad_to_max_length: bool = False,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw
Expand All @@ -69,6 +70,7 @@ def __init__(
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
pad_to_max_length=pad_to_max_length,
)

def prepare_data(self) -> None:
Expand Down
18 changes: 18 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@
Qwen2Config500M,
Qwen2Model,
)
from nemo.collections.llm.gpt.model.ssm import (
BaseMambaConfig1_3B,
BaseMambaConfig2_7B,
BaseMambaConfig130M,
BaseMambaConfig370M,
BaseMambaConfig780M,
NVIDIAMambaConfig8B,
NVIDIAMambaHybridConfig8B,
SSMConfig,
)
from nemo.collections.llm.gpt.model.starcoder import StarcoderConfig, StarcoderConfig15B, StarcoderModel
from nemo.collections.llm.gpt.model.starcoder2 import (
Starcoder2Config,
Expand Down Expand Up @@ -137,6 +147,14 @@
"Qwen2Config7B",
"Qwen2Config72B",
"Qwen2Model",
"SSMConfig",
"BaseMambaConfig130M",
"BaseMambaConfig370M",
"BaseMambaConfig780M",
"BaseMambaConfig1_3B",
"BaseMambaConfig2_7B",
"NVIDIAMambaConfig8B",
"NVIDIAMambaHybridConfig8B",
"MaskedTokenLossReduction",
"gpt_data_step",
"gpt_forward_step",
Expand Down
Loading

0 comments on commit 62c1dce

Please sign in to comment.