Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "[FSDP] Add limiter using CUDA events"
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
- Loading branch information