Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potentially incorrect calculation of total_updates on >=4.46.0 since #34198 affecting multi gpu training #35387

Closed
2 of 4 tasks
chiragjn opened this issue Dec 21, 2024 · 2 comments

Comments

@chiragjn
Copy link

chiragjn commented Dec 21, 2024

System Info

Okay I have been pulling my hair out for a few hours and turns out this bug only happens when average_tokens_across_devices is True and epochs > 1

Simplest case to reproduce

DDP
world size 2
dataset length = 4
epochs = 2
micro batch size = 1 (aka per gpu batch size)
gradient accumulation = 1
average_tokens_across_devices = True

So every epoch, total 2 steps on both devices

but as first epoch finishes, we get

[rank1]:   File "/home/jovyan/axolotl/src/axolotl/train.py", line 191, in train
[rank1]:     trainer.train(resume_from_checkpoint=resume_from_checkpoint)
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/transformers/trainer.py", line 2164, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/transformers/trainer.py", line 2473, in _inner_training_loop
[rank1]:     batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
[rank1]:                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/transformers/trainer.py", line 5142, in get_batch_samples
[rank1]:     num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
[rank1]:                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/accelerate/accelerator.py", line 2459, in gather
[rank1]:     return gather(tensor)
[rank1]:            ^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/accelerate/utils/operations.py", line 376, in wrapper
[rank1]:     return function(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/accelerate/utils/operations.py", line 437, in gather
[rank1]:     return _gpu_gather(tensor)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/accelerate/utils/operations.py", line 356, in _gpu_gather
[rank1]:     return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/home/jovyan/.conda/envs/jupyter-base/lib/python3.11/site-packages/accelerate/utils/operations.py", line 129, in recursively_apply
[rank1]:     raise TypeError(
[rank1]: TypeError: Unsupported types (<class 'NoneType'>) passed to `_gpu_gather_one`. Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` should be passed.

The main culprit here is

total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1
for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
for i, inputs in enumerate(batch_samples):

steps_in_epoch per rank is correctly calculated as 2 but total updates is 3
Normally that is harmless because dataloader would be exhausted and would result in empty batch and it won't enter the loop on 2473.
However, when using the recently added option average_tokens_across_devices, it will try to gather number of total items in batches across all ranks and gather doesn't like broadcasting None

def get_batch_samples(self, epoch_iterator, num_batches):
batch_samples = []
num_items_in_batch = None
for _ in range(num_batches):
try:
batch_samples += [next(epoch_iterator)]
except StopIteration:
break
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
except (TypeError, AttributeError):
pass
if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()

This problem does not surface with 1 gpu because average_tokens_across_devices is auto set to False and neither under epoch = 1 because DefaultFlowCallback stops the training process considering global step and expected max steps

Who can help?

@muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Train any LM on more than one gpu, set at least average_tokens_across_devices = True and epochs > 1

Expected behavior

Either we fix total_updates count or we handle None for gather

@SunMarc
Copy link
Member

SunMarc commented Dec 23, 2024

Thanks for the detailed report @chiragjn ! I think the simplest solution would be to set num_items_in_batch as a tensor of 0 when the value is None since it means that batch_samples doesn't have any element. Can you check if this solves the issue ?

if self.args.average_tokens_across_devices:
    if num_items_in_batch is None:
        num_items_in_batch = torch.tensor(0)

@chiragjn
Copy link
Author

chiragjn commented Jan 7, 2025

Fixed by #35102

@chiragjn chiragjn closed this as completed Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants