From 59b2ab5dfdeaad25d189150e239a67af89afcc11 Mon Sep 17 00:00:00 2001 From: Ruan Silva Date: Sat, 24 Sep 2022 00:40:04 +0000 Subject: [PATCH 1/2] try fsdp with cuda event queue --- .../fully_sharded_data_parallel.py | 79 +++++++++++++++---- 1 file changed, 65 insertions(+), 14 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 674601c49..cec48a5b5 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +import collections import contextlib import copy from dataclasses import dataclass @@ -75,6 +76,43 @@ pass +class _FreeEventQueue: + """ + This tracks all pending frees corresponding to inflight all-gathers. The + queueing pattern is iterative enqueues followed by a flush, and the current + heuristic for the flush is based on the number of inflight all-gathers. + """ + + def __init__(self) -> None: + self._queue: Deque[torch.cuda.Event] = collections.deque() + self._max_num_inflight_all_gathers = 2 # empirically chosen + + def enqueue(self, free_event: torch.cuda.Event) -> None: + """Enqueues a free event.""" + self._queue.append(free_event) + + def flush_if_needed(self) -> List[torch.cuda.Event]: + """ + If the queue should be flushed (based on an internal criteria), then + this returns a non-empty :class:`list` of free events. Otherwise, this + returns an empty :class:`list`. + """ + events: List[torch.cuda.Event] = [] + if len(self._queue) >= self._max_num_inflight_all_gathers: + while self._queue: + event = self._dequeue() + assert event is not None + events.append(event) + return events + + def _dequeue(self) -> Optional[torch.cuda.Event]: + """Dequeues a free event if possible.""" + if self._queue: + event = self._queue.popleft() + return event + return None + + class TrainingState(Enum): """ Simple enum to indicate what state FSDP is in. Used for asserting @@ -1157,6 +1195,7 @@ def _reset_lazy_init(self) -> None: self._streams: Dict[str, torch.cuda.Stream] = {} self._reducer: Optional[ReduceScatterBucketer] = None self._fsdp_forward_ordering: List[nn.Module] = [] + self._free_event_queue = _FreeEventQueue() self._my_fsdp_instance_idx: Optional[int] = None for p in self.params: if hasattr(p, "_fp32_shard"): @@ -1330,6 +1369,8 @@ def _set_is_root(self) -> None: (m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group) ) m._fsdp_forward_ordering = self._fsdp_forward_ordering + m._free_event_queue = self._free_event_queue + def _setup_streams(self) -> None: """Create streams to overlap data transfer and computation.""" @@ -1404,13 +1445,14 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``). self._rebuild_full_params() - if ( - self._fsdp_forward_ordering is not None - and self._my_fsdp_instance_idx is not None and self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1 - ): - self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._rebuild_full_params( - wait_for_all_gather=False - ) + # if ( + # self._fsdp_forward_ordering is not None + # and self._my_fsdp_instance_idx is not None and self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1 + # ): + + # self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._rebuild_full_params( + # wait_for_all_gather=False + # ) # Register backward hooks to reshard params and reduce-scatter grads. # These need to be re-registered every forward pass. @@ -1500,12 +1542,12 @@ def _pre_backward_hook(*unused: Any) -> None: # overhead. if self.reshard_after_forward: self._rebuild_full_params() - if ( - self.reshard_after_forward - and self._fsdp_forward_ordering is not None - and self._my_fsdp_instance_idx is not None and self._my_fsdp_instance_idx > 0 - ): - self._fsdp_forward_ordering[self._my_fsdp_instance_idx - 1]._rebuild_full_params(wait_for_all_gather=False) + # if ( + # self.reshard_after_forward + # and self._fsdp_forward_ordering is not None + # and self._my_fsdp_instance_idx is not None and self._my_fsdp_instance_idx > 0 + # ): + # self._fsdp_forward_ordering[self._my_fsdp_instance_idx - 1]._rebuild_full_params(wait_for_all_gather=False) else: self._use_full_params() @@ -1889,6 +1931,12 @@ def _rebuild_full_params(self, force_full_precision: bool = False, wait_for_all_ caller to free the full-sized param. This will be ``None`` if ``force_full_precision=False`` and the full params are already gathered. """ + + events = self._free_event_queue.flush_if_needed() + if events: + # As a minor optimization, only synchronize the latest event + events[-1].synchronize() + output_tensors: List[Tuple[torch.Tensor, bool]] = [] def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: @@ -2071,7 +2119,10 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: # Storage object and unshard it in-place. For now, just resize # the Storage to 0 to save memory. free_storage_(p._full_param_padded) - torch.cuda.current_stream().synchronize() + + free_event = torch.cuda.Event() + free_event.record() + self._free_event_queue.enqueue(free_event) def local_metadata_dict(self) -> Dict[str, Any]: """ From 592c0b93500140e5afbbb50e7cf89261d427be68 Mon Sep 17 00:00:00 2001 From: Ruan Silva Date: Tue, 27 Sep 2022 20:30:22 +0000 Subject: [PATCH 2/2] update PG --- fairscale/nn/data_parallel/fully_sharded_data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index cec48a5b5..55478e59e 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -349,7 +349,7 @@ def __init__( module: nn.Module, process_group: Optional[ProcessGroup] = None, # The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName - process_group_reduce_scatter: Any = ProcessGroupName.default, + process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter, reshard_after_forward: bool = True, disable_reshard_on_root: bool = True, mixed_precision: bool = False,