Skip to content

Commit

Permalink
Update on "[FSDP] Add limiter using CUDA events"
Browse files Browse the repository at this point in the history
This PR tackles the high GPU reserved memory issue for FSDP.

Currently:
- This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter.
- If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used).
- Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there.
- I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum.

### High-GPU Reserved Memory

#### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972)
- [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394)
- [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1)
- [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1)
- Prefetch before freeing the padded unsharded flattened parameter
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch)

#### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052)
- [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431)
- [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522)
- [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1)
- [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1)
- [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106)
- [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch)

#### PT-D FSDP Approach
- In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter
- In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes


#### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size

<details>
  <summary> `all_gather_issue_limit=None` </summary>
  
![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png)

Peak GPU reserved memory: 6784 MB = 6.784 GB
Time / batch: 3.4 s

</details>

<details>
  <summary> `all_gather_issue_limit=2` </summary>
  
![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png)

Peak GPU reserved memory: 5846 MB = 5.846 GB
Time / batch: 3.4 s

</details>


[ghstack-poisoned]
  • Loading branch information
awgu committed Aug 22, 2022
2 parents ae67c2d + 536474e commit 746d0b3
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,42 +473,38 @@ def record_post_forward(self, handles: List[FlatParamHandle]) -> None:
iteration with the expectation that the recorded order is reset in
:meth:`next_iter`.
"""
if not handles:
return
handles_key = tuple(handles)
if handles_key in self.handles_to_post_forward_order_index:
if handles_key and handles_key in self.handles_to_post_forward_order_index:
return
index = len(self.handles_post_forward_order)
self.handles_to_post_forward_order_index[handles_key] = index
if handles_key:
self.handles_to_post_forward_order_index[handles_key] = index
self.handles_post_forward_order.append(handles_key)

def record_pre_forward(self, handles: List[FlatParamHandle]) -> None:
"""
Records ``handles`` in the pre-forward order on the first iteration,
where ``handles`` should be a group of handles used in the same
module's forward. If ``handles`` is empty, then it is omitted.
module's forward.
If the distributed debug level is at least INFO, then this additionally
checks the execution order across ranks. See :meth:`_check_order` for
details.
"""
# TODO (awgu): For now, we exclude modules with no parameters from the
# order, which is different from the existing implementation.
if not handles:
return
handles_key = tuple(handles)
if self._checking_order:
if self._checking_order and handles_key:
self._check_order(handles_key)
# Fix the order after the first iteration
# TODO (awgu): For now, only record the first usage of a module, which
# is consistent with the existing implementation.
if (
not self.is_first_iter
or handles_key in self.handles_to_pre_forward_order_index
or (handles_key and handles_key in self.handles_to_pre_forward_order_index)
):
return
index = len(self.handles_pre_forward_order)
self.handles_to_pre_forward_order_index[handles_key] = index
if handles_key:
self.handles_to_pre_forward_order_index[handles_key] = index
self.handles_pre_forward_order.append(handles_key)

def _check_order(self, handles_key: Tuple[FlatParamHandle, ...]) -> None:
Expand Down

0 comments on commit 746d0b3

Please sign in to comment.