Skip to content

Commit

Permalink
Add from_dict to HFDatasetDataModule (#11559)
Browse files Browse the repository at this point in the history
* Add from_dict method

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add test_load_from_dict

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add test_load_from_dict

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
2 people authored and BoxiangW committed Dec 23, 2024
1 parent 867dd0c commit 3867529
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 37 deletions.
80 changes: 50 additions & 30 deletions nemo/collections/llm/gpt/data/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datasets.dataset_dict
import lightning.pytorch as pl
import torch
from datasets import load_dataset
Expand All @@ -22,46 +21,51 @@
from nemo.utils import logging


def make_dataset_splits(path, split, split_aliases, kwargs):
def make_dataset_splits(dataset, split, split_aliases):
"""
Loads a dataset with datasets.load_dataset and
returns a dictionary containing all dataset splits.
Given a dataset (e.g. from datasets.load_dataset or datasets.Dataset.from_dict) it
returns a dictionary containing the corresponding dataset splits.
For example:
ans = make_dataset_splits("dataset-id")
$ ds = load_dataset("dataset-id")
$ print(ds)
> DatasetDict({
> train: Dataset({
> features: ['id', 'title', 'context', 'question', 'answers'],
> num_rows: 87599
> })
> validation: Dataset({
> features: ['id', 'title', 'context', 'question', 'answers'],
> num_rows: 10570
> })
> })
In this case the value of `ans` (returned value) will be:
$ ds = load_dataset("dataset-id")
$ ans = make_dataset_splits(ds)
# `ds` contains the following
$ print(ds)
> DatasetDict({
> train: Dataset({
> features: ['id', 'title', 'context', 'question', 'answers'],
> num_rows: 87599
> })
> validation: Dataset({
> features: ['id', 'title', 'context', 'question', 'answers'],
> num_rows: 10570
> })
> })
# In this case the value of `ans` (returned value) will be:
$ print(ans)
> {
> "train": Dataset .. (with 87599 rows),
> "val": Dataset .. (with 10570 rows),
> }
"""
dataset = load_dataset(path, split=split, **kwargs)
from datasets import Dataset, DatasetDict

split_names = ['train', 'test', 'val']
dataset_splits = {split: None for split in split_names}
dataset_splits = {_split: None for _split in split_names}

alias_to_split = {}
for split_name, _split_aliases in split_aliases.items():
assert split_name in split_names
for alias in _split_aliases:
alias_to_split[alias] = split_name

if isinstance(dataset, datasets.dataset_dict.DatasetDict):
if isinstance(dataset, Dataset):
assert isinstance(split, str), "Expected split to be a string, but got " + str(type(split))
dataset_splits[split] = dataset
elif isinstance(dataset, DatasetDict):
dataset_split_names = dataset.keys()
logging.info(f"HF dataset has the following splits: {dataset_split_names}")
for alias_split_name, split in dataset.items():
Expand Down Expand Up @@ -89,9 +93,8 @@ def make_dataset_splits(path, split, split_aliases, kwargs):
else:
raise ValueError("Expected split name to be None, str or a list")

assert (
sum(map(lambda x: x is not None, dataset_splits.values())) > 0
), "Expected at least one dataset to have been initialized"
num_init_splits = sum(map(lambda x: x is not None, dataset_splits.values()))
assert num_init_splits > 0, f"Expected at least one split to have been initialized {num_init_splits}"
return dataset_splits


Expand All @@ -111,9 +114,9 @@ class HFDatasetDataModule(pl.LightningDataModule):

def __init__(
self,
path,
collate_fn=None,
path_or_dataset,
split=None,
collate_fn=None,
num_workers=2,
pin_memory=True,
persistent_workers=True,
Expand All @@ -130,16 +133,26 @@ def __init__(
) -> None:
super().__init__()
assert pad_token_id is not None

logging.info(f"Loading HF dataset from {path}")
from datasets import Dataset, DatasetDict

# A dataset usually will have several splits (e.g. train, val, test, etc).
# We map synonym names to canonical names (train, test, val).
# A synonym can be a prefix/suffixed word e.g. train <> training.
split_aliases = {'train': train_aliases, 'test': test_aliases, 'val': val_aliases}

# self.dataset_splits will hold the actual dataset for each split.
self.dataset_splits = make_dataset_splits(path, split, split_aliases, kwargs)
if isinstance(path_or_dataset, str):
logging.info(f"Loading HF dataset from {path_or_dataset}")
dataset = load_dataset(path_or_dataset, split=split, **kwargs)
elif isinstance(path_or_dataset, Dataset) or isinstance(path_or_dataset, DatasetDict):
logging.info(f"Using passed HF dataset {str(path_or_dataset)}")
dataset = path_or_dataset
else:
raise ValueError(
"Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " + str(type(path_or_dataset))
)

self.dataset_splits = make_dataset_splits(dataset, split, split_aliases)

if collate_fn is None:
self._collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id)
Expand All @@ -157,6 +170,13 @@ def __init__(
self.use_mcore_sampler = use_mcore_sampler
self.mcore_dataloader_type = mcore_dataloader_type

@staticmethod
def from_dict(dataset_dict, split, **kwargs):
from datasets import Dataset

dataset = Dataset.from_dict(dataset_dict)
return HFDatasetDataModule(path_or_dataset=dataset, split=split, **kwargs)

@staticmethod
def collate_fn(batch, pad_token_id=0):
def batchify(tensor):
Expand Down
36 changes: 29 additions & 7 deletions tests/collections/llm/gpt/data/test_hf_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def test_load_single_split():
ds = llm.HFDatasetDataModule(
path=DATA_PATH,
path_or_dataset=DATA_PATH,
split='train',
seq_length=512,
micro_batch_size=2,
Expand All @@ -46,7 +46,7 @@ def test_load_nonexistent_split():
expected_msg = '''Unknown split "this_split_name_should_not_exist". Should be one of ['train', 'validation'].'''
try:
llm.HFDatasetDataModule(
path=DATA_PATH,
path_or_dataset=DATA_PATH,
split='this_split_name_should_not_exist',
seq_length=512,
micro_batch_size=2,
Expand All @@ -59,7 +59,7 @@ def test_load_nonexistent_split():

def test_load_multiple_split():
ds = llm.HFDatasetDataModule(
path=DATA_PATH,
path_or_dataset=DATA_PATH,
split=['train', 'validation'],
seq_length=512,
micro_batch_size=2,
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_validate_dataset_asset_accessibility_file_does_not_exist():
raised_exception = False
try:
llm.HFDatasetDataModule(
path="/this/path/should/not/exist/",
path_or_dataset="/this/path/should/not/exist/",
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
Expand All @@ -103,12 +103,34 @@ def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trai
raised_exception = False
try:
llm.HFDatasetDataModule(
path=None,
path_or_dataset=None,
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
)
except TypeError:
raised_exception = True
except ValueError as e:
raised_exception = (
str(e) == "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got <class 'NoneType'>"
)

assert raised_exception == True, "Expected to raise a ValueError"


def test_load_from_dict():
data = {'text': "Below is an instruction that describes a task, paired with an input that "}

datamodule = llm.HFDatasetDataModule.from_dict(
{"text": [data['text'] for _ in range(101)]},
split='train',
global_batch_size=4,
micro_batch_size=1,
)
assert datamodule is not None
assert isinstance(datamodule, llm.HFDatasetDataModule)
assert hasattr(datamodule, 'train')
assert datamodule.train is not None
assert len(datamodule.train) == 101
assert hasattr(datamodule, 'val')
assert datamodule.val is None
assert hasattr(datamodule, 'test')
assert datamodule.test is None

0 comments on commit 3867529

Please sign in to comment.