Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory grows due to keeping losses on device #763

Closed
carmocca opened this issue Dec 27, 2024 · 3 comments · Fixed by #779
Closed

Memory grows due to keeping losses on device #763

carmocca opened this issue Dec 27, 2024 · 3 comments · Fixed by #779
Labels
better_engineering Repo code quality improvements good first issue Good for newcomers

Comments

@carmocca
Copy link
Contributor

If logging is disabled (or very infrequent), the memory usage slowly grows because the max and average loss is kept in a list on-device: https://github.com/pytorch/torchtitan/blob/main/train.py#L353-L354

The training loop should offload these tensors to the CPU right after their aggregation is finished. Especially because the logging prints will do that anyways under the hood

@tianyu-l
Copy link
Contributor

Thanks for raising this issue, @carmocca !

It looks to me that what's kept on device is https://github.com/pytorch/torchtitan/blob/main/train.py#L330 (the max and average you mentioned are on CPU?).

I think the reason we keep it is because we don't want call .item() (which incurs synchronization between CPU and GPU) unless hitting a log step. I do agree that if logging is disabled / infrequent, this overhead is unnecessary. Although, may I ask what's the use case where you'd log too infrequently so that this overhead becomes unacceptable?

@carmocca
Copy link
Contributor Author

carmocca commented Jan 6, 2025

One case would be debugging or testing: you don't want to dump logs to TensorBoard but then your memory will grow until OOM. The code currently expects that you eventually log

Another is small and fast models where the step time is small and it's excessive to log that often. There you might notice the ragged look in the memory metrics.

Also, I don't personally like the default behaviour of smoothing the loss over the last log_freq steps. I find it misleading when compared to log_freq=1. My expectation was that log_freq > 1 would simply log the loss at that modulo step. But that might just be me.

@tianyu-l
Copy link
Contributor

tianyu-l commented Jan 7, 2025

My expectation was that log_freq > 1 would simply log the loss at that modulo step.

Hmm makes sense to me. Please feel free to change the behavior with a PR. Looks it'll address both issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_engineering Repo code quality improvements good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants