Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa committed Dec 11, 2024
1 parent 11f1202 commit 329d737
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
8 changes: 6 additions & 2 deletions nemo/collections/llm/gpt/data/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def make_dataset_splits(dataset, split, split_aliases):
> }
"""
from datasets import Dataset, DatasetDict

split_names = ['train', 'test', 'val']
dataset_splits = {_split: None for _split in split_names}

Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(
super().__init__()
assert pad_token_id is not None
from datasets import Dataset, DatasetDict

# A dataset usually will have several splits (e.g. train, val, test, etc).
# We map synonym names to canonical names (train, test, val).
# A synonym can be a prefix/suffixed word e.g. train <> training.
Expand All @@ -146,8 +148,9 @@ def __init__(
logging.info(f"Using passed HF dataset {str(path_or_dataset)}")
dataset = path_or_dataset
else:
raise ValueError("Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got "\
+ str(type(path_or_dataset)))
raise ValueError(
"Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " + str(type(path_or_dataset))
)

self.dataset_splits = make_dataset_splits(dataset, split, split_aliases)

Expand All @@ -170,6 +173,7 @@ def __init__(
@staticmethod
def from_dict(dataset_dict, split, **kwargs):
from datasets import Dataset

dataset = Dataset.from_dict(dataset_dict)
return HFDatasetDataModule(path_or_dataset=dataset, split=split, **kwargs)

Expand Down
7 changes: 5 additions & 2 deletions tests/collections/llm/gpt/data/test_hf_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,13 @@ def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trai
global_batch_size=2,
)
except ValueError as e:
raised_exception = str(e) == "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got <class 'NoneType'>"
raised_exception = (
str(e) == "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got <class 'NoneType'>"
)

assert raised_exception == True, "Expected to raise a ValueError"


def test_load_from_dict():
data = {'text': "Below is an instruction that describes a task, paired with an input that "}

Expand All @@ -130,4 +133,4 @@ def test_load_from_dict():
assert hasattr(datamodule, 'val')
assert datamodule.val is None
assert hasattr(datamodule, 'test')
assert datamodule.test is None
assert datamodule.test is None

0 comments on commit 329d737

Please sign in to comment.