Skip to content

Commit

Permalink
modify data split to use HF api
Browse files Browse the repository at this point in the history
ghstack-source-id: e23d5e0b70abc427a13bc8bf195c876c007f4939
Pull Request resolved: #65
  • Loading branch information
tianyu-l committed Feb 21, 2024
1 parent 4dbe0af commit a5597ad
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions torchtrain/datasets/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchtrain.datasets.tokenizer import TokenizerIf

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node


class AlpacaDataset(IterableDataset):
Expand Down Expand Up @@ -44,32 +45,24 @@ def __init__(
rank: int = 0,
**kwargs
) -> None:
self._data = load_dataset("tatsu-lab/alpaca", split="train")
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
# Setting `streaming=True` works for large dataset, but the speed is slow.
ds = load_dataset("tatsu-lab/alpaca", split="train")
self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size))
self._tokenizer = tokenizer
self.data_iterator = iter(self._data)
self.seq_len = seq_len
self.world_size = world_size
self.rank = rank
self.response_tag = "\n\n### Response:\n"

def __len__(self):
return len(self._data)

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

for idx, sample in enumerate(self.data_iterator):
# select samples to pack in a round-robin fashion
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
if idx % self.world_size != self.rank:
continue
for sample in self.data_iterator:
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)

if len(all_tokens) >= max_buffer_token_len:
while len(all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
# batched_x = x.reshape(self.batch_size, -1)
# update tokens to the remaining tokens
Expand Down

0 comments on commit a5597ad

Please sign in to comment.