From 0dfa5e576be4d9f2251af20f8fdfd92c6fca3fa1 Mon Sep 17 00:00:00 2001 From: Naman Goyal Date: Fri, 30 Aug 2024 13:15:04 -0700 Subject: [PATCH] support for grad acc --- fairscale/nn/data_parallel/fully_sharded_data_parallel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index cdd6e6e8c..5584e6d02 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1821,7 +1821,10 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> param._saved_grad_shard.data += reduced_grad.data reduced_grad = param._saved_grad_shard.data elif (param.grad is None) and self.fp32_reduce_scatter: - param.main_grad = reduced_grad.data + if getattr(param, "main_grad", None) is not None: + param.main_grad.add_(reduced_grad.data) + else: + param.main_grad = reduced_grad.data # Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full # backwards pass completes, we will set `.grad` to the CPU copy.