-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix for high gpu reserved memory #972
base: main
Are you sure you want to change the base?
Conversation
Thanks Naman! Looks great and we should definitely incorporate. Couple of questions:
|
@@ -1978,8 +2001,8 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: | |||
|
|||
if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype): | |||
self._free_fp16_param_shard([p]) | |||
|
|||
torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious about this change? Why do we condition on the boolean?
|
cc @zhaojuanmao |
@ngoyal2707 are you disabling multiple process groups for reduce_scattter and all_gather in the backward, instead, using the backward prefetch approach here? PyTorch version of FSDP is doing the backward prefetch right now. This will better avoid over prefetching more layers in the backward pass using multiple process group approach |
but will it be conflict with multi process group approach? will the backward prefetch be a config? |
Tried this on AWS. Find it doubles roughly WPS for my particular workload. (1700 -> 3200) |
@@ -2047,6 +2070,7 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: | |||
# Storage object and unshard it in-place. For now, just resize | |||
# the Storage to 0 to save memory. | |||
free_storage_(p._full_param_padded) | |||
torch.cuda.current_stream().synchronize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ngoyal2707 I guess this helped releasing GPU reserved memory...
For @stephenroller models, were your training slowed down by high reserved memory and slow CUDA cache allocator? if so, maybe this change helped a lot...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but I'm wondering whether this will result in regression for other models which are not suffered from high reserved memory and slow CUDA cache allocator issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah my jobs are fine tuned of a large pre trained model. We consistently saw on both 40gb and 80gb nodes that the memory reserved was always 99%
if ( | ||
self._fsdp_forward_ordering is not None | ||
and self._my_fsdp_instance_idx is not None and self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1 | ||
): | ||
self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._rebuild_full_params( | ||
wait_for_all_gather=False | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice to add this forward prefetching!
Let's push this through |
To push this through we need an owner who can:
Let me know if anyone is up for taking this on! |
@anj-s and @stephenroller, I suspected the speedup is possibly mostly from this single sentence change "torch.cuda.current_stream().synchronize()" for @stephenroller's use case, possibly it is good to verify that. e..g, @stephenroller, just use FairScale master branch + this change "torch.cuda.current_stream().synchronize()" to your use case, and see whether there is a speed up. But this change is better to be optional, not be default, otherwise other FSDP use cases that are not suffered from high GPU reserved memory issue possibly will get regression. |
Can you explain a little bit about why you think there might be a regression? |
I'm not entirely sure about this. We also have large scale, not latency bound pretraining, and this change increased WPS there--when ideal reserved memory is only 25% of the node. What's clear in both cases was we were thrashing the allocator aggressively. This could maybe be a regression for places where the full state and activations fit in memory, and the network is able to fetch faster than the compute can utilize the state. But that case is better addressed with zero2 anyway, no? |
@ngoyal2707 is clearly the owner and should push this through. |
"What's clear in both cases was we were thrashing the allocator aggressively." ==> that mostly means "torch.cuda.current_stream().synchronize()" helps here, because CUDA allocation is a async operation, it could try to allocate a large storage before completely free full parameters here. So it is trashing. Adding this sync for your use case basically makes sure full parameters are freed before allocating. This helped reducing CUDA allocator trash. For the use cases that are not suffering from slow cudaAllocator, this blocking operation will wait for all pending async operations and reduce parallelizations, and thus could result in regression, no? |
So for unblocking this effort, suggestion here is:
|
So I agree with everyone here. I should be the owner but honestly I still don't know if this is the best approach here and I want opinion of people who understands cuda and pytorch caching allocator better The core reason of the issue is due to unbounded queue of launching kernels combined with multiple cuda streams. One very simple pytorch example of this is following:
above code outputs at end:
That means, we only reserve 16MB of memory for the 4M fp32 tensor cause its only been used by single stream, so pytorch caching allocator can just assign every time the same memory block as operations within stream are sequential. comparing above to following:
This outputs:
i.e. pytorch reserves 16MB tensor for every new allocation cause each memory block is being used by both current stream and So in FSDP case, So I think there possibly can be some regression, specially to very small models cause GPU can be idle while CPU thread is launching new kernels. |
Want to add a data point about the performance improvement on AWS brought by this PR : )
More context: With similar training setup we were getting wps ~12000 on Azure. For 175B model fine-tuning, this PR helps us close a lot the speed gap between Azure and AWS. |
adding @mrshenli who is looking into pluggable cudaCacheAllocator, and hopefully can resolve this issue in general |
… CUDA events" ### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
… CUDA events" ### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
… CUDA events" ### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
… CUDA events" ### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
… CUDA events" ### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
… CUDA events" ### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
… CUDA events" ### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
… CUDA events" ### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
### High-GPU Reserved Memory As I was generalizing PyTorch FSDP's prefetching, I inadvertently stumbled into the high GPU reserved memory issue (first surfaced from Fairscale FSDP). #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### [WIP] PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### Observations & Discussion 1. Blocking the CPU thread is critical for preventing the over-all-gathering. 2. For static graphs, pre-forward prefetching should use the pre-forward order, and post-forward prefetching should use the post-forward order. - Fairscale and PT-D FSDPs all follow this. - Post-forward prefetching is more conservative than pre-forward prefetching. Post-forward prefetching targets sibling-level prefetching only. Pre-forward prefetching follows the execution order. - We should investigate the performance difference between pre- and post-forward prefetching. - It seems that the post-forward prefetching is motivated by having the `current_stream().synchronize()` _after_ the unsharded parameter is freed. 3. For static graphs, backward prefetching should use the pre-backward order. 4. A mistargeted prefetch may be either (1) targeting an already unsharded parameter, (2) targeting a not yet unsharded, or (3) targeting an already resharded parameter. - Since `_rebuild_full_params()` has side effects (e.g. for mixed precision and CPU offloading), even (1) may cause performance degradation. - The previous PR makes `FullyShardedDataParallel._unshard()` a no-op for sharded strategies if already unsharded. This addresses case (1). - We may want to add some logic to guard against case (3). #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `allow_over_all_gather=True` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `allow_over_all_gather=False` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
This PR tackles the high GPU reserved memory issue for FSDP. Currently: - This adds an argument `all_gather_issue_limit: Optional[int]` to the FSDP constructor, where `None` disables the limiter and a positive integer enables the limiter. - If enabled, this limiter is only meaningful for `FULL_SHARD` and not for `SHARD_GRAD_OP` and `NO_SHARD` (since (1) we track free events, not all-gather events and (2) for the non-`FULL_SHARD` strategies, the reserved memory will inevitably be used). - Given this, ideally each sharding strategy can have its own attributes, and we can move this `all_gather_issue_limit` to only be an attribute for `FULL_SHARD`. This idea also applies to `HYBRID_SHARD` since one option then is to pass the second process group as an attribute there. - I want to discuss this since this does not seem backward compatible. I am not sure that with [enums](https://stackoverflow.com/questions/12680080/python-enums-with-attributes), we can have different attributes per enum. ### High-GPU Reserved Memory #### [Fairscale FSDP Approach 1](facebookresearch/fairscale#972) - [Record pre-forward order](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1392-L1394) - [Use pre-forward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1406-L1412) (pre-forward order index + 1) - [Use pre-backward prefetching](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1502-L1507) (pre-forward order index - 1) - Prefetch before freeing the padded unsharded flattened parameter - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/191553190d73a5ef4a48687c889d4b1d94532135/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2073) (regardless of prefetch) #### [Fairscale FSDP Approach 2](facebookresearch/fairscale#1052) - [Record post-forward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1430-L1431) - [Record pre-backward order](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1521-L1522) - [Use post-forward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1434) (post-forward index + 1) - [Use post-backward prefetching](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L1675) (pre-backward index + 1) - [Prefetch after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2106) - [Synchronize the current stream after freeing the padded unsharded flattened parameter](https://github.com/facebookresearch/fairscale/blob/7d46cba0ac2bc2d69922d75a454c08edf07bb6ce/fairscale/nn/data_parallel/fully_sharded_data_parallel.py#L2132) (regardless of prefetch) #### PT-D FSDP Approach - In `_reshard()`, record a CUDA event after freeing the padded unsharded flattened parameter - In `_unshard()`, before actually unsharding, check if the number of saved free events exceeds a max number and if so, synchronize the earliest event, blocking the CPU thread until that event completes #### T5 (500M) 2 Nodes 16 A100 GPUs 256 Batch Size <details> <summary> `all_gather_issue_limit=None` </summary> ![Screen Shot 2022-08-16 at 4 51 25 PM](https://user-images.githubusercontent.com/31054793/184982990-166e97e9-b0af-4bd7-ae9a-2716bf5b8f48.png) Peak GPU reserved memory: 6784 MB = 6.784 GB Time / batch: 3.4 s </details> <details> <summary> `all_gather_issue_limit=2` </summary> ![Screen Shot 2022-08-16 at 4 51 14 PM](https://user-images.githubusercontent.com/31054793/184983007-5e81ae54-fcb0-4a06-a4af-73f0e52b5949.png) Peak GPU reserved memory: 5846 MB = 5.846 GB Time / batch: 3.4 s </details> [ghstack-poisoned]
Before this change:
Training 175b on 1024 A100s:
After the change:
No idea if the approach is perfect or not, but for above use case, I am getting slightly better WPS (this most likely could be
a flake though, and I'd expect them to match), much better cuda gb reserved (which I can almost explain now, where its being used) and ppl matches.
plus I can train much much larger models now. (will post additional details and results in post).
Will do additional testing