Skip to content

Commit

Permalink
add miniPile dataset for pretraining, 1M entries (solves the 'out of …
Browse files Browse the repository at this point in the history
…data' at 40 iters issue) (#88)

This PR add's minipile (1M, 6GB) dataset as an option for pretraining
with torchtrain.
It resolves the issue where we run out of data after 40 iterations with
the default alpaca dataset.
Per @tianyu-l's excellent suggestion, have refactored to have a single
hf_datasets.py file that supports both minipile and alpaca since it
turned out no need for any different tokenizer, etc.
Also cleaned up the datasets package so that create_tokenizer is exposed
directly, and thus all public apis can be used directly from
torchtrain.datasets.
Lastly - added warning if/when a dataset is being re-looped so users
don't get burned by overfitting:
<img width="1294" alt="Screenshot 2024-03-06 at 5 11 09 AM"
src="https://github.com/pytorch/torchtrain/assets/46302957/82480b6f-c677-4794-80c5-5c10b037732a">


Adds a color highlight to showcase what dataloader was built:
<img width="1360" alt="Screenshot 2024-03-05 at 9 19 10 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/4717ec6a-14bb-4283-a3ae-fa40c27deee0">
and
<img width="1360" alt="Screenshot 2024-03-05 at 9 22 01 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/dbf32d51-2dd4-4526-8855-9b33b627559e">


Usage:
just add "minipile" or "alpaca" as the dataset in the training config
toml file.
<img width="439" alt="Screenshot 2024-02-25 at 12 35 26 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/1afbaed1-07f8-4e37-b8cc-80190db7fb27">

Testing:
verified training loss is improving and ran for 100 iters to verify no
issue with out of data any longer with minipile.
reran with alpaca and saw the expected out of data at 40 iters without
infinite loop option, runs to 100 with infinite.

Notes:
I did not make this a default dataset since for debugmodel, mostly
running 10 iters is fine and there's 6GB to pull down.
<img width="869" alt="Screenshot 2024-02-25 at 12 30 29 PM"
src="https://github.com/pytorch/torchtrain/assets/46302957/1070a80a-ad20-4f0f-a860-e13caa3120a0">
  • Loading branch information
lessw2020 authored Mar 7, 2024
1 parent f31adb0 commit 6927e45
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 17 deletions.
10 changes: 7 additions & 3 deletions torchtrain/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from torchtrain.datasets.alpaca import build_alpaca_data_loader
from torchtrain.datasets.hf_datasets import build_hf_data_loader
from torchtrain.datasets.tokenizer import create_tokenizer

__all__ = ["build_alpaca_data_loader", "create_tokenizer", "pad_batch_to_longest_seq"]
__all__ = [
"build_hf_data_loader",
"create_tokenizer",
]

dataloader_fn = {
"alpaca": build_alpaca_data_loader,
"alpaca": build_hf_data_loader,
"minipile": build_hf_data_loader,
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,46 @@
from torch.utils.data import DataLoader, IterableDataset

from torchtrain.datasets.tokenizer import TokenizerIf
from torchtrain.logging_utils import rank0_log
from torchtrain.utils import Color

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node

_supported_datasets = {
"alpaca": "tatsu-lab/alpaca",
"minipile": "JeanKaddour/minipile",
}

class AlpacaDataset(IterableDataset):
"""PyTorch Representation of the Alpaca Dataset from Hugging Face.

class HuggingFaceDataset(IterableDataset):
"""PyTorch Representation of a Dataset from Hugging Face.
We currently support two datasets:
minipile (1M training entries)
alpaca (52K training entries)
>> MiniPile <<:
MiniPile dataset is detailed in the following paper:
https://arxiv.org/abs/2304.08442
Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
dataset_name (str): name of the dataset to load
tokenizer (TokenizerIf): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process
infinite: whether to loop infinitely over the dataset
infinite (bool): whether to loop infinitely over the dataset
Data input format:
Data input format (minipile):
{
"text": "Open-end spinning devices with such rotor bearing arrangements are known in
various different embodiments, and have been extensively described,
for example in German Patent Publications"
}
>> Alpaca <<:
Data input format (alpaca):
{
"instruction": "Create a classification task by clustering the given list of items.",
"input": "Apples, oranges, bananas, strawberries, pineapples",
Expand All @@ -32,25 +56,31 @@ class AlpacaDataset(IterableDataset):
}
Example:
>>> alpaca_ds = AlpacaDataset(tokenizer=tokenizer)
>>> alpaca_ds = HuggingFaceDataset(tokenizer=tokenizer)
>>> for batch in Dataloader(alpaca_ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8
"""

def __init__(
self,
dataset_name: str,
tokenizer: TokenizerIf,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
infinite: bool = False,
**kwargs
) -> None:
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
# Setting `streaming=True` works for large dataset, but the speed is slow.
ds = load_dataset("tatsu-lab/alpaca", split="train")
if dataset_name not in _supported_datasets:
raise ValueError(
f"Dataset {dataset_name} is not supported. Supported datasets are: {_supported_datasets.keys()}"
)

ds = load_dataset(_supported_datasets[dataset_name], split="train")
self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
Expand All @@ -75,16 +105,25 @@ def __iter__(self):
yield input, label
if not self.infinite:
break
else:
# we are re-looping on the same dataset, warn user
rank0_log(
f"{Color.red}WARNING:{Color.reset} dataset {Color.yellow}'{self.dataset_name}'{Color.reset} is "
f"being re-looped. Loss related metrics might be misleading.{Color.reset}"
)


def build_alpaca_data_loader(
def build_hf_data_loader(
dataset_name: str,
tokenizer: TokenizerIf,
batch_size: int,
seq_len: int,
world_size: int,
rank: int,
world_size,
rank,
infinite: bool = True,
):
alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank, infinite)
hf_ds = HuggingFaceDataset(
dataset_name, tokenizer, seq_len, world_size, rank, infinite
)

return DataLoader(alpaca_ds, batch_size=batch_size)
return DataLoader(hf_ds, batch_size=batch_size)
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,16 @@ def main(job_config: JobConfig):
else:
dp_degree, dp_rank = 1, 0
data_loader = build_dataloader_fn(
job_config.training.dataset,
tokenizer,
job_config.training.batch_size,
job_config.training.seq_len,
dp_degree,
dp_rank,
)
rank0_log(
f"{Color.green}Built Dataloader for '{job_config.training.dataset}' dataset.{Color.reset}"
)

# build model
model_cls = model_name_to_cls[model_name]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ compile = false
checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
dataset = "alpaca"
dataset = "alpaca" # supported datasets = minipile (1M), alpaca (52K)

0 comments on commit 6927e45

Please sign in to comment.