From db4cf53ad24eff0b5fe044ea2dabc6219f9b8e7f Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 26 Feb 2024 15:57:02 -0800 Subject: [PATCH] support infinite loop over alpaca dataset [ghstack-poisoned] --- torchtrain/datasets/alpaca.py | 40 ++++++++++++++++++++++------------- train.py | 4 +++- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/torchtrain/datasets/alpaca.py b/torchtrain/datasets/alpaca.py index 3ee1442d..734792b6 100644 --- a/torchtrain/datasets/alpaca.py +++ b/torchtrain/datasets/alpaca.py @@ -20,6 +20,7 @@ class AlpacaDataset(IterableDataset): seq_len (int): max sequence length world_size (int): number of data parallel processes participating in training rank (int): rank of the current data parallel process + infinite: whether to loop infinitely over the dataset Data input format: { @@ -43,38 +44,47 @@ def __init__( seq_len: int = 2048, world_size: int = 1, rank: int = 0, + infinite: bool = False, **kwargs ) -> None: # TODO: This is a temporary solution for small datasets like Alpaca. # For larger datasets we need to use a more scalable approach. # Setting `streaming=True` works for large dataset, but the speed is slow. ds = load_dataset("tatsu-lab/alpaca", split="train") - self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size)) + self._data = split_dataset_by_node(ds, rank, world_size) self._tokenizer = tokenizer self.seq_len = seq_len + self.infinite = infinite def __iter__(self): max_buffer_token_len = 1 + self.seq_len all_tokens: List[int] = [] - for sample in self.data_iterator: - sample_text = sample["text"] - sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) - all_tokens.extend(sample_tokens) + while True: + for sample in iter(self._data): + sample_text = sample["text"] + sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) + all_tokens.extend(sample_tokens) - while len(all_tokens) >= max_buffer_token_len: - x = torch.LongTensor(all_tokens[:max_buffer_token_len]) - # batched_x = x.reshape(self.batch_size, -1) - # update tokens to the remaining tokens - all_tokens = all_tokens[max_buffer_token_len:] - input = x[:-1] - label = x[1:] - yield input, label + while len(all_tokens) >= max_buffer_token_len: + x = torch.LongTensor(all_tokens[:max_buffer_token_len]) + # update tokens to the remaining tokens + all_tokens = all_tokens[max_buffer_token_len:] + input = x[:-1] + label = x[1:] + yield input, label + if not self.infinite: + break def build_alpaca_data_loader( - tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank + tokenizer: TokenizerIf, + batch_size: int, + seq_len: int, + world_size: int, + rank: int, + infinite: bool = True, ): - alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank) + alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank, infinite) return DataLoader(alpaca_ds, batch_size=batch_size) diff --git a/train.py b/train.py index 904ebf17..5ce5de37 100644 --- a/train.py +++ b/train.py @@ -167,6 +167,8 @@ def main(job_config: JobConfig): ) checkpoint.load() + data_iterator = iter(data_loader) + with maybe_run_profiler(job_config) as torch_profiler: checkpoint.reset() # variables used to keep info for metrics logging @@ -180,7 +182,7 @@ def main(job_config: JobConfig): train_state.step += 1 # get batch data_load_start = timer() - batch = next(iter(data_loader)) + batch = next(data_iterator) input_ids, labels = batch input_ids = input_ids.cuda() labels = labels.cuda()