diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index bf3e6c018..0cff54acb 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -332,6 +332,7 @@ def __init__( offload_config: Optional[OffloadConfig] = None, state_dict_on_rank_0_only: bool = False, gradient_predivide_factor: Optional[float] = None, + zero2_process_group: Optional[ProcessGroup] = None, ): try: import torch._C @@ -380,6 +381,9 @@ def __init__( "parameter uses all the available ranks for the optimal performance." ) self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward + + self.zero2_process_group = zero2_process_group + self.disable_reshard_on_root = disable_reshard_on_root self.mixed_precision = mixed_precision self.fp32_reduce_scatter = fp32_reduce_scatter @@ -518,6 +522,9 @@ def __init__( if isinstance(m, FullyShardedDataParallel): m._free_ssd_offload() + if self.zero2_process_group is not None: + assert not self.move_params_to_cpu + def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 while world_size % factor == 0 and world_size / factor > factor: @@ -1419,7 +1426,10 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: outputs = self.module(*args, **kwargs) if self.reshard_after_forward: - self._free_full_params() + if self.zero2_process_group is not None: + self._zero2_shard_to_smaller_group() + else: + self._free_full_params() if self.mixed_precision or self.move_params_to_cpu: self._free_fp16_param_shard() @@ -1499,7 +1509,10 @@ def _pre_backward_hook(*unused: Any) -> None: # idempotent. So in case they are called unnecessarily, they don't incur much # overhead. if self.reshard_after_forward: - self._rebuild_full_params() + if self.zero2_process_group is not None: + self._zero2_rebuild_full_params() + else: + self._rebuild_full_params() if ( self.reshard_after_forward and self._fsdp_forward_ordering is not None @@ -2006,6 +2019,126 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) return output_tensors + + @torch.no_grad() + def _zero2_rebuild_full_params(self, force_full_precision: bool = False, wait_for_all_gather = True) -> Optional[List[Tuple[torch.Tensor, bool]]]: + """ + Gather all shards of params. + + Note, this is idempotent if full params are already gathered. Callers + assume the idempotency. So please keep it that way. + + Args: + force_full_precision (bool, Optional): by default params will be gathered + in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is + ``True``, in which case they will be gathered in full precision + (e.g., FP32), possibly in fresh storage. The parameter that's being + rebuilt will end up in full precision as well. + + Returns: + A list of tuples, where the first element is the full-sized param + and the second element is a bool indicating if it's safe for the + caller to free the full-sized param. This will be ``None`` if + ``force_full_precision=False`` and the full params are already gathered. + """ + output_tensors: List[Tuple[torch.Tensor, bool]] = [] + + def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: + """ + Helper function to update p.data pointer. + + Args: + custom_output_tensor (torch.Tensor, Optional): if not None, this + tensor contains the data we just gathered. + """ + if custom_output_tensor is not None: + assert p._is_sharded + p.data = custom_output_tensor + output_tensors.append((p.data, True)) + elif not p._is_sharded: + if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision: + assert p._fp16_shard is not None + p.data = p._fp16_shard + output_tensors.append((p.data, True)) + else: + # Here p.data == p._fp32_shard, so it's not safe to free. + output_tensors.append((p.data, False)) + else: + p.data = p._full_param_padded + output_tensors.append((p.data, True)) + # Trim any padding and reshape to match original size. + p.data = p.data[: p._orig_size.numel()].view(p._orig_size) + + if self._has_shared_params: + # self.has_full_params flag can be out of sync if a shared param is + # sharded by another FSDP instance. An example is that in eval case + # with reshard_after_forward=False but the sharing instance has + # reshard_after_forward=True. Then, on the second forward, the + # other instance can shard the shared param and but this instance + # can mistakenly think the full param is already gathered from the + # has_full_params flag. + # + # Therefore, we update the flag accordingly here. + self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params) + + # Early exit if we already have full params and don't need full precision. + if self.has_full_params and not force_full_precision: + if wait_for_all_gather: + torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) + for p in self.params: + update_p_data() + return output_tensors + + self.has_full_params = True + + with torch.cuda.stream(self._streams["all_gather"]): + + for p in self.params: + if not p._is_sharded: # e.g., when world_size == 1 + update_p_data() + else: + # Skip if already built. Only shared param can be rebuilt multiple times. + # A corner case is p._orig_size = (1,), which means the shape equality is + # not a perfect check. But we assume we don't share a param with shape (1,). + if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared: + continue + # If self.move_params_to_cpu and force_full_precision, we need to cast + # the FP32 CPU param to CUDA for the all-gather. + p_data = p.data.to(p._full_param_padded.device, non_blocking=True) + + p_size = p._full_param_padded.size() + assert p_size.numel() % self.world_size == 0 + if self.mixed_precision and force_full_precision: + # Allocate fresh tensor in full precision since we are in + # mixed precision and full precision rebuild is asked. + output_tensor = p_data.new_zeros(p_size) + else: + if p._full_param_padded.storage().size() != p_size.numel(): + # Allocate based on full size from all shards. + alloc_storage_(p._full_param_padded, size=p_size) + output_tensor = p._full_param_padded + + # Fill output_tensor with (p.data for each shard in self.world_size) + if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives: + # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather. + dist._all_gather_base(output_tensor, p._zero2_fp16_shard , group=self.zero2_process_group) + else: + chunks = list(output_tensor.chunk(self.world_size)) + dist.all_gather(chunks, p._zero2_fp16_shard, group=self.zero2_process_group) + + # Set p.data = output_tensor (with padding trimmed) + update_p_data(output_tensor) + + if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision: + self._free_zero2_param_shard([p]) + + if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype): + self._free_zero2_param_shard([p]) + if wait_for_all_gather: + torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) + return output_tensors + + @torch.no_grad() def _use_full_params(self) -> None: """ @@ -2074,6 +2207,38 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: free_storage_(p._full_param_padded) torch.cuda.current_stream().synchronize() + + def _zero2_shard_to_smaller_group(self, params: Optional[List[Parameter]] = None): + if params is None: + params = self.params + self.has_full_params = False + current_stream = torch.cuda.current_stream() + for p in params: + if not p._is_sharded: # e.g., world_size == 1 + if self.mixed_precision or self.move_params_to_cpu: + self._free_fp16_param_shard([p]) + continue + # Cases for when zero2 world size > 1 but less than zero3 size + zero2_world_size = dist.get_world_size(self.zero2_process_group) + zero2_rank = dist.get_rank(self.zero2_process_group) + chunks = p._full_param_padded.chunk(zero2_world_size) + + p._zero2_fp16_shard = torch.empty_like(chunks[zero2_rank]) + p._zero2_fp16_shard.copy_(chunks[zero2_rank]) + + # Don't let PyTorch reuse this memory until all work in the current + # stream is complete. + p._full_param_padded.record_stream(current_stream) + # There may be external references to the Tensor Storage that we + # can't modify, such as references that are created by + # ctx.save_for_backward in the forward pass. Thus when we + # unshard parameters, we should reuse the original Tensor + # Storage object and unshard it in-place. For now, just resize + # the Storage to 0 to save memory. + free_storage_(p._full_param_padded) + torch.cuda.current_stream().synchronize() + + def local_metadata_dict(self) -> Dict[str, Any]: """ Get the information needed to reconstruct the model from shards offline. @@ -2238,6 +2403,19 @@ def _free_fp16_param_shard(self, params: Optional[List[Parameter]] = None) -> No p._fp16_shard.record_stream(current_stream) free_storage_(p._fp16_shard) + @torch.no_grad() + def _free_zero2_param_shard(self, params: Optional[List[Parameter]] = None) -> None: + """Free storage for FP16 shards for a list of params.""" + if params is None: + params = self.params + current_stream = torch.cuda.current_stream() + for p in params: + if p._zero2_fp16_shard is not None: + # _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't + # free it until the work in the current stream completes. + p._zero2_fp16_shard.record_stream(current_stream) + free_storage_(p._zero2_fp16_shard) + def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None: """Assert we are in the given state.""" # Since assert can be turned off and this error checking