diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 952031351..be874e125 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -311,7 +311,7 @@ def __init__( module: nn.Module, process_group: Optional[ProcessGroup] = None, # The type for the process_group_reduce_scatter only can be either ProcessGroup or ProcessGroupName - process_group_reduce_scatter: Any = ProcessGroupName.reduce_scatter, + process_group_reduce_scatter: Any = ProcessGroupName.default, reshard_after_forward: bool = True, disable_reshard_on_root: bool = True, mixed_precision: bool = False,