Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streaming datasets doesn't work properly with multi-node #6623

Open
rohitgr7 opened this issue Jan 27, 2024 · 23 comments
Open

streaming datasets doesn't work properly with multi-node #6623

rohitgr7 opened this issue Jan 27, 2024 · 23 comments
Labels
enhancement New feature or request

Comments

@rohitgr7
Copy link

rohitgr7 commented Jan 27, 2024

Feature request

Let’s say I have a dataset with 5 samples with values [1, 2, 3, 4, 5], with 2 GPUs (for DDP) and batch size of 2. This dataset is an IterableDataset since I am streaming it.

Now I split the dataset using split_dataset_by_node to ensure it doesn’t get repeated. And since it’s already splitted, I don’t have to use DistributedSampler (also they don't work with iterable datasets anyway)?

But in this case I noticed that the:

First iteraton:
first GPU will get → [1, 2]
first GPU will get → [3, 4]

Second iteraton:
first GPU will get → [5]
first GPU will get → Nothing

which actually creates an issue since in case of DistributedSampler, the samples are repeated internally to ensure non of the GPUs at any iteration is missing any data for gradient sync.

So my questions are:

  1. Here since splitting is happening before hand, how to make sure each GPU get’s a batch at each iteration to avoid gradient sync issues?
  2. Do we need to use DistributedSampler? If yes, how?
  3. in the docstrings of split_dataset_by_node, this is mentioned: "If the dataset has a number of shards that is a factor of world_size (i.e. if dataset.n_shards % world_size == 0), then the shards are evenly assigned across the nodes, which is the most optimized. Otherwise, each node keeps 1 example out of world_size, skipping the other examples." Can you explain the last part here?
  4. If dataset.n_shards % world_size != 0, is it possible to shard the streaming dataset on the fly to avoid the case where data is missing?

Motivation

Somehow streaming datasets should work with DDP since for big LLMs a lot of data is required and DDP/multi-node is mostly used to train such models and streaming can actually help solve the data part of it.

Your contribution

Yes, I can help in submitting the PR once we get mutual understanding on how it should behave.

@rohitgr7 rohitgr7 added the enhancement New feature or request label Jan 27, 2024
@rohitgr7
Copy link
Author

@mariosasko, @lhoestq, @albertvillanova
hey guys! can anyone help? or can you guys suggest who can help with this?

@lhoestq
Copy link
Member

lhoestq commented Jan 31, 2024

Hi !

  1. When the dataset is running of of examples, the last batches received by the GPU can be incomplete or empty/missing. We haven't implemented yet a way to ignore the last batch. It might require the datasets to provide the number of examples per shard though, so that we can know when to stop.
  2. Samplers are not compatible with IterableDatasets in pytorch
  3. if dataset.n_shards % world_size != 0 then all the nodes will read/stream the full dataset in order (possibly reading/streaming the same data multiple times), BUT will only yield one example out of world_size so that each example goes to one exactly one GPU.
  4. no, sharding should be down up-front and can take some time depending on the dataset size and format

@rohitgr7
Copy link
Author

if dataset.n_shards % world_size != 0 then all the nodes will read/stream the full dataset in order (possibly reading/streaming the same data multiple times), BUT will only yield one example out of world_size so that each example goes to one exactly one GPU.

considering there's just 1 shard and 2 worker nodes, do you mean each worker node will load the whole dataset but still receive half of that shard while streaming?

@lhoestq
Copy link
Member

lhoestq commented Feb 1, 2024

Yes both nodes will stream from the 1 shard, but each node will skip half of the examples. This way in total each example is seen once and exactly once during you distributed training.

Though it terms of I/O, the dataset is effectively read/streamed twice.

@rohitgr7
Copy link
Author

rohitgr7 commented Feb 1, 2024

what if the number of samples in that shard % num_nodes != 0? it will break/get stuck? or is the data repeated in that case for gradient sync?

@lhoestq
Copy link
Member

lhoestq commented Feb 2, 2024

In the case one at least one of the nodes will get an empty/incomplete batch. The data is not repeated in that case. If the training loop doesn't take this into account it can lead to unexpected behaviors indeed.

In the future we'd like to add a feature that would allow the nodes to ignore the last batch, this way all the nodes would only have full batches.

@kkkjyu
Copy link

kkkjyu commented Mar 8, 2024

In the case one at least one of the noes will get an empty/incomplete batch. The data is not repeated in that case. If the training loop doesn't take this into account it can lead to unexpected behaviors indeed.

In the future we'd like to add a feature that would allow the nodes to ignore the last batch, this way all the nodes would only have full batches.

Is there any method to modify one dataset's n_shard? modify the number of files is ok? one file == one shard?

@lhoestq
Copy link
Member

lhoestq commented Mar 8, 2024

modify the number of files is ok? one file == one shard?

Yep, one file == one shard :)

