Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add with_rank to Dataset.from_generator #7199

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,7 @@ def from_generator(
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: NamedSplit = Split.TRAIN,
with_rank: bool = False,
**kwargs,
):
"""Create a Dataset from a generator.
Expand Down Expand Up @@ -1095,6 +1096,7 @@ def from_generator(
gen_kwargs=gen_kwargs,
num_proc=num_proc,
split=split,
with_rank=with_rank,
**kwargs,
).read()

Expand Down
12 changes: 9 additions & 3 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,7 @@ def _prepare_split(
file_format="arrow",
num_proc: Optional[int] = None,
max_shard_size: Optional[Union[int, str]] = None,
with_rank: bool = False,
):
max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)

Expand Down Expand Up @@ -1483,7 +1484,7 @@ def _prepare_split(
job_id = 0
with pbar:
for job_id, done, content in self._prepare_split_single(
gen_kwargs=gen_kwargs, job_id=job_id, **_prepare_split_args
gen_kwargs=gen_kwargs, job_id=job_id, with_rank=with_rank, **_prepare_split_args
):
if done:
result = content
Expand All @@ -1496,7 +1497,7 @@ def _prepare_split(
]
else:
kwargs_per_job = [
{"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args}
{"gen_kwargs": gen_kwargs, "job_id": job_id, "with_rank": with_rank, **_prepare_split_args}
for job_id, gen_kwargs in enumerate(
_split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc)
)
Expand Down Expand Up @@ -1582,8 +1583,13 @@ def _prepare_split_single(
split_info: SplitInfo,
check_duplicate_keys: bool,
job_id: int,
with_rank: bool = False,
) -> Iterable[Tuple[int, bool, Union[int, tuple]]]:
generator = self._generate_examples(**gen_kwargs)
if with_rank:
generator = self._generate_examples(rank=job_id, **gen_kwargs)
else:
generator = self._generate_examples(**gen_kwargs)

writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
embed_local_files = file_format == "parquet"
shard_lengths = []
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/io/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: NamedSplit = Split.TRAIN,
with_rank: bool = False,
**kwargs,
):
super().__init__(
Expand All @@ -26,6 +27,7 @@ def __init__(
num_proc=num_proc,
**kwargs,
)
self.with_rank = with_rank
self.builder = Generator(
cache_dir=cache_dir,
features=features,
Expand All @@ -52,6 +54,7 @@ def read(self):
verification_mode=verification_mode,
base_path=base_path,
num_proc=self.num_proc,
with_rank=self.with_rank,
)
dataset = self.builder.as_dataset(
split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
Expand Down
52 changes: 52 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3953,6 +3953,58 @@ def test_dataset_from_generator_split(split, data_generator, tmp_path):
_check_generator_dataset(dataset, expected_features, expected_split)


@pytest.mark.parametrize(
"with_rank",
[True, False],
)
@pytest.mark.parametrize(
"num_proc",
[None, 1, 2, 3, 4],
)
def test_dataset_from_generator_with_rank(with_rank, num_proc, tmp_path):
cache_dir = tmp_path / "cache"

def _gen_with_rank(arg_1, **kwargs):
for a in arg_1:
res = {"col_1": a}
if with_rank:
assert "rank" in kwargs
res["rank"] = kwargs["rank"]
else:
assert "rank" not in kwargs
yield res

num_examples = 10
examples = list(range(num_examples))
gen_kwargs = {"arg_1": examples}
datasets = [
Dataset.from_generator(
_gen_with_rank, gen_kwargs=gen_kwargs, cache_dir=cache_dir, with_rank=_with_rank, num_proc=num_proc
)
for _with_rank in [with_rank, not with_rank]
]
assert datasets[0]._fingerprint == datasets[1]._fingerprint
dataset = datasets[0]
assert isinstance(dataset, Dataset)
assert dataset.num_rows == num_examples
assert dataset.num_columns == 2 if with_rank else 1
assert dataset.split == NamedSplit("train")
assert dataset.column_names == ["col_1", "rank"] if with_rank else ["col_1"]
if with_rank:
expected_features = {"col_1": "int64", "rank": "int64"}
else:
expected_features = {"col_1": "int64"}
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
npt.assert_array_equal(examples, dataset["col_1"])
if with_rank:
if num_proc is None:
ranks = [0]
else:
ranks = np.arange(num_proc)
npt.assert_array_equal(ranks, list(set(dataset["rank"])))


@require_not_windows
@require_dill_gt_0_3_2
@require_pyspark
Expand Down