Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fsdp ptd #1085

Draft
wants to merge 2 commits into
base: fixing_memory_issues_with_keeping_overlap_may24
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 66 additions & 15 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import collections
import contextlib
import copy
from dataclasses import dataclass
Expand Down Expand Up @@ -75,6 +76,43 @@
pass


class _FreeEventQueue:
"""
This tracks all pending frees corresponding to inflight all-gathers. The
queueing pattern is iterative enqueues followed by a flush, and the current
heuristic for the flush is based on the number of inflight all-gathers.
"""

def __init__(self) -> None:
self._queue: Deque[torch.cuda.Event] = collections.deque()
self._max_num_inflight_all_gathers = 2 # empirically chosen

def enqueue(self, free_event: torch.cuda.Event) -> None:
"""Enqueues a free event."""
self._queue.append(free_event)

def flush_if_needed(self) -> List[torch.cuda.Event]:
"""
If the queue should be flushed (based on an internal criteria), then
this returns a non-empty :class:`list` of free events. Otherwise, this
returns an empty :class:`list`.
"""
events: List[torch.cuda.Event] = []
if len(self._queue) >= self._max_num_inflight_all_gathers:
while self._queue:
event = self._dequeue()
assert event is not None
events.append(event)
return events

def _dequeue(self) -> Optional[torch.cuda.Event]:
"""Dequeues a free event if possible."""
if self._queue:
event = self._queue.popleft()
return event
return None


class TrainingState(Enum):
"""
Simple enum to indicate what state FSDP is in. Used for asserting
Expand Down Expand Up @@ -311,7 +349,7 @@ def __init__(
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
# The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName
process_group_reduce_scatter: Any = ProcessGroupName.default,
process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter,
reshard_after_forward: bool = True,
disable_reshard_on_root: bool = True,
mixed_precision: bool = False,
Expand Down Expand Up @@ -1157,6 +1195,7 @@ def _reset_lazy_init(self) -> None:
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
self._fsdp_forward_ordering: List[nn.Module] = []
self._free_event_queue = _FreeEventQueue()
self._my_fsdp_instance_idx: Optional[int] = None
for p in self.params:
if hasattr(p, "_fp32_shard"):
Expand Down Expand Up @@ -1330,6 +1369,8 @@ def _set_is_root(self) -> None:
(m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
)
m._fsdp_forward_ordering = self._fsdp_forward_ordering
m._free_event_queue = self._free_event_queue


def _setup_streams(self) -> None:
"""Create streams to overlap data transfer and computation."""
Expand Down Expand Up @@ -1404,13 +1445,14 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self._rebuild_full_params()

if (
self._fsdp_forward_ordering is not None
and self._my_fsdp_instance_idx is not None and self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1
):
self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._rebuild_full_params(
wait_for_all_gather=False
)
# if (
# self._fsdp_forward_ordering is not None
# and self._my_fsdp_instance_idx is not None and self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1
# ):

# self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._rebuild_full_params(
# wait_for_all_gather=False
# )

# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
Expand Down Expand Up @@ -1500,12 +1542,12 @@ def _pre_backward_hook(*unused: Any) -> None:
# overhead.
if self.reshard_after_forward:
self._rebuild_full_params()
if (
self.reshard_after_forward
and self._fsdp_forward_ordering is not None
and self._my_fsdp_instance_idx is not None and self._my_fsdp_instance_idx > 0
):
self._fsdp_forward_ordering[self._my_fsdp_instance_idx - 1]._rebuild_full_params(wait_for_all_gather=False)
# if (
# self.reshard_after_forward
# and self._fsdp_forward_ordering is not None
# and self._my_fsdp_instance_idx is not None and self._my_fsdp_instance_idx > 0
# ):
# self._fsdp_forward_ordering[self._my_fsdp_instance_idx - 1]._rebuild_full_params(wait_for_all_gather=False)
else:
self._use_full_params()

Expand Down Expand Up @@ -1889,6 +1931,12 @@ def _rebuild_full_params(self, force_full_precision: bool = False, wait_for_all_
caller to free the full-sized param. This will be ``None`` if
``force_full_precision=False`` and the full params are already gathered.
"""

events = self._free_event_queue.flush_if_needed()
if events:
# As a minor optimization, only synchronize the latest event
events[-1].synchronize()

output_tensors: List[Tuple[torch.Tensor, bool]] = []

def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
Expand Down Expand Up @@ -2071,7 +2119,10 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage_(p._full_param_padded)
torch.cuda.current_stream().synchronize()

free_event = torch.cuda.Event()
free_event.record()
self._free_event_queue.enqueue(free_event)

def local_metadata_dict(self) -> Dict[str, Any]:
"""
Expand Down