From a5597ad8dd4c514ffbf78b456dd61e0c959d8dd7 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 21 Feb 2024 11:15:09 -0800 Subject: [PATCH] modify data split to use HF api ghstack-source-id: e23d5e0b70abc427a13bc8bf195c876c007f4939 Pull Request resolved: https://github.com/pytorch-labs/torchtrain/pull/65 --- torchtrain/datasets/alpaca.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/torchtrain/datasets/alpaca.py b/torchtrain/datasets/alpaca.py index f52d21121..3ee1442dc 100644 --- a/torchtrain/datasets/alpaca.py +++ b/torchtrain/datasets/alpaca.py @@ -9,6 +9,7 @@ from torchtrain.datasets.tokenizer import TokenizerIf from datasets import load_dataset +from datasets.distributed import split_dataset_by_node class AlpacaDataset(IterableDataset): @@ -44,32 +45,24 @@ def __init__( rank: int = 0, **kwargs ) -> None: - self._data = load_dataset("tatsu-lab/alpaca", split="train") + # TODO: This is a temporary solution for small datasets like Alpaca. + # For larger datasets we need to use a more scalable approach. + # Setting `streaming=True` works for large dataset, but the speed is slow. + ds = load_dataset("tatsu-lab/alpaca", split="train") + self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size)) self._tokenizer = tokenizer - self.data_iterator = iter(self._data) self.seq_len = seq_len - self.world_size = world_size - self.rank = rank - self.response_tag = "\n\n### Response:\n" - - def __len__(self): - return len(self._data) def __iter__(self): max_buffer_token_len = 1 + self.seq_len all_tokens: List[int] = [] - for idx, sample in enumerate(self.data_iterator): - # select samples to pack in a round-robin fashion - # TODO: This is a temporary solution for small datasets like Alpaca. - # For larger datasets we need to use a more scalable approach. - if idx % self.world_size != self.rank: - continue + for sample in self.data_iterator: sample_text = sample["text"] sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) all_tokens.extend(sample_tokens) - if len(all_tokens) >= max_buffer_token_len: + while len(all_tokens) >= max_buffer_token_len: x = torch.LongTensor(all_tokens[:max_buffer_token_len]) # batched_x = x.reshape(self.batch_size, -1) # update tokens to the remaining tokens