Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fsdp+pp+te WPS decreasing issue #1139

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,7 +1650,10 @@ def _register_post_backward_hooks(self) -> None:
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
p._shard_bwd_hook = (grad_acc, handle)
if not hasattr(p, "_shard_bwd_hooks"):
p._shard_bwd_hooks = []
p._shard_bwd_hooks.append((grad_acc, handle))
# p._shard_bwd_hook = (grad_acc, handle)

@torch.no_grad()
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
Expand Down Expand Up @@ -1860,6 +1863,9 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
# p._shard_bwd_hook[1].remove()
# delattr(p, "_shard_bwd_hook")
if hasattr(p, "_shard_bwd_hooks") and self._require_backward_grad_sync:
for _, handle in p._shard_bwd_hooks:
handle.remove()

# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p._saved_grad_shard
Expand All @@ -1876,7 +1882,10 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
)
p.grad = p._saved_grad_shard
if p._saved_grad_shard.dtype != p.dtype:
p.grad = p._saved_grad_shard.to(p.dtype)
else:
p.grad = p._saved_grad_shard

if hasattr(p, "_saved_grad_shard"):
delattr(p, "_saved_grad_shard")
Expand Down
Loading