Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repeat method to datasets #7198

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4040,6 +4040,40 @@ def skip(self, n: int) -> "Dataset":
"""
return self.select(range(n, len(self)))

def repeat(self, num_times: int) -> "Dataset":
"""
Create a new [`Dataset`] that repeats the underlying dataset `num_times` times.

Like itertools.repeat, repeating once just returns the full dataset.

Args:
num_times (`int`):
Number of times to repeat the dataset.

Example:
```py
>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="train")
>>> ds = ds.take(2).repeat(2)
>>> list(ds)
[{'label': 1,
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
{'label': 1,
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
{'label': 1, 'text': 'effective but too-tepid biopic'},
{'label': 1,
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
{'label': 1,
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
{'label': 1, 'text': 'effective but too-tepid biopic'}]
```
"""
if num_times is None:
raise ValueError("Map style datasets do not support indefinite repetition.")
num_times = max(num_times, 0)
indices = list(range(len(self))) * num_times
return self.select(indices)

def take(self, n: int) -> "Dataset":
"""
Create a new [`Dataset`] with only the first `n` elements.
Expand Down
91 changes: 91 additions & 0 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,54 @@ def n_shards(self) -> int:
return self.ex_iterable.n_shards


class RepeatExamplesIterable(_BaseExamplesIterable):
"""
Iterable that repeats the underlying iterable a given number of times.
"""

def __init__(
self,
ex_iterable: _BaseExamplesIterable,
num_times: int,
):
super().__init__()
self.ex_iterable = ex_iterable
self.num_times = num_times

def _init_state_dict(self) -> dict:
self._state_dict = {
"repeat_index": 0,
"ex_iterable": self.ex_iterable._init_state_dict(),
}
return self._state_dict

def __iter__(self):
repeat_index = self._state_dict["repeat_index"] if self._state_dict else 0
while True:
if self.num_times and repeat_index >= max(self.num_times, 0):
break
yield from self.ex_iterable
repeat_index += 1
if self._state_dict:
self._state_dict["repeat_index"] = repeat_index
self._state_dict["ex_iterable"] = self.ex_iterable._init_state_dict()

def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExamplesIterable":
"""Shuffle the underlying iterable, then repeat."""
return RepeatExamplesIterable(self.ex_iterable.shuffle_data_sources(generator), num_times=self.num_times)

def shard_data_sources(self, worker_id: int, num_workers: int) -> "RepeatExamplesIterable":
"""Shard, then repeat shards."""
return RepeatExamplesIterable(
self.ex_iterable.shard_data_sources(worker_id, num_workers),
num_times=self.num_times,
)

@property
def n_shards(self) -> int:
return self.ex_iterable.n_shards


class TakeExamplesIterable(_BaseExamplesIterable):
def __init__(
self,
Expand Down Expand Up @@ -2513,6 +2561,49 @@ def skip(self, n: int) -> "IterableDataset":
token_per_repo_id=self._token_per_repo_id,
)

def repeat(self, num_times: Optional[int]) -> "IterableDataset":
"""
Create a new [`IterableDataset`] that repeats the underlying dataset `num_times` times.

N.B. The effect of calling shuffle after repeat depends significantly on buffer size.
With buffer_size 1, duplicate data is never seen in the same iteration, even after shuffling:
ds.repeat(n).shuffle(seed=42, buffer_size=1) is equivalent to ds.shuffle(seed=42, buffer_size=1).repeat(n),
and only shuffles shard orders within each iteration.
With buffer size >= (num samples in the dataset * num_times), we get full shuffling of the repeated data, i.e. we can observe duplicates in
the same iteration.

Args:
num_times (`int`) or (`None`):
Number of times to repeat the dataset. If `None`, the dataset will be repeated indefinitely.

Example:
```py
>>> from datasets import load_dataset
>>> ds = load_dataset("rotten_tomatoes", split="train")
>>> ds = ds.take(2).repeat(2)
>>> list(ds)
[{'label': 1,
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
{'label': 1,
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
{'label': 1, 'text': 'effective but too-tepid biopic'},
{'label': 1,
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
{'label': 1,
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
{'label': 1, 'text': 'effective but too-tepid biopic'}]
```
"""
return IterableDataset(
ex_iterable=RepeatExamplesIterable(self._ex_iterable, num_times=num_times),
info=self._info,
split=self._split,
formatting=self._formatting,
shuffling=copy.deepcopy(self._shuffling),
distributed=copy.deepcopy(self._distributed),
token_per_repo_id=self._token_per_repo_id,
)

def take(self, n: int) -> "IterableDataset":
"""
Create a new [`IterableDataset`] with only the first `n` elements.
Expand Down