From 3867529db38c7f9120c3c9fd799cb69bc72d7b4c Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Wed, 11 Dec 2024 19:54:10 -0800 Subject: [PATCH] Add from_dict to HFDatasetDataModule (#11559) * Add from_dict method Signed-off-by: Alexandros Koumparoulis * add test_load_from_dict Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * add test_load_from_dict Signed-off-by: Alexandros Koumparoulis * fix Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa --------- Signed-off-by: Alexandros Koumparoulis Signed-off-by: akoumpa Co-authored-by: akoumpa --- nemo/collections/llm/gpt/data/hf_dataset.py | 80 ++++++++++++------- .../llm/gpt/data/test_hf_datamodule.py | 36 +++++++-- 2 files changed, 79 insertions(+), 37 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 039e5b90b096..73b6444a6e9c 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datasets.dataset_dict import lightning.pytorch as pl import torch from datasets import load_dataset @@ -22,38 +21,40 @@ from nemo.utils import logging -def make_dataset_splits(path, split, split_aliases, kwargs): +def make_dataset_splits(dataset, split, split_aliases): """ - Loads a dataset with datasets.load_dataset and - returns a dictionary containing all dataset splits. + Given a dataset (e.g. from datasets.load_dataset or datasets.Dataset.from_dict) it + returns a dictionary containing the corresponding dataset splits. For example: - ans = make_dataset_splits("dataset-id") - $ ds = load_dataset("dataset-id") - $ print(ds) - > DatasetDict({ - > train: Dataset({ - > features: ['id', 'title', 'context', 'question', 'answers'], - > num_rows: 87599 - > }) - > validation: Dataset({ - > features: ['id', 'title', 'context', 'question', 'answers'], - > num_rows: 10570 - > }) - > }) - - In this case the value of `ans` (returned value) will be: + $ ds = load_dataset("dataset-id") + $ ans = make_dataset_splits(ds) + + # `ds` contains the following + $ print(ds) + > DatasetDict({ + > train: Dataset({ + > features: ['id', 'title', 'context', 'question', 'answers'], + > num_rows: 87599 + > }) + > validation: Dataset({ + > features: ['id', 'title', 'context', 'question', 'answers'], + > num_rows: 10570 + > }) + > }) + + # In this case the value of `ans` (returned value) will be: $ print(ans) > { > "train": Dataset .. (with 87599 rows), > "val": Dataset .. (with 10570 rows), > } """ - dataset = load_dataset(path, split=split, **kwargs) + from datasets import Dataset, DatasetDict split_names = ['train', 'test', 'val'] - dataset_splits = {split: None for split in split_names} + dataset_splits = {_split: None for _split in split_names} alias_to_split = {} for split_name, _split_aliases in split_aliases.items(): @@ -61,7 +62,10 @@ def make_dataset_splits(path, split, split_aliases, kwargs): for alias in _split_aliases: alias_to_split[alias] = split_name - if isinstance(dataset, datasets.dataset_dict.DatasetDict): + if isinstance(dataset, Dataset): + assert isinstance(split, str), "Expected split to be a string, but got " + str(type(split)) + dataset_splits[split] = dataset + elif isinstance(dataset, DatasetDict): dataset_split_names = dataset.keys() logging.info(f"HF dataset has the following splits: {dataset_split_names}") for alias_split_name, split in dataset.items(): @@ -89,9 +93,8 @@ def make_dataset_splits(path, split, split_aliases, kwargs): else: raise ValueError("Expected split name to be None, str or a list") - assert ( - sum(map(lambda x: x is not None, dataset_splits.values())) > 0 - ), "Expected at least one dataset to have been initialized" + num_init_splits = sum(map(lambda x: x is not None, dataset_splits.values())) + assert num_init_splits > 0, f"Expected at least one split to have been initialized {num_init_splits}" return dataset_splits @@ -111,9 +114,9 @@ class HFDatasetDataModule(pl.LightningDataModule): def __init__( self, - path, - collate_fn=None, + path_or_dataset, split=None, + collate_fn=None, num_workers=2, pin_memory=True, persistent_workers=True, @@ -130,8 +133,7 @@ def __init__( ) -> None: super().__init__() assert pad_token_id is not None - - logging.info(f"Loading HF dataset from {path}") + from datasets import Dataset, DatasetDict # A dataset usually will have several splits (e.g. train, val, test, etc). # We map synonym names to canonical names (train, test, val). @@ -139,7 +141,18 @@ def __init__( split_aliases = {'train': train_aliases, 'test': test_aliases, 'val': val_aliases} # self.dataset_splits will hold the actual dataset for each split. - self.dataset_splits = make_dataset_splits(path, split, split_aliases, kwargs) + if isinstance(path_or_dataset, str): + logging.info(f"Loading HF dataset from {path_or_dataset}") + dataset = load_dataset(path_or_dataset, split=split, **kwargs) + elif isinstance(path_or_dataset, Dataset) or isinstance(path_or_dataset, DatasetDict): + logging.info(f"Using passed HF dataset {str(path_or_dataset)}") + dataset = path_or_dataset + else: + raise ValueError( + "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " + str(type(path_or_dataset)) + ) + + self.dataset_splits = make_dataset_splits(dataset, split, split_aliases) if collate_fn is None: self._collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) @@ -157,6 +170,13 @@ def __init__( self.use_mcore_sampler = use_mcore_sampler self.mcore_dataloader_type = mcore_dataloader_type + @staticmethod + def from_dict(dataset_dict, split, **kwargs): + from datasets import Dataset + + dataset = Dataset.from_dict(dataset_dict) + return HFDatasetDataModule(path_or_dataset=dataset, split=split, **kwargs) + @staticmethod def collate_fn(batch, pad_token_id=0): def batchify(tensor): diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py index a8d264701d39..58f7c02e091b 100644 --- a/tests/collections/llm/gpt/data/test_hf_datamodule.py +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -19,7 +19,7 @@ def test_load_single_split(): ds = llm.HFDatasetDataModule( - path=DATA_PATH, + path_or_dataset=DATA_PATH, split='train', seq_length=512, micro_batch_size=2, @@ -46,7 +46,7 @@ def test_load_nonexistent_split(): expected_msg = '''Unknown split "this_split_name_should_not_exist". Should be one of ['train', 'validation'].''' try: llm.HFDatasetDataModule( - path=DATA_PATH, + path_or_dataset=DATA_PATH, split='this_split_name_should_not_exist', seq_length=512, micro_batch_size=2, @@ -59,7 +59,7 @@ def test_load_nonexistent_split(): def test_load_multiple_split(): ds = llm.HFDatasetDataModule( - path=DATA_PATH, + path_or_dataset=DATA_PATH, split=['train', 'validation'], seq_length=512, micro_batch_size=2, @@ -88,7 +88,7 @@ def test_validate_dataset_asset_accessibility_file_does_not_exist(): raised_exception = False try: llm.HFDatasetDataModule( - path="/this/path/should/not/exist/", + path_or_dataset="/this/path/should/not/exist/", seq_length=512, micro_batch_size=2, global_batch_size=2, @@ -103,12 +103,34 @@ def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trai raised_exception = False try: llm.HFDatasetDataModule( - path=None, + path_or_dataset=None, seq_length=512, micro_batch_size=2, global_batch_size=2, ) - except TypeError: - raised_exception = True + except ValueError as e: + raised_exception = ( + str(e) == "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " + ) assert raised_exception == True, "Expected to raise a ValueError" + + +def test_load_from_dict(): + data = {'text': "Below is an instruction that describes a task, paired with an input that "} + + datamodule = llm.HFDatasetDataModule.from_dict( + {"text": [data['text'] for _ in range(101)]}, + split='train', + global_batch_size=4, + micro_batch_size=1, + ) + assert datamodule is not None + assert isinstance(datamodule, llm.HFDatasetDataModule) + assert hasattr(datamodule, 'train') + assert datamodule.train is not None + assert len(datamodule.train) == 101 + assert hasattr(datamodule, 'val') + assert datamodule.val is None + assert hasattr(datamodule, 'test') + assert datamodule.test is None