@alex-hh
Copy link
Contributor

alex-hh commented Sep 22, 2024

Hi @lhoestq, do you have any advice on how to implement a fix for the case dataset.n_shards % world_size != 0 while such a fix is not supported in the library?

It seems essential for performing validation in a ddp setting

Simply limiting the number of files is a bit brittle as it relies on world size being consistent to ensure different runs see the same data

How should a user either ignore the last batch or handle the empty batch?

Is the issue of overhanging batches also relevant for map-style datasets?

@lhoestq
Copy link
Member

lhoestq commented Sep 23, 2024

How should a user either ignore the last batch or handle the empty batch?

Check the batch size in the training loop and use all_reduce (or any communication method) to make sure all the nodes got their data before passing them to the model. If some data are missing you can decide to stop the training loop or repeat examples until all the nodes have exhausted their data.

Cc @andrewkho in case you know a way to make the DataLoader stop or add extra samples automatically in case of distributed + unevenly divisible iterable dataset

Is the issue of overhanging batches also relevant for map-style datasets?

The DistributedSampler drops the last data by default to make the dataset evenly divisible.

@andrewkho
Copy link

@lhoestq Unfortunately for IterableDataset there isn't a way to do this in general without introducing communciation between ranks, or having all the ranks read all the data before starting to figure out when to stop (which is pretty impractical). My recommendation for these situations where you don't know the total number of samples apriori is to, configure the iterable dataset to yield a fixed number of samples before raising StopIteration, and if necessary, repeat/reshuffle samples to hit that number

@andrewkho
Copy link

A heads up that we're planning to land something new in torchdata by end-of-year to help with these scenarios, we'll update this thread when we hvae some code landed

@lhoestq
Copy link
Member

lhoestq commented Sep 27, 2024

I made a quick example with communication between ranks to stop once all the data from all the ranks are exhausted (and repeating data if necessary to end up with a number of samples evenly divisible)

import torch
import torch.distributed as dist
from datasets import Dataset
from datasets.distributed import split_dataset_by_node
from torch.utils.data import DataLoader


# simulate a streaming dataset
num_shards = 1  # change here if you want to simulate a dataset made of many files/shards
ds = Dataset.from_dict({"x": [1, 2, 3, 4, 5]}).to_iterable_dataset(num_shards=num_shards)

# split the dataset for distributed training
dist.init_process_group()
rank, world_size = dist.get_rank(), dist.get_world_size()
ds = split_dataset_by_node(ds, rank=rank,world_size=world_size)
dl = DataLoader(ds)

exhausted = torch.zeros(world_size, dtype=torch.bool)

# IMPORTANT: Loop over the local dataset until the data from each rank has been exhausted

def loop():
    while True:
        yield from dl
        yield "end"

for x in loop():
    if x == "end":
        exhausted[rank] = True
        continue
    # stop once the data from all the ranks are exhausted
    dist.all_reduce(exhausted)
    if torch.all(exhausted):
        break
    # do your forward pass + loss here
    # model.forward(...)
    print(x)

