Skip to content

Commit

Permalink
bf16 changes and attribute for cpu activations
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoyal2707 committed Jun 28, 2022
1 parent 1915531 commit ba38cf3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
14 changes: 9 additions & 5 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,8 +1386,9 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:

# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
is_bf16 = self.compute_dtype == torch.bfloat16
if self._is_root and self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(True, True, *args, **kwargs)
args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs)

if self not in self._fsdp_forward_ordering:
self._my_fsdp_instance_idx = len(self._fsdp_forward_ordering)
Expand All @@ -1397,7 +1398,7 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
if self.force_input_to_fp32 and not self.mixed_precision:
args, kwargs = cast_floats_to_right_precision(False, False, *args, **kwargs)
args, kwargs = cast_floats_to_right_precision(False, False, is_bf16, *args, **kwargs)

# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
Expand Down Expand Up @@ -2054,6 +2055,7 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
if params is None:
params = self.params
self.has_full_params = False

current_stream = torch.cuda.current_stream()
for p in params:
if not p._is_sharded: # e.g., world_size == 1
Expand Down Expand Up @@ -2182,7 +2184,6 @@ def consolidate_shard_weights(
for n, t, s in zip(names, full_param.split(numels), shapes):
out_state_dict_key = ".".join([fsdp_path, n]) if fsdp_path else n
consolidated_weights[out_state_dict_key] = t.view(s)

# copy shared parameters
for src_path, dest_path in metadata["shared_param_info"]:
consolidated_weights[dest_path] = consolidated_weights[src_path]
Expand Down Expand Up @@ -2462,15 +2463,18 @@ def _get_default_cuda_device(module: nn.Module) -> torch.device:
return torch.device("cuda")


def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
def cast_floats_to_right_precision(to_fp16: bool, no_grad: bool, is_bf16: bool, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
"""
Cast floating point Tensors in *args or **kwargs to FP16 or FP32 if they are not.
We also retain the requires_grad flag so that casting doesn't affect the autograd graph.
"""

def fn_fp16(x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.float32:
y = x.half()
if is_bf16:
y = x.bfloat16()
else:
y = x.half()
if x.is_leaf:
y.requires_grad = x.requires_grad
return y
Expand Down
2 changes: 2 additions & 0 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def _unflatten_params_as_views(self) -> None:
ps = self.get_param_views()
param_views = []
for (_, m, n), p in zip(self._param_infos, ps):
setattr(p, '_fsdp_weight', True)
setattr(m, n, p) # This will set as plain attr
param_views.append(p)

Expand All @@ -382,6 +383,7 @@ def _unflatten_params_as_views(self) -> None:
for (_, _, m, n, shared_m, shared_n) in self._shared_param_infos:
setattr(m, n, getattr(shared_m, shared_n))


@contextmanager
def unflatten_params(self, flat_params: Optional[List[Tensor]] = None) -> Generator:
"""
Expand Down
11 changes: 11 additions & 0 deletions tests/nn/data_parallel/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,17 @@ def test_mixed_precision_autocast_fp32_compute(self):
expected_buffer_type=torch.float32,
)

def test_mixed_precision_bfloat16(self):
self._spawn_test_case(
{"mixed_precision": True, "compute_dtype": torch.bfloat16},
True, # autocast enabled
torch.bfloat16, # expected_input_dtype
torch.bfloat16, # expected_param_dtype
torch.float32, # expected_loss_dtype
torch.bfloat16, # expected_reduce_dtype
expected_buffer_type=torch.float32,
)

def test_fp32_reduce_scatter(self):
self._spawn_test_case(
{"mixed_precision": True, "fp32_reduce_scatter": True},
Expand Down

0 comments on commit ba38cf3

Please sign in to comment.