Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a miniversion containing only ZB-H1 and essential changes so other megatron forks can easily integrate #10

Open
ufotalent opened this issue Dec 15, 2023 · 11 comments

Comments

@ufotalent
Copy link

@ufotalent To implement a version using our own running engine and async IO
@QPHutu To implement a version by modifying 1f1b schedule using sync IO

@robotsp
Copy link

robotsp commented Mar 2, 2024

I ran a llama-7B instance in pipeline parallel mode (pp size=8, tp size=1) using ZB-H1, but I found there is no exact better performance vs. the original one. Both durations of each step are the same.

Does it make sense? @QPHutu @ufotalent @P2333

@ufotalent
Copy link
Author

I ran a llama-7B instance in pipeline parallel mode (pp size=8, tp size=1) using ZB-H1, but I found there is no exact better performance vs. the original one. Both durations of each step are the same.

Does it make sense? @QPHutu @ufotalent @P2333

Hi thanks for trying out ZB-H1. The result seems like problematic because in this situation ZBH1 should provide some acceleration. May we know some details on the setup? Like what’s the code repo and what’s briefly the changes to enable ZBH1? Also what’s the number of mini batches in pp?

@robotsp
Copy link

robotsp commented Mar 4, 2024

I ran the latest version of Megatron-LM and patch the quick implementation for zb-h1(commit-id: 95212f7). The global batch size and the mini batch size of the testing are 256 and 4 respectively. @ufotalent @QPHutu @P2333

@ufotalent
Copy link
Author

I ran the latest version of Megatron-LM and patch the quick implementation for zb-h1(commit-id: 95212f7). The global batch size and the mini batch size of the testing are 256 and 4 respectively. @ufotalent @QPHutu @P2333

@robotsp is it possible to share the training script?
One possible suspect is the calling path is calling megatron.core.models.gpt.GPTModel instead of megatron.model.GPTModel. Currently the patch is effective on megatron.model.GPTModel.
Another suspect is that you're using interleaved-1f1b (by setting flags num_layers_per_virtual_pipeline_stage). Our current simple patch is for 1F1B schedule, not interleaved 1F1B.
Since your patched code works but just producing identical run times, I feel like the patched code path is not executed.
Thanks

@GeLee-Q
Copy link

GeLee-Q commented Jan 3, 2025

is it possible to share the training script?
One possible suspect is the calling path is calling megatron.core.models.gpt.GPTModel instead of megatron.model.GPTModel. Currently the patch is effective on megatron.model.GPTModel.
Another suspect is that you're using interleaved-1f1b (by setting flags num_layers_per_virtual_pipeline_stage). Our current simple patch is for 1F1B schedule, not interleaved 1F1B.
Since your patched code works but just producing identical run times, I feel like the patched code path is not executed.
Thanks

@ufotalent
Hello, I have also encountered the issue of it not taking effect. Could you please explain what the following sentence means? I noticed that the example in the submitted code was modified in megatron.core.

Currently the patch is effective on megatron.model.GPTModel.

Here are some of my parameter settings:

MODEL_PARALLEL_ARGS=(
    --tensor-model-parallel-size 2
    --pipeline-model-parallel-size 4
    --use-distributed-optimizer
    --overlap-grad-reduce
    --overlap-param-gather
    --distributed-backend nccl
    # --sequence-parallel
    --no-gradient-accumulation-fusion
    --transformer-impl local
)

@ufotalent
Copy link
Author

is it possible to share the training script?
One possible suspect is the calling path is calling megatron.core.models.gpt.GPTModel instead of megatron.model.GPTModel. Currently the patch is effective on megatron.model.GPTModel.
Another suspect is that you're using interleaved-1f1b (by setting flags num_layers_per_virtual_pipeline_stage). Our current simple patch is for 1F1B schedule, not interleaved 1F1B.
Since your patched code works but just producing identical run times, I feel like the patched code path is not executed.
Thanks

@ufotalent Hello, I have also encountered the issue of it not taking effect. Could you please explain what the following sentence means? I noticed that the example in the submitted code was modified in megatron.core.

Currently the patch is effective on megatron.model.GPTModel.

Here are some of my parameter settings:

MODEL_PARALLEL_ARGS=(
    --tensor-model-parallel-size 2
    --pipeline-model-parallel-size 4
    --use-distributed-optimizer
    --overlap-grad-reduce
    --overlap-param-gather
    --distributed-backend nccl
    # --sequence-parallel
    --no-gradient-accumulation-fusion
    --transformer-impl local
)

Hi @GeLee-Q , Thanks for the interest in our work. Could you share which version (or git commit) are you using? Thanks

@GeLee-Q
Copy link

GeLee-Q commented Jan 9, 2025

Hi @GeLee-Q , Thanks for the interest in our work. Could you share which version (or git commit) are you using? Thanks

Thank you, I have integrated these two modifications into my own version of the Megatron-LM library that I use.

NVIDIA@95212f7#diff-6078a722754eba8b855a8156b2dc22283858d10acd2d6bc8115086f35d4fbb7b

NVIDIA@a84d634

@ufotalent
Copy link
Author

Hi @GeLee-Q , Thanks for the interest in our work. Could you share which version (or git commit) are you using? Thanks

Thank you, I have integrated these two modifications into my own version of the Megatron-LM library that I use.

NVIDIA@95212f7#diff-6078a722754eba8b855a8156b2dc22283858d10acd2d6bc8115086f35d4fbb7b

NVIDIA@a84d634

Oh I get the problem. I think the reason is you turned on 'no_gradient_accumulation_fusion' which skips our modification on the layers.py

My suggestion is to remove this flag. If you'll need to turn off grad accumulation for some reason, then in the W pass you'll need to manually do the W matmul and grad accumulation in sth like this:

grad_weight = grad_output.t().matmul(total_input)
weight.main_grad += grad_weight

@ufotalent
Copy link
Author

Sorry forgot to @GeLee-Q

@GeLee-Q
Copy link

GeLee-Q commented Jan 9, 2025

Hi @GeLee-Q , Thanks for the interest in our work. Could you share which version (or git commit) are you using? Thanks

Thank you, I have integrated these two modifications into my own version of the Megatron-LM library that I use.
NVIDIA@95212f7#diff-6078a722754eba8b855a8156b2dc22283858d10acd2d6bc8115086f35d4fbb7b
NVIDIA@a84d634

Oh I get the problem. I think the reason is you turned on 'no_gradient_accumulation_fusion' which skips our modification on the layers.py

My suggestion is to remove this flag. If you'll need to turn off grad accumulation for some reason, then in the W pass you'll need to manually do the W matmul and grad accumulation in sth like this:

grad_weight = grad_output.t().matmul(total_input)
weight.main_grad += grad_weight

Thank you very much for your guidance. I will continue to study your work in depth going forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants