diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index dee596e9a..7de3d0251 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -41,8 +41,10 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter +import transformer_engine.pytorch as te from fairscale.nn.misc import FlattenParamsWrapper +from fairscale.nn.misc.flatten_params_wrapper import FlatParameter from fairscale.nn.wrap import auto_wrap, config_auto_wrap_policy, enable_wrap from fairscale.utils.containers import apply_to_tensors from fairscale.utils.parallel import ( @@ -55,6 +57,8 @@ from fairscale.utils.params import calc_grad_norm, recursive_copy_to_device from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.state_dict import replace_by_prefix_ +from transformer_engine.pytorch.cpp_extensions import cast_to_fp8, DType, FP8FwdTensors +from transformer_engine.pytorch.fp8 import amax_and_scale_update, FP8GlobalStateManager from . import fsdp_optim_utils as ou @@ -149,6 +153,11 @@ class OffloadConfig: # Path to the directory for storing parameters offloaded to disk. dir: Optional[str] = None +def _is_te_module_with_weights(m: nn.Module) -> bool: + return isinstance(m, (te.Linear, te.LayerNormLinear, te.LayerNormMLP)) + +def _is_fp8_dtype(dtype: torch.dtype) -> bool: + return dtype in [torch.float8_e5m2, torch.float8_e4m3fn] class FullyShardedDataParallel(nn.Module): """ @@ -487,9 +496,22 @@ def __init__( non_flatten_params = params param_name_groups = [[n] for n in param_names] if self.flatten_parameters: - to_be_flatten_params = [params] - non_flatten_params = [] - param_name_groups = [param_names] + to_be_flatten_params = [ + [ + params[i] + for i in range(len(params)) + if "norm_weight" not in param_names[i] + ] + ] + non_flatten_params = [ + params[i] + for i in range(len(params)) + if "norm_weight" in param_names[i] + ] + param_name_groups = [ + [n for n in param_names if "norm_weight" not in n], + [n for n in param_names if "norm_weight" in n], + ] del param_names self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper( @@ -558,6 +580,9 @@ def __init__( self._all_gather_free_event_queue = _FreeEventQueue() if limit_all_gather_events else None self._reduce_scatter_free_event_queue = _FreeEventQueue() if limit_reduce_scatter_events else None + def _is_fp8_dtype(self) -> bool: + return _is_fp8_dtype(self.compute_dtype) + def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 while world_size % factor == 0 and world_size / factor > factor: @@ -687,7 +712,7 @@ def _cast_buffers( @property def params_with_grad(self) -> List[Parameter]: """[p for p in self.parameters() if p.grad is not None]""" - return [p for p in self.parameters() if p.grad is not None] + return [p for p in self.parameters() if p.grad is not None or getattr(p, "main_grad", None) is not None] @torch.no_grad() def clip_grad_norm_( @@ -790,7 +815,9 @@ def _shard_parameters_(self) -> None: assert p.dtype == torch.float32 # If world_size is 1, then we all-reduce grads instead of sharding. - p._is_sharded = self.world_size > 1 + p._is_sharded = self.world_size > 1 and ( + not self._is_fp8_dtype() or isinstance(p, FlatParameter) + ) p._orig_size = p.data.size() if not p._is_sharded: @@ -1174,7 +1201,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge non_shared_params ), f"{len(full_tensors)} vs. {len(non_shared_params)}" for p, (full_tensor, safe_to_free) in zip(non_shared_params, full_tensors): - if not volatile: + if not volatile and p._is_sharded: # Copy any changes made to the full params back into # the corresponding local shards. local_shard, _ = self._get_shard(full_tensor) @@ -1294,7 +1321,15 @@ def _init_param_attributes(self, p: Parameter) -> None: # storage to size 0 at init (here) and re-materialize (by copying # from _fp32_shard) as needed. If offloading params to CPU, the # dtype of the fp16 shard will depend on the *`compute_dtype`*. - p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype) + if self._is_fp8_dtype() and not isinstance(p, FlatParameter): + # assume non flattened are precision critical like norm + assert not p._is_sharded + dtype = torch.bfloat16 + else: + dtype = self.compute_dtype + p._fp16_shard = torch.zeros_like( + p._fp32_shard, device=self.compute_device, dtype=dtype + ) free_storage_(p._fp16_shard) if self.mixed_precision: @@ -1312,8 +1347,16 @@ def _init_param_attributes(self, p: Parameter) -> None: # world_size, although these padding elements will be removed before the # relevant computation. if p._is_sharded: + if self._is_fp8_dtype() and not isinstance(p, FlatParameter): + # assume non flattened are precision critical like norm + assert not p._is_sharded + dtype = torch.bfloat16 + else: + dtype = self.compute_dtype p._full_param_padded = torch.zeros( - p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype + p.data.numel() * self.world_size, + device=self.compute_device, + dtype=dtype ) free_storage_(p._full_param_padded) @@ -1428,13 +1471,19 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: # For root and mixed precision, we convert the input to FP16 (no_grad is needed for # the conversion). - is_bf16 = self.compute_dtype == torch.bfloat16 + is_bf16 = self.compute_dtype in [ + torch.bfloat16, + torch.float8_e5m2, + torch.float8_e4m3fn, + ] if self._is_root and self.mixed_precision: args, kwargs = cast_floats_to_right_precision(True, True, is_bf16, *args, **kwargs) + just_added_to_fsdp_forward_ordering = False if self not in self._fsdp_forward_ordering: self._my_fsdp_instance_idx = len(self._fsdp_forward_ordering) self._fsdp_forward_ordering.append(self) + just_added_to_fsdp_forward_ordering = True # If enabled, convert the input to FP32 if we are in full precision. # no_grad is not used because the input might be for a non-root instance, @@ -1442,8 +1491,50 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: if self.force_input_to_fp32 and not self.mixed_precision: args, kwargs = cast_floats_to_right_precision(False, False, is_bf16, *args, **kwargs) + # need to use fp32_to_fp16 stream since _cast_fp32_param_shards_to_fp16 + # depends on this block. + with torch.no_grad(), torch.cuda.stream(self._streams["fp32_to_fp16"]): + # Collect parameters to update scale/scale_inv before we + # _cast_fp32_param_shards_to_fp16 that uses fp8 scale to quantize + # before all-gather. + # These include params we prefetch all-gather. + params = [] + if self._my_fsdp_instance_idx < len(self._fsdp_forward_ordering) - 1: + if self._my_fsdp_instance_idx == 0 and self._is_fp8_dtype(): + # The first FSDP instance didn't have chance to prefetch + params = self.params + if self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1]._is_fp8_dtype(): + # FSDP instance we'll prefetch all-gather + params.extend(self._fsdp_forward_ordering[self._my_fsdp_instance_idx + 1].params) + elif just_added_to_fsdp_forward_ordering: + # In the first iteration, we didn't have chance to record + # fsdp_instance_idx to prefetch + if self._is_fp8_dtype(): + params = self.params + + for p in params: + if not isinstance(p, FlatParameter): + continue + d = {info[0]: info[1] for info in p._param_infos} + for n, m in d.items(): + # Previous iteration was grad_enabled + if m.fp8_meta.get("update_amax_and_scale_fwd", False): + if m.fp8_meta["recipe"].reduce_amax: + FP8GlobalStateManager.copy_amax_from_global_buffer( + m.fp8_meta, forward=True + ) + # FIXME update_weight_scale_inv is only True for the first micro-batch + amax_and_scale_update(m.fp8_meta, True) + FP8GlobalStateManager.set_amax_buffer_key_deletion( + m.fp8_meta, forward=True + ) + else: + amax_and_scale_update(m.fp8_meta, True) + # All-gather full parameters. This will also transfer FP32 parameters to # ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``). + self.module.is_first_batch = not self.is_not_first_batch if hasattr(self, "is_not_first_batch") else True + self.is_not_first_batch = True self._rebuild_full_params() if ( @@ -1680,7 +1771,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # then subsequent hook callbacks will see POST state. self.assert_state([TrainingState.BACKWARD_PRE, TrainingState.BACKWARD_POST]) self.training_state = TrainingState.BACKWARD_POST - if param.grad is None: + grad_or_main_grad = ( + param.main_grad if getattr(param, "main_grad", None) is not None else param.grad + ) + if grad_or_main_grad is None: return if hasattr(param, "_linked_param"): @@ -1693,8 +1787,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if hasattr(param._linked_param, "_is_shared") and param._linked_param._is_shared: param = param._linked_param - assert param.grad is not None, param.shape - if param.grad.requires_grad: + assert grad_or_main_grad is not None, param.shape + if grad_or_main_grad.requires_grad: raise RuntimeError("FSDP only works with gradients that don't require gradients") if self._require_backward_grad_sync or self.reshard_after_forward: @@ -1705,11 +1799,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # bandwidth but uses more GPU memory. self._free_full_params([param]) - if self.mixed_precision: - # This is a no-op if reshard_after_forward is True, since we already - # free the param shard when rebuilding the full params in the - # pre_backward_hook. - self._free_fp16_param_shard([param]) + # if self.mixed_precision: + # # This is a no-op if reshard_after_forward is True, since we already + # # free the param shard when rebuilding the full params in the + # # pre_backward_hook. + # self._free_fp16_param_shard([param]) # Switch to FP32 shard after backward. self._use_fp32_param_shard([param]) @@ -1721,15 +1815,19 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): - orig_grad_data = param.grad.data + orig_grad_data = grad_or_main_grad - if self.fp32_reduce_scatter: - # Cast grad to FP32. - param.grad.data = param.grad.data.float() + if self.mixed_precision: + if self.fp32_reduce_scatter: + # Cast grad to FP32. + grad_or_main_grad.data = grad_or_main_grad.to(param.dtype) + elif self._is_fp8_dtype(): + # Use bf16 wgrad for fp8 weights (TODO: handle fp8 wgrad) + grad_or_main_grad.data = grad_or_main_grad.to(torch.bfloat16) if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.gradient_predivide_factor) + grad_or_main_grad.div_(self.gradient_predivide_factor) if param._is_sharded: assert self._reducer is not None @@ -1737,7 +1835,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - grad = param.grad.data + if hasattr(param, "main_grad"): + grad = param.main_grad + param.main_grad = None + else: + grad = param.grad # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. # # The effect on memory consumption is not usually significant. No extra memory is allocated if this @@ -1758,8 +1860,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Currently the only way for _is_sharded to be False is if # world_size == 1. This could be relaxed in the future, in which # case grads should be all-reduced here. - assert self.world_size == 1 - self._post_reduction_hook(param, param.grad.data) + # assert self.world_size == 1 + self._post_reduction_hook(param, grad_or_main_grad) # After _post_backward_hook returns, orig_grad_data will eventually # go out of scope, at which point it could otherwise be freed for @@ -1886,6 +1988,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: # again after post-backward if p.shape != p._saved_grad_shard.shape: self._use_fp32_param_shard([p]) + assert getattr(p, "main_grad", None) is None if p._saved_grad_shard.dtype != p.dtype: p.grad = p._saved_grad_shard.to(p.dtype) else: @@ -1987,6 +2090,7 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: custom_output_tensor (torch.Tensor, Optional): if not None, this tensor contains the data we just gathered. """ + p_fp16_shard_size = -1 if custom_output_tensor is not None: assert p._is_sharded p.data = custom_output_tensor @@ -1995,6 +2099,7 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision: assert p._fp16_shard is not None p.data = p._fp16_shard + p_fp16_shard_size = p._fp16_shard.storage().size() output_tensors.append((p.data, True)) else: # Here p.data == p._fp32_shard, so it's not safe to free. @@ -2136,6 +2241,17 @@ def _prep_grads_for_backward(self) -> None: right shape, device, accumulated values, etc. """ for p in self.params: + if isinstance(p, FlatParameter) and all( + _is_te_module_with_weights(info[1]) for info in p._param_infos + ): + if getattr(p, "main_grad", None) is None: + p.main_grad = torch.empty_like(p, dtype=torch.float) + main_grad_views = p.get_param_views(p.main_grad) + for (_, m, n), main_grad in zip(p._param_infos, main_grad_views): + if torch.distributed.get_rank() == 5: + getattr(m, n).assigned = True + getattr(m, n).main_grad = main_grad + if p.grad is not None: if p.grad.device != p.data.device: p.grad = None @@ -2157,6 +2273,7 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: """Free up storage for full parameters.""" if params is None: params = self.params + self.is_not_first_batch = False self.has_full_params = False current_stream = torch.cuda.current_stream() @@ -2203,8 +2320,8 @@ def local_metadata_dict(self) -> Dict[str, Any]: backing_param_name = m.module.flat_param_names[i] names, shapes, numels = m.module.metadata(i) else: - assert len(m._param_name_groups[i]) == 1 - backing_param_name = m._param_name_groups[i][0] + # assert len(m._param_name_groups[i]) == 1 + backing_param_name = m._param_name_groups[m._num_flatten_params][i - m._num_flatten_params] names = [backing_param_name] shapes = [p._orig_size] numels = [p._orig_size.numel()] @@ -2320,12 +2437,50 @@ def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = No for p in params: assert p._fp16_shard is not None alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) - p._fp16_shard.copy_( - # If move_params_to_cpu is True, this will be non-blocking - # because _fp32_shard is pinned, otherwise it's a no-op. - p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) - ) - p.data = p._fp16_shard + if self._is_fp8_dtype() and _is_fp8_dtype(p._fp16_shard.dtype): + # fp8 quantization + assert isinstance(p, FlatParameter) + assert len(p._param_infos) == len(p._param_numels) + numel_per_shard = p.numel() + offset = -numel_per_shard * self.rank + for i in range(len(p._param_infos)): + _, m, n = p._param_infos[i] + numel = p._param_numels[i] + if offset + numel <= 0 or offset >= numel_per_shard: + offset += numel + continue + assert _is_te_module_with_weights(m) + fp8_dtype_forward = te.fp8.get_fp8_te_dtype( + m.fp8_meta["recipe"], fprop_tensor=True + ) + if not m.fp8_initialized: + m.fp8_init( + num_gemms=2 if isinstance(m, te.LayerNormMLP) else 1 + ) + begin = max(offset, 0) + end = min(offset + numel, numel_per_shard) + cast_to_fp8( + p._fp32_shard[begin:end], + m.fp8_meta["scaling_fwd"], + FP8FwdTensors.GEMM2_WEIGHT + if n == "fc2_weight" + else FP8FwdTensors.GEMM1_WEIGHT, + fp8_dtype_forward, + out=p._fp16_shard[begin:end], + ) + offset += numel + p.data = p._fp16_shard.view( + torch.float8_e4m3fn + if fp8_dtype_forward == DType.kFloat8E4M3 + else torch.float8_e5m2 + ) + else: + p._fp16_shard.copy_( + # If move_params_to_cpu is True, this will be non-blocking + # because _fp32_shard is pinned, otherwise it's a no-op. + p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) + ) + p.data = p._fp16_shard torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) @torch.no_grad() diff --git a/fairscale/nn/misc/flatten_params_wrapper.py b/fairscale/nn/misc/flatten_params_wrapper.py index 38265dd2b..2adc22396 100644 --- a/fairscale/nn/misc/flatten_params_wrapper.py +++ b/fairscale/nn/misc/flatten_params_wrapper.py @@ -486,7 +486,9 @@ def load_state_dict( return super().load_state_dict(state_dict, strict) def forward(self, *inputs: Any, **kwinputs: Any) -> Any: - self._unflatten_params_as_views() + is_first_batch = self.is_first_batch if hasattr(self, "is_first_batch") else False + if is_first_batch: + self._unflatten_params_as_views() return self.module(*inputs, **kwinputs) def get_param_views(self, external_data_list: Optional[List[Optional[Tensor]]] = None) -> Iterator[Tensor]: