You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is actually B+W; it computes the gradients with respect to the inputs and the weights.
While, in the schedule_w function, W counts the duration of this:
WeightGradStore.pop(chunk=scheduled_node.chunk)
After conducting a global search for WeightGradStore.put, I found that it actually only puts operations for gradient accumulation, specifically fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 or fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16. Therefore, W actually counts only the operation time for gradient accumulation!
The timing statistics from the timer also prove this point, with W's duration being very short and B's duration being almost double that of F:
Is this the expected result?
The text was updated successfully, but these errors were encountered:
Hi, thanks for the interest in our work and implementation.
The wgrad_gemm_accum_fp16 actually do both weight grad calculation and weight accumulation. It's a fused kernel of weight grad calculation and accumulation, that's why it's called gemm + accum.
Just to make sure, is gradient_accumulation_fusion enabled for your setting? Our implementation of B-W split only works when gradient_accumulation_fusion enabled.
I did not add the --no-gradient-accumulation-fusion parameter and get_args().gradient_accumulation_fusion is True when running. However, W has such a short runtime, while B takes almost twice as long as F. I wonder if there could be other reasons for this?
Your question
It seems that B's timing includes W, while W merely accounts for the time of gradient accumulation.
In the megatron/core/pipeline_parallel/zb_schedules.py file, the function
schedule_b
counts the duration of this:This is actually B+W; it computes the gradients with respect to the inputs and the weights.
While, in the schedule_w function, W counts the duration of this:
After conducting a global search for
WeightGradStore.put
, I found that it actually only puts operations for gradient accumulation, specificallyfused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32
orfused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16
. Therefore, W actually counts only the operation time for gradient accumulation!The timing statistics from the timer also prove this point, with W's duration being very short and B's duration being almost double that of F:
Is this the expected result?
The text was updated successfully, but these errors were encountered: