From 02c403642da7b8844064e8d78bc246a0077ca439 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Sun, 1 Oct 2023 14:57:52 -0700 Subject: [PATCH 1/3] Fix fsdp+pp+te WPS decreasing issue --- .../fully_sharded_data_parallel.py | 38 ++++++++++++++++--- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 4e8e62180..050c132c7 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -557,6 +557,7 @@ def __init__( self.dont_wait_current_stream_for_post_all_gather = False self._all_gather_free_event_queue = _FreeEventQueue() if limit_all_gather_events else None self._reduce_scatter_free_event_queue = _FreeEventQueue() if limit_reduce_scatter_events else None + self._module_fqn = None def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 @@ -1220,6 +1221,9 @@ def _lazy_init(self) -> None: self._set_is_root() self._setup_streams() self._setup_output_hook_list() + for module_name, module in self.named_modules(): + if isinstance(module, FullyShardedDataParallel): + module._module_fqn = module_name if self._is_root: # Buffers stay on GPU, and don't get sharded. Since _cast_buffers @@ -1650,6 +1654,9 @@ def _register_post_backward_hooks(self) -> None: assert p_tmp.grad_fn is not None grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object. handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p)) + if not hasattr(p, "_shard_bwd_hooks"): + p._shard_bwd_hooks = [] + p._shard_bwd_hooks.append((grad_acc, handle)) p._shard_bwd_hook = (grad_acc, handle) @torch.no_grad() @@ -1710,6 +1717,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Switch to FP32 shard after backward. self._use_fp32_param_shard([param]) + if self.mixed_precision and self.fp32_reduce_scatter: + if getattr(param, "main_grad", None) is None: + param.main_grad = param.grad.to(torch.float32) + else: + param.main_grad.add_(param.grad.data) + + param.grad = None if not self._require_backward_grad_sync: return @@ -1718,15 +1732,19 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): - orig_grad_data = param.grad.data if self.fp32_reduce_scatter: # Cast grad to FP32. - param.grad.data = param.grad.data.float() + orig_grad_data = param.grad.data.float() + else: + orig_grad_data = param.grad.data if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.gradient_predivide_factor) + if getattr(param, "main_grad", None) is not None: + param.main_grad.data.div_(self.gradient_predivide_factor) + else: + param.grad.data.div_(self.gradient_predivide_factor) if param._is_sharded: assert self._reducer is not None @@ -1734,7 +1752,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - grad = param.grad.data + if getattr(param, "main_grad", None) is not None: + grad = param.main_grad.data + param.main_grad = None + else: + grad = param.grad.data # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. # # The effect on memory consumption is not usually significant. No extra memory is allocated if this @@ -1860,6 +1882,9 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}") # p._shard_bwd_hook[1].remove() # delattr(p, "_shard_bwd_hook") + if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync: + for _, handle in p._shard_bwd_hooks: + handle.remove() # Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad # remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard @@ -1876,7 +1901,10 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: p.device == p._saved_grad_shard.device, f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}", ) - p.grad = p._saved_grad_shard + if p._saved_grad_shard.dtype != p.dtype: + p.grad = p._saved_grad_shard.to(p.dtype) + else: + p.grad = p._saved_grad_shard if hasattr(p, "_saved_grad_shard"): delattr(p, "_saved_grad_shard") From 65837b28b7cccc96cbe39bc5e55e005db9731a49 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Sun, 1 Oct 2023 17:16:34 -0700 Subject: [PATCH 2/3] Address comment; remove unused stuff --- .../nn/data_parallel/fully_sharded_data_parallel.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 050c132c7..2f6eb6d54 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -557,7 +557,6 @@ def __init__( self.dont_wait_current_stream_for_post_all_gather = False self._all_gather_free_event_queue = _FreeEventQueue() if limit_all_gather_events else None self._reduce_scatter_free_event_queue = _FreeEventQueue() if limit_reduce_scatter_events else None - self._module_fqn = None def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 @@ -1221,9 +1220,6 @@ def _lazy_init(self) -> None: self._set_is_root() self._setup_streams() self._setup_output_hook_list() - for module_name, module in self.named_modules(): - if isinstance(module, FullyShardedDataParallel): - module._module_fqn = module_name if self._is_root: # Buffers stay on GPU, and don't get sharded. Since _cast_buffers @@ -1657,7 +1653,7 @@ def _register_post_backward_hooks(self) -> None: if not hasattr(p, "_shard_bwd_hooks"): p._shard_bwd_hooks = [] p._shard_bwd_hooks.append((grad_acc, handle)) - p._shard_bwd_hook = (grad_acc, handle) + # p._shard_bwd_hook = (grad_acc, handle) @torch.no_grad() def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: @@ -1735,9 +1731,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if self.fp32_reduce_scatter: # Cast grad to FP32. - orig_grad_data = param.grad.data.float() - else: - orig_grad_data = param.grad.data + param.grad.data = param.grad.data.float() + + orig_grad_data = param.grad.data if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. From 45cd03858b34f8b945827cb803dd4bdcaae3b113 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Sun, 1 Oct 2023 18:05:58 -0700 Subject: [PATCH 3/3] split into wps fix P841842878 only and main_grad fix --- .../fully_sharded_data_parallel.py | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 2f6eb6d54..759b9f445 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1713,13 +1713,6 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Switch to FP32 shard after backward. self._use_fp32_param_shard([param]) - if self.mixed_precision and self.fp32_reduce_scatter: - if getattr(param, "main_grad", None) is None: - param.main_grad = param.grad.to(torch.float32) - else: - param.main_grad.add_(param.grad.data) - - param.grad = None if not self._require_backward_grad_sync: return @@ -1728,19 +1721,15 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): + orig_grad_data = param.grad.data if self.fp32_reduce_scatter: # Cast grad to FP32. param.grad.data = param.grad.data.float() - orig_grad_data = param.grad.data - if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - if getattr(param, "main_grad", None) is not None: - param.main_grad.data.div_(self.gradient_predivide_factor) - else: - param.grad.data.div_(self.gradient_predivide_factor) + param.grad.data.div_(self.gradient_predivide_factor) if param._is_sharded: assert self._reducer is not None @@ -1748,11 +1737,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - if getattr(param, "main_grad", None) is not None: - grad = param.main_grad.data - param.main_grad = None - else: - grad = param.grad.data + grad = param.grad.data # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. # # The effect on memory consumption is not usually significant. No extra memory is allocated if this