Skip to content

Commit

Permalink
Fix fsdp+pp+te WPS decreasing issue (#1139)
Browse files Browse the repository at this point in the history
* Fix fsdp+pp+te WPS decreasing issue

* Address comment; remove unused stuff

* split into wps fix P841842878 only and main_grad fix
  • Loading branch information
jianyuh authored Oct 2, 2023
1 parent e910320 commit 0db6e62
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,10 @@ def _register_post_backward_hooks(self) -> None:
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
p._shard_bwd_hook = (grad_acc, handle)
if not hasattr(p, "_shard_bwd_hooks"):
p._shard_bwd_hooks = []
p._shard_bwd_hooks.append((grad_acc, handle))
# p._shard_bwd_hook = (grad_acc, handle)

@torch.no_grad()
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
Expand Down Expand Up @@ -1860,6 +1863,9 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
# p._shard_bwd_hook[1].remove()
# delattr(p, "_shard_bwd_hook")
if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync:
for _, handle in p._shard_bwd_hooks:
handle.remove()

# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
Expand All @@ -1876,7 +1882,10 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
)
p.grad = p._saved_grad_shard
if p._saved_grad_shard.dtype != p.dtype:
p.grad = p._saved_grad_shard.to(p.dtype)
else:
p.grad = p._saved_grad_shard

if hasattr(p, "_saved_grad_shard"):
delattr(p, "_saved_grad_shard")
Expand Down

0 comments on commit 0db6e62

Please sign in to comment.