Skip to content

Commit

Permalink
Chat dataset support (#11423)
Browse files Browse the repository at this point in the history
* chat dataset support

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

* add ci test

Signed-off-by: Chen Cui <[email protected]>

* address comment

Signed-off-by: Chen Cui <[email protected]>

* Apply isort and black reformatting

Signed-off-by: cuichenx <[email protected]>

* address comment

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: cuichenx <[email protected]>
Co-authored-by: cuichenx <[email protected]>
  • Loading branch information
cuichenx and cuichenx authored Dec 16, 2024
1 parent e92ec13 commit 356a3a6
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 15 deletions.
33 changes: 32 additions & 1 deletion .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4551,6 +4551,36 @@ jobs:
--pp_size 1 \
--mbs 1 --packed
L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--devices 2 \
--max_steps 3 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
--peft lora \
--tp_size 1 \
--pp_size 1 \
--mbs 1 \
--chat_dataset_path /home/TestData/nemo2_data/chat
python tests/collections/llm/gpt_finetuning.py \
--restore_path /home/TestData/nemo2_ckpt/llama_68M \
--devices 2 \
--max_steps 6 \
--experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \
--peft lora \
--tp_size 1 \
--pp_size 1 \
--mbs 1 \
--chat_dataset_path /home/TestData/nemo2_data/chat
L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4661,7 +4691,7 @@ jobs:
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_ckpt
rm -rf /tmp/nemo2_ptq_engine
L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4820,6 +4850,7 @@ jobs:
- L2_NeMo_2_GPT_LoRA_TP1PP1_MBS2
- L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2
- L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2
- L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_Chat
- L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED
- L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED
- L2_NeMo_2_GPT_CLoRA_TP1PP1_MBS1_PACKED
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from nemo.collections.llm.gpt.data import (
AlpacaDataModule,
ChatDataModule,
DollyDataModule,
FineTuningDataModule,
HFDatasetDataModule,
Expand Down Expand Up @@ -220,6 +221,7 @@
"Qwen2Config72B",
"PreTrainingDataModule",
"FineTuningDataModule",
"ChatDataModule",
"SquadDataModule",
"T5PreTrainingDataModule",
"T5FineTuningDataModule",
Expand Down
8 changes: 5 additions & 3 deletions nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from nemo.collections.llm.gpt.data.alpaca import AlpacaDataModule
from nemo.collections.llm.gpt.data.chat import ChatDataModule
from nemo.collections.llm.gpt.data.dolly import DollyDataModule
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule
Expand All @@ -21,12 +22,13 @@
from nemo.collections.llm.gpt.data.squad import SquadDataModule

__all__ = [
"FineTuningDataModule",
"AlpacaDataModule",
"SquadDataModule",
"ChatDataModule",
"DollyDataModule",
"FineTuningDataModule",
"HFDatasetDataModule",
"MockDataModule",
"PreTrainingDataModule",
"build_pretraining_datamodule",
"HFDatasetDataModule",
"SquadDataModule",
]
41 changes: 41 additions & 0 deletions nemo/collections/llm/gpt/data/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import lru_cache

from nemo.collections.llm.gpt.data.core import create_sft_dataset
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule


class ChatDataModule(FineTuningDataModule):
"""
Base class for fine-tuning an LLM on chat datasets.
This class calls `GPTSFTChatDataset` for chat template processing
See base class `FineTuningDataModule` for more details.
"""

@lru_cache
def _create_dataset(self, path, is_test=False, **kwargs):
# pylint: disable=C0115,C0116
return create_sft_dataset(
path,
tokenizer=self.tokenizer,
seq_length=(self.seq_length if is_test or self.packed_sequence_size <= 0 else self.packed_sequence_size),
memmap_workers=self.memmap_workers,
seed=self.seed,
chat=True,
is_test=is_test,
**kwargs,
)
10 changes: 9 additions & 1 deletion nemo/collections/llm/gpt/data/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,17 @@ def create_sft_dataset(
memmap_workers: int = 2,
hf_dataset: bool = False,
global_sample_mapping: bool = False,
chat: bool = False,
**kwargs,
) -> "GPTSFTDataset":
if path.suffix == '.npy':
"""
Create the dataset class (GPTSFTDataset, GPTSFTChatDataset or GPTSFTPackedDataset)
"""
if chat:
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import GPTSFTChatDataset

dataset_cls = GPTSFTChatDataset
elif path.suffix == '.npy':
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTPackedDataset

dataset_cls = GPTSFTPackedDataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ def _maybe_validate_prompt_template(self):

def _build_samples_mapping(self):
super()._build_samples_mapping()
assert hasattr(self.tokenizer, "vocab"), "tokenizer should have vocab property, not supported"
LABEL_START = self.special_tokens['label_start']
END_NAME_SIGNAL = self.special_tokens['end_of_name']

Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def on_import_ckpt(self, model: pl.LightningModule):
def save_hf_tokenizer_assets(self, tokenizer_name_or_path, save_path="/tmp/nemo_tokenizer"):
from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
tok = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
# Save tokenizer assets to save_path.
tok.save_pretrained(save_path)
return save_path
30 changes: 22 additions & 8 deletions tests/collections/llm/gpt_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def get_args():
parser.add_argument('--tp_size', type=int, default=1, help="tensor parallel size")
parser.add_argument('--pp_size', type=int, default=1, help="pipeline parallel size")
parser.add_argument('--packed', action='store_true', help="use packed sequence dataset")
parser.add_argument(
'--chat_dataset_path', type=str, default="", help="path to chat dataset. Uses dolly if this is empty."
)

return parser.parse_args()

Expand Down Expand Up @@ -105,13 +108,24 @@ def get_args():
packed_sequence_specs = (
PackedSequenceSpecs(packed_sequence_size=2048, tokenizer_model_name="dummy_tokenizer") if args.packed else None
)
dolly = llm.DollyDataModule(
seq_length=2048,
micro_batch_size=args.mbs,
global_batch_size=4,
num_workers=0,
packed_sequence_specs=packed_sequence_specs,
)
if args.chat_dataset_path:
assert not args.packed
data = llm.ChatDataModule(
dataset_root=args.chat_dataset_path,
seq_length=2048,
micro_batch_size=args.mbs,
global_batch_size=8,
num_workers=0,
packed_sequence_specs=packed_sequence_specs,
)
else:
data = llm.DollyDataModule(
seq_length=2048,
micro_batch_size=args.mbs,
global_batch_size=8,
num_workers=0,
packed_sequence_specs=packed_sequence_specs,
)

tokenizer = get_nmt_tokenizer(tokenizer_model=os.path.join(args.restore_path, "dummy_tokenizer.model"))
llama3_8b = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer)
Expand All @@ -123,7 +137,7 @@ def get_args():

llm.finetune(
model=llama3_8b,
data=dolly,
data=data,
trainer=trainer,
peft=peft,
log=logger,
Expand Down

0 comments on commit 356a3a6

Please sign in to comment.