Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] The timing for B and W appears to be incorrect #22

Open
RookieHong opened this issue May 15, 2024 · 2 comments
Open

[QUESTION] The timing for B and W appears to be incorrect #22

RookieHong opened this issue May 15, 2024 · 2 comments

Comments

@RookieHong
Copy link

Your question
It seems that B's timing includes W, while W merely accounts for the time of gradient accumulation.

In the megatron/core/pipeline_parallel/zb_schedules.py file, the function schedule_b counts the duration of this:

input_tensor_grad = backward_step(
    input_tensor, output_tensor, output_tensor_grad, self.model_type,
    self.config
)

This is actually B+W; it computes the gradients with respect to the inputs and the weights.

While, in the schedule_w function, W counts the duration of this:

WeightGradStore.pop(chunk=scheduled_node.chunk)

After conducting a global search for WeightGradStore.put, I found that it actually only puts operations for gradient accumulation, specifically fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 or fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16. Therefore, W actually counts only the operation time for gradient accumulation!

The timing statistics from the timer also prove this point, with W's duration being very short and B's duration being almost double that of F:
image

Is this the expected result?

@ufotalent
Copy link

Hi, thanks for the interest in our work and implementation.

The wgrad_gemm_accum_fp16 actually do both weight grad calculation and weight accumulation. It's a fused kernel of weight grad calculation and accumulation, that's why it's called gemm + accum.

Just to make sure, is gradient_accumulation_fusion enabled for your setting? Our implementation of B-W split only works when gradient_accumulation_fusion enabled.

@RookieHong
Copy link
Author

Thanks for the reply!

I did not add the --no-gradient-accumulation-fusion parameter and get_args().gradient_accumulation_fusion is True when running. However, W has such a short runtime, while B takes almost twice as long as F. I wonder if there could be other reasons for this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants