Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add miniPile dataset for pretraining, 1M entries (solves the 'out of data' at 40 iters issue) #88

Merged
merged 16 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions torchtrain/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from torchtrain.datasets.alpaca import build_alpaca_data_loader
from torchtrain.datasets.tokenizer import create_tokenizer
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
from torchtrain.datasets.hf_datasets import build_hf_data_loader

__all__ = ["build_alpaca_data_loader", "create_tokenizer", "pad_batch_to_longest_seq"]
__all__ = [
"build_hf_data_loader",
"create_tokenizer",
]

dataloader_fn = {
"alpaca": build_alpaca_data_loader,
"alpaca": build_hf_data_loader,
"minipile": build_hf_data_loader,
}
90 changes: 0 additions & 90 deletions torchtrain/datasets/alpaca.py

This file was deleted.

112 changes: 112 additions & 0 deletions torchtrain/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from typing import List

import torch
from torch.utils.data import DataLoader, IterableDataset

from torchtrain.datasets.tokenizer import TokenizerIf

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node

_supported_datasets = {
"alpaca": "tatsu-lab/alpaca",
"minipile": "JeanKaddour/minipile",
}


class HuggingFaceDataset(IterableDataset):
"""PyTorch Representation of a Dataset from Hugging Face.

We currently support two datasets:
minipile (1M training entries)
alpaca (57K training entries)

>> MiniPile <<:
MiniPile dataset is detailed in the following paper:
https://arxiv.org/abs/2304.08442

Args:
dataset_name (str): name of the dataset to load
tokenizer (TokenizerIf): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process

Data input format (minipile):
{
"text": "Open-end spinning devices with such rotor bearing arrangements are known in
various different embodiments, and have been extensively described,
for example in German Patent Publications"
}

>> Alpaca <<:
Data input format (alpaca):
{
"instruction": "Create a classification task by clustering the given list of items.",
"input": "Apples, oranges, bananas, strawberries, pineapples",
"output": "Class 1: Apples, Oranges\nClass 2: Bananas, Strawberries\nClass 3: Pineapples",
"text": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nCreate a classification task by clustering the given list of items.\n\n### Input:\nApples, oranges, bananas, strawberries, pineapples\n\n### Response:\nClass 1: Apples,
Oranges\nClass 2: Bananas, Strawberries\nClass 3: Pineapples", # noqa: B950
}

Example:
>>> minipile_ds = MiniPileDataset(tokenizer=tokenizer)
lessw2020 marked this conversation as resolved.
Show resolved Hide resolved
>>> for batch in Dataloader(minipile_ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8
"""

def __init__(
self,
dataset_name: str,
tokenizer: TokenizerIf,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
) -> None:
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
# Setting `streaming=True` works for large dataset, but the speed is slow.
if dataset_name not in _supported_datasets:
raise ValueError(
f"Dataset {dataset_name} is not supported. Supported datasets are: {_supported_datasets.keys()}"
)
ds = load_dataset(_supported_datasets[dataset_name], split="train")
# ds = load_dataset("JeanKaddour/minipile", split="train")
self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size))
self._tokenizer = tokenizer
self.seq_len = seq_len

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

for sample in self.data_iterator:
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)

while len(all_tokens) >= max_buffer_token_len:
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
# batched_x = x.reshape(self.batch_size, -1)
# update tokens to the remaining tokens
all_tokens = all_tokens[max_buffer_token_len:]
input = x[:-1]
label = x[1:]
yield input, label


def build_hf_data_loader(
dataset_name: str,
tokenizer: TokenizerIf,
batch_size: int,
seq_len: int,
world_size,
rank,
):
hf_ds = HuggingFaceDataset(dataset_name, tokenizer, seq_len, world_size, rank)

return DataLoader(hf_ds, batch_size=batch_size)
7 changes: 6 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from torchtrain.config_manager import JobConfig

# torchtrain related
from torchtrain.datasets import create_tokenizer, dataloader_fn
from torchtrain.datasets import dataloader_fn
from torchtrain.datasets.tokenizer import create_tokenizer
from torchtrain.logging_utils import init_logger, rank0_log
from torchtrain.lr_scheduling import get_lr_scheduler
from torchtrain.meta_init import meta_model_init
Expand Down Expand Up @@ -108,12 +109,16 @@ def main(job_config: JobConfig):
dp_rank = dp_mesh.get_local_rank()
build_dataloader_fn = dataloader_fn[job_config.training.dataset]
data_loader = build_dataloader_fn(
job_config.training.dataset,
tokenizer,
job_config.training.batch_size,
job_config.training.seq_len,
dp_degree,
dp_rank,
)
rank0_log(
f"{Color.green}Built Dataloader for '{job_config.training.dataset}' dataset.{Color.reset}"
)

# build model
model_cls = model_name_to_cls[model_name]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ compile = false
checkpoint_interval = 3600
checkpoint_interval_type = "steps"
checkpoint_folder = ""
dataset = "alpaca"
dataset = "alpaca" # supported datasets = minipile (1M), alpaca (57K)
Loading