on my laptop I run torchrun --nnodes=1 --nproc-per-node=2 main.py and I get

{'x': tensor([2])}
{'x': tensor([1])}
{'x': tensor([3])}
{'x': tensor([4])}
{'x': tensor([5])}
{'x': tensor([2])}

we indeed end up with 6 samples, {'x': tensor([2])} was repeated to get 6 examples in total which is divisible by the world size 2.

I also tried with more ranks and with num_workers in DataLoader and it works as expected (don't forget to add if __name__ == '__main__': if necessary for DataLoader multiprocessing)

EDIT: replaced cycle(chain(dl, ["end"])) by loop() after comment #6623 (comment) by @ragavsachdeva

@alex-hh
Copy link
Contributor

alex-hh commented Sep 27, 2024

great thanks for the example, will give it a try!

@alex-hh
Copy link
Contributor

alex-hh commented Oct 2, 2024

@lhoestq in the case where dataset.n_shards is divisible by world_size, is it important that each shard contains exactly the same number of samples? what happens if this isn't the case (in what circumstances will this cause a timeout)?

@lhoestq
Copy link
Member

lhoestq commented Oct 2, 2024

If your data are not evenly divisible (dataset.n_shards divisibility by world_size just changes the logic to distribute the data) you'll need some logic to make the GPUs happy at the end of training. E.g. with my example above to stop once all the data from all the ranks are exhausted (and repeating data if necessary to end up with a number of samples evenly divisible)

Though if dataset.n_shards is divisible by world_size and each shard contains the same amount of data then your data IS evenly divisible so you are all good

@alex-hh
Copy link
Contributor

alex-hh commented Oct 2, 2024

Ok makes sense, thanks for the explanation. I guess even if the shards all contain the same amount of data you still have an issue if you do any filtering (#6719)

What do you think of dataset.repeat(n).take(samples_per_epoch) as a simple way of handling this kind of situation? (c.f. issue I just opened #7192 ).

@lhoestq
Copy link
Member

lhoestq commented Oct 3, 2024

yes it makes sense indeed

@ragavsachdeva
Copy link

I made a quick example with communication between ranks to stop once all the data from all the ranks are exhausted (and repeating data if necessary to end up with a number of samples evenly divisible)

from itertools import cycle, chain
...
# IMPORTANT: Loop over the local dataset until the data from each rank has been exhausted
for x in cycle(chain(dl, ["end"])):
    if x == "end":
        exhausted[rank] = True
        continue
    # stop once the data from all the ranks are exhausted
    dist.all_reduce(exhausted)
    if torch.all(exhausted):
        break
    # do your forward pass + loss here
    # model.forward(...)
    print(x)

Just incase someone copy pastes this in their code (like I did), please be aware of pytorch/pytorch#23900 and use pytorch/pytorch#23900 (comment).

@lhoestq
Copy link
Member

lhoestq commented Oct 9, 2024

Thanks for noticing @ragavsachdeva ! I edited my code to fix the issue

@JohnHerry
Copy link

I have a node with 8 cards and training files splited into 56 sub files, so my n_shards= 56 / 8 = 7; my initial num_workers = 32, and it report that n_shards = 7 < num_workers, so 25 wokers are stoped, as a result, my training can use only 7 cpu cores at all. should I set my num_wokers less then 7 to get more cpu cores worked?

@lhoestq
Copy link
Member

lhoestq commented Oct 15, 2024

In your case each rank has a DataLoader with 7 running workers (and 25 stopped workers) so actually in total there are 8*7=56 DataLoader workers running (one per shard).

If you want to use more CPU for the DataLoader you can shard your dataset in more files than 56. E.g. if you want each rank to run 32 DataLoader workers you need 8*32=256 files.

@JohnHerry
Copy link

Thank you for the help, I will have a try

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

7 participants