From 603efc2340cf25d368461072589e25624d8f2072 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Mon, 22 Jul 2024 19:52:15 -0700 Subject: [PATCH] change x, w, dL_dY variable names to input, weight, grad_output (#323) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/323 The following naming scheme matches the rest of PyTorch better: ```Python // forward output = input @ weight_t // backward grad_input = grad_output @ weight grad_weight = input_t @ grad_output ``` This PR changes all the previous references to `x`, `w`, `dL_dY` to match the naming scheme above. Reviewed By: drisspg Differential Revision: D60072596 fbshipit-source-id: 74e89d154a698a0dae8c92f39e2267409b151642 --- .github/workflows/ufmt.yml | 3 + README.md | 10 +- benchmarks/bench_linear_float8.py | 40 ++--- benchmarks/bench_multi_gpu.py | 6 +- benchmarks/profile_linear_float8.py | 27 ++-- float8_experimental/float8_linear.py | 149 ++++++++++-------- float8_experimental/float8_linear_utils.py | 124 ++++++++------- float8_experimental/float8_tensor.py | 2 + float8_experimental/float8_tensor_parallel.py | 4 +- float8_experimental/fsdp_utils.py | 3 +- test/test_base.py | 99 ++++++------ test/test_compile.py | 78 ++++----- test/test_dtensor.py | 6 +- test/test_fsdp.py | 8 +- test/test_fsdp2/test_fsdp2.py | 44 +++--- test/test_fsdp2/test_fsdp2_common.py | 6 +- test/test_fsdp_compile.py | 6 +- test/test_inference_flows.py | 12 +- test/test_numerics_integration.py | 27 ++-- 19 files changed, 349 insertions(+), 305 deletions(-) diff --git a/.github/workflows/ufmt.yml b/.github/workflows/ufmt.yml index f4fa6be2..4a5f52be 100644 --- a/.github/workflows/ufmt.yml +++ b/.github/workflows/ufmt.yml @@ -23,4 +23,7 @@ jobs: pip install black==23.3.0 usort==1.0.6 ufmt==2.1.0 libcst==1.0.1 - name: Analyzing the code with ufmt run: | + ufmt format . + git diff + git restore . ufmt check . diff --git a/README.md b/README.md index 3e8e304c..9436696a 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,9 @@ pip install -e ".[dev]" # Single GPU User API -We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`). +We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`). -## float8 linear with dynamic scaling for `x`, `w` and `dL_dY` +## float8 linear with dynamic scaling for `input`, `weight` and `grad_output` This is the most accurate recipe as every tensor is scaled dynamically. @@ -95,9 +95,9 @@ m = Model(...) # type swap_linear_with_float8_linear( m, - scaling_type_x=TensorScalingType.DELAYED, - scaling_type_w=TensorScalingType.DELAYED, - scaling_type_dL_dY=TensorScalingType.DELAYED, + scaling_type_input=TensorScalingType.DELAYED, + scaling_type_weight=TensorScalingType.DELAYED, + scaling_type_grad_output=TensorScalingType.DELAYED, ) # optional: use FSDP. Note that workarounds gated with config.enable_amax_init and diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index 967de570..a931ee88 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -95,16 +95,16 @@ def main( n_limit: Optional[int] = None, fast_accum_filter: Optional[bool] = None, shape_name_filter: Optional[str] = None, - scaling_type_x: str = "dynamic", - scaling_type_w: str = "dynamic", - scaling_type_dL_dY: str = "dynamic", + scaling_type_input: str = "dynamic", + scaling_type_weight: str = "dynamic", + scaling_type_grad_output: str = "dynamic", ): device = "cuda" print(f"Compile is set to | {compile}") - scaling_type_x = TensorScalingType(scaling_type_x) - scaling_type_w = TensorScalingType(scaling_type_w) - scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY) + scaling_type_input = TensorScalingType(scaling_type_input) + scaling_type_weight = TensorScalingType(scaling_type_weight) + scaling_type_grad_output = TensorScalingType(scaling_type_grad_output) # LLaMa 2 70B single-node weight shapes # assumes fused attn.wqkv and ffn.w13 @@ -136,9 +136,9 @@ def main( linear_float8 = Float8Linear.from_float( copy.deepcopy(linear_ref), emulate=False, - scaling_type_x=scaling_type_x, - scaling_type_w=scaling_type_w, - scaling_type_dL_dY=scaling_type_dL_dY, + scaling_type_input=scaling_type_input, + scaling_type_weight=scaling_type_weight, + scaling_type_grad_output=scaling_type_grad_output, ) scaling_repr = linear_float8.scaling_repr() @@ -153,7 +153,9 @@ def main( ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward() def float8_forw_backward(): - if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): + if linear_requires_sync( + scaling_type_input, scaling_type_weight, scaling_type_grad_output + ): sync_float8_amax_and_scale_history(linear_float8) linear_float8(input_tensor).sum().backward() @@ -278,18 +280,18 @@ def invoke_main() -> None: parser.add_argument("-n", "--n_limit", type=int, required=False) parser.add_argument("--fast_accum_filter", type=bool, required=False) parser.add_argument("--shape_name_filter", type=str, required=False) - parser.add_argument("--scaling_type_x", type=str, required=False) - parser.add_argument("--scaling_type_w", type=str, required=False) - parser.add_argument("--scaling_type_dL_dY", type=str, required=False) + parser.add_argument("--scaling_type_input", type=str, required=False) + parser.add_argument("--scaling_type_weight", type=str, required=False) + parser.add_argument("--scaling_type_grad_output", type=str, required=False) args = parser.parse_args() output_path = Path(args.output_path) if args.output_path is not None else None kwargs = {} - if args.scaling_type_x is not None: - kwargs["scaling_type_x"] = args.scaling_type_x - if args.scaling_type_w is not None: - kwargs["scaling_type_w"] = args.scaling_type_w - if args.scaling_type_dL_dY is not None: - kwargs["scaling_type_dL_dY"] = args.scaling_type_dL_dY + if args.scaling_type_input is not None: + kwargs["scaling_type_input"] = args.scaling_type_input + if args.scaling_type_weight is not None: + kwargs["scaling_type_weight"] = args.scaling_type_weight + if args.scaling_type_grad_output is not None: + kwargs["scaling_type_grad_output"] = args.scaling_type_grad_output main( output_path, not args.disable_compile, diff --git a/benchmarks/bench_multi_gpu.py b/benchmarks/bench_multi_gpu.py index 00a549c4..ac196951 100644 --- a/benchmarks/bench_multi_gpu.py +++ b/benchmarks/bench_multi_gpu.py @@ -68,9 +68,9 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32): swap_linear_with_float8_linear( m, emulate=False, - scaling_type_x=TensorScalingType.DELAYED, - scaling_type_w=TensorScalingType.DELAYED, - scaling_type_dL_dY=TensorScalingType.DELAYED, + scaling_type_input=TensorScalingType.DELAYED, + scaling_type_weight=TensorScalingType.DELAYED, + scaling_type_grad_output=TensorScalingType.DELAYED, ) return m diff --git a/benchmarks/profile_linear_float8.py b/benchmarks/profile_linear_float8.py index 503a01a3..c7b8d38b 100644 --- a/benchmarks/profile_linear_float8.py +++ b/benchmarks/profile_linear_float8.py @@ -204,20 +204,23 @@ def profile_function( def main( profile_path_prefix: Path, compile: bool = True, - scaling_type_x: str = "dynamic", - scaling_type_w: str = "dynamic", - scaling_type_dL_dY: str = "dynamic", + scaling_type_input: str = "dynamic", + scaling_type_weight: str = "dynamic", + scaling_type_grad_output: str = "dynamic", model_type: str = "linear", dtype_filter: str = "both", ): assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported" assert dtype_filter in ("both", "float8", "bfloat16") - scaling_type_x = TensorScalingType(scaling_type_x) - scaling_type_w = TensorScalingType(scaling_type_w) - scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY) + scaling_type_input = TensorScalingType(scaling_type_input) + scaling_type_weight = TensorScalingType(scaling_type_weight) + scaling_type_grad_output = TensorScalingType(scaling_type_grad_output) scaling_repr = "_".join( - [s.short_str() for s in (scaling_type_x, scaling_type_w, scaling_type_dL_dY)] + [ + s.short_str() + for s in (scaling_type_input, scaling_type_weight, scaling_type_grad_output) + ] ) print(f"Compile is set to | {compile}") @@ -254,9 +257,9 @@ def main( m_ref = m_ref.to(device).to(ref_dtype) extra_kwargs = { - "scaling_type_x": scaling_type_x, - "scaling_type_w": scaling_type_w, - "scaling_type_dL_dY": scaling_type_dL_dY, + "scaling_type_input": scaling_type_input, + "scaling_type_weight": scaling_type_weight, + "scaling_type_grad_output": scaling_type_grad_output, } m_float8 = copy.deepcopy(m_ref) @@ -278,7 +281,9 @@ def float8_forw_backward_wrapper(x): # inspection of the fw+bw torch.compile without the scale # syncing code # TODO(future): make this better - if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): + if linear_requires_sync( + scaling_type_input, scaling_type_weight, scaling_type_grad_output + ): with record_function("scale_amax_and_scales"): sync_amax_history(m_float8) out = float8_forw(x) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 37cf0d5d..de7f6117 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -82,14 +82,16 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function): def forward( ctx, tensor, - fp8_amax_dL_dY, - fp8_amax_history_dL_dY, - fp8_scale_dL_dY, + fp8_amax_grad_output, + fp8_amax_history_grad_output, + fp8_scale_grad_output, scale_fn_name, is_amax_initialized, linear_mm_config: LinearMMConfig, ): - ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY) + ctx.save_for_backward( + fp8_amax_grad_output, fp8_amax_history_grad_output, fp8_scale_grad_output + ) ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized ctx.linear_mm_config = linear_mm_config @@ -97,26 +99,30 @@ def forward( @staticmethod def backward(ctx, go): - fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY = ctx.saved_tensors + ( + fp8_amax_grad_output, + fp8_amax_history_grad_output, + fp8_scale_grad_output, + ) = ctx.saved_tensors scale_fn_name = ctx.scale_fn_name is_amax_initialized = ctx.is_amax_initialized _maybe_initialize_amaxes_scales_for_float8_cast( go, - fp8_amax_dL_dY, - fp8_amax_history_dL_dY, - fp8_scale_dL_dY, + fp8_amax_grad_output, + fp8_amax_history_grad_output, + fp8_scale_grad_output, scale_fn_name, e5m2_dtype, is_amax_initialized, reduce_amax=True, ) - fp8_amax_dL_dY.fill_(tensor_to_amax(go)) + fp8_amax_grad_output.fill_(tensor_to_amax(go)) res = to_fp8_no_autograd( go, - fp8_scale_dL_dY, + fp8_scale_grad_output, e5m2_dtype, linear_mm_config=ctx.linear_mm_config, gemm_input_role=GemmInputRole.DL_DY, @@ -164,9 +170,9 @@ def __init__(self, *args, **kwargs): """ Additional arguments on top of `torch.nn.Linear`'s arguments: * `delayed_scaling_recipe`: configuration for delayed scaling - * `scaling_type_x`: delayed vs dynamic scaling for `x` - * `scaling_type_w`: delayed vs dynamic scaling for `w` - * `scaling_type_dL_dY`: delayed vs dynamic scaling for `dL_dY` + * `scaling_type_input`: delayed vs dynamic scaling for `input` + * `scaling_type_weight`: delayed vs dynamic scaling for `weight` + * `scaling_type_grad_output`: delayed vs dynamic scaling for `grad_output` """ delayed_scaling_recipe = kwargs.pop( @@ -175,20 +181,24 @@ def __init__(self, *args, **kwargs): # Amax scales should always be kept as float32. self.always_float32_buffers = set() emulate = kwargs.pop("emulate", False) - scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DYNAMIC) - scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DYNAMIC) - scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DYNAMIC) + scaling_type_input = kwargs.pop("scaling_type_input", TensorScalingType.DYNAMIC) + scaling_type_weight = kwargs.pop( + "scaling_type_weight", TensorScalingType.DYNAMIC + ) + scaling_type_grad_output = kwargs.pop( + "scaling_type_grad_output", TensorScalingType.DYNAMIC + ) super().__init__(*args, **kwargs) - # Defines the scaling behavior of x, w, dL_dY - self.scaling_type_x = scaling_type_x - self.scaling_type_w = scaling_type_w - self.scaling_type_dL_dY = scaling_type_dL_dY + # Defines the scaling behavior of input, weight, grad_output + self.scaling_type_input = scaling_type_input + self.scaling_type_weight = scaling_type_weight + self.scaling_type_grad_output = scaling_type_grad_output # Convenience flag to skip code related to delayed scaling self.has_any_delayed_scaling = ( - self.scaling_type_x is TensorScalingType.DELAYED - or self.scaling_type_w is TensorScalingType.DELAYED - or self.scaling_type_dL_dY is TensorScalingType.DELAYED + self.scaling_type_input is TensorScalingType.DELAYED + or self.scaling_type_weight is TensorScalingType.DELAYED + or self.scaling_type_grad_output is TensorScalingType.DELAYED ) # TODO(future): have a unique recipe per buffer instead of one per @@ -200,15 +210,15 @@ def __init__(self, *args, **kwargs): # TODO(future): user level configuration of gemms self.linear_mm_config = LinearMMConfig( - # x + # input ScaledMMConfig( emulate, True if not emulate else False, False, config.pad_inner_dim ), - # w + # weight ScaledMMConfig( emulate, True if not emulate else False, False, config.pad_inner_dim ), - # dL_dY + # grad_output ScaledMMConfig(emulate, False, False, config.pad_inner_dim), ) @@ -239,9 +249,9 @@ def create_buffers(self): device = self.weight.device # TODO(future PR): dtype values below don't have the other float8 # flavors, fix it - default_x = torch.finfo(torch.float8_e4m3fn).max - default_w = torch.finfo(torch.float8_e4m3fn).max - default_dl_dy = torch.finfo(torch.float8_e5m2).max + default_input = torch.finfo(torch.float8_e4m3fn).max + default_weight = torch.finfo(torch.float8_e4m3fn).max + default_grad_output = torch.finfo(torch.float8_e5m2).max # Note: for now, create all the buffers if any are needed, to postpone # the work to make the scale and amax syncing and history calculation @@ -249,31 +259,32 @@ def create_buffers(self): # show it is worth doing. if self.has_any_delayed_scaling: self.register_always_float32_buffer( - "fp8_amax_x", torch.tensor([default_x], device=device) + "fp8_amax_input", torch.tensor([default_input], device=device) ) self.register_always_float32_buffer( - "fp8_amax_history_x", torch.zeros(history_len, device=device) + "fp8_amax_history_input", torch.zeros(history_len, device=device) ) self.register_always_float32_buffer( - "fp8_scale_x", torch.tensor([1.0], device=device) + "fp8_scale_input", torch.tensor([1.0], device=device) ) self.register_always_float32_buffer( - "fp8_amax_w", torch.tensor([default_w], device=device) + "fp8_amax_weight", torch.tensor([default_weight], device=device) ) self.register_always_float32_buffer( - "fp8_amax_history_w", torch.zeros(history_len, device=device) + "fp8_amax_history_weight", torch.zeros(history_len, device=device) ) self.register_always_float32_buffer( - "fp8_scale_w", torch.tensor([1.0], device=device) + "fp8_scale_weight", torch.tensor([1.0], device=device) ) self.register_always_float32_buffer( - "fp8_amax_dL_dY", torch.tensor([default_dl_dy], device=device) + "fp8_amax_grad_output", + torch.tensor([default_grad_output], device=device), ) self.register_always_float32_buffer( - "fp8_amax_history_dL_dY", torch.zeros(history_len, device=device) + "fp8_amax_history_grad_output", torch.zeros(history_len, device=device) ) self.register_always_float32_buffer( - "fp8_scale_dL_dY", torch.tensor([1.0], device=device) + "fp8_scale_grad_output", torch.tensor([1.0], device=device) ) def register_always_float32_buffer( @@ -303,13 +314,13 @@ def cast_x_to_float8( autocast_dtype = torch.get_autocast_gpu_dtype() x = x.to(autocast_dtype) - if self.scaling_type_x is TensorScalingType.DELAYED: + if self.scaling_type_input is TensorScalingType.DELAYED: scale_fn_name = self.recipe.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( x, - self.fp8_amax_x, - self.fp8_amax_history_x, - self.fp8_scale_x, + self.fp8_amax_input, + self.fp8_amax_history_input, + self.fp8_scale_input, scale_fn_name, e4m3_dtype, is_amax_initialized, @@ -317,30 +328,30 @@ def cast_x_to_float8( ) x_fp8 = Float8Tensor.to_float8( x, - self.fp8_scale_x, + self.fp8_scale_input, e4m3_dtype, - self.fp8_amax_x, + self.fp8_amax_input, linear_mm_config=self.linear_mm_config, gemm_input_role=GemmInputRole.X, ) else: - assert self.scaling_type_x is TensorScalingType.DYNAMIC + assert self.scaling_type_input is TensorScalingType.DYNAMIC x_fp8 = cast_to_float8_e4m3_dynamic(x, self.linear_mm_config) return x_fp8 def cast_w_to_float8( self, w: torch.Tensor, is_amax_initialized: bool ) -> torch.Tensor: - if self.scaling_type_w is TensorScalingType.DELAYED: + if self.scaling_type_weight is TensorScalingType.DELAYED: if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: scale_fn_name = self.recipe.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( w, - self.fp8_amax_w, - self.fp8_amax_history_w, - self.fp8_scale_w, + self.fp8_amax_weight, + self.fp8_amax_history_weight, + self.fp8_scale_weight, scale_fn_name, e4m3_dtype, is_amax_initialized, @@ -349,14 +360,14 @@ def cast_w_to_float8( w_fp8 = Float8Tensor.to_float8( w, - self.fp8_scale_w, + self.fp8_scale_weight, e4m3_dtype, - self.fp8_amax_w, + self.fp8_amax_weight, linear_mm_config=self.linear_mm_config, gemm_input_role=GemmInputRole.W, ) else: - assert self.scaling_type_w is TensorScalingType.DYNAMIC + assert self.scaling_type_weight is TensorScalingType.DYNAMIC if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: @@ -366,19 +377,19 @@ def cast_w_to_float8( return w_fp8 def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: - if self.scaling_type_dL_dY is TensorScalingType.DELAYED: + if self.scaling_type_grad_output is TensorScalingType.DELAYED: scale_fn_name = self.recipe.scale_fn_name y = NoopFwToFloat8E5M2Bw.apply( y, - self.fp8_amax_dL_dY, - self.fp8_amax_history_dL_dY, - self.fp8_scale_dL_dY, + self.fp8_amax_grad_output, + self.fp8_amax_history_grad_output, + self.fp8_scale_grad_output, scale_fn_name, self.is_amax_initialized, self.linear_mm_config, ) else: - assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC + assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC y = cast_to_float8_e5m2_dynamic_bw(y, self.linear_mm_config) return y @@ -425,7 +436,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: def scaling_repr(self): # add scaling settings without using too many characters # example: "x:del,w:del,dldy:dyn" - return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}" + return f"x:{self.scaling_type_input.short_str()},w:{self.scaling_type_weight.short_str()},dldy:{self.scaling_type_grad_output.short_str()}" def extra_repr(self): s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"' @@ -436,9 +447,9 @@ def from_float( cls, mod, emulate: bool = False, - scaling_type_x=TensorScalingType.DYNAMIC, - scaling_type_w=TensorScalingType.DYNAMIC, - scaling_type_dL_dY=TensorScalingType.DYNAMIC, + scaling_type_input=TensorScalingType.DYNAMIC, + scaling_type_weight=TensorScalingType.DYNAMIC, + scaling_type_grad_output=TensorScalingType.DYNAMIC, ): """ Create an nn.Linear with fp8 compute from a regular nn.Linear @@ -452,9 +463,9 @@ def from_float( mod.in_features, mod.out_features, bias=False, - scaling_type_x=scaling_type_x, - scaling_type_w=scaling_type_w, - scaling_type_dL_dY=scaling_type_dL_dY, + scaling_type_input=scaling_type_input, + scaling_type_weight=scaling_type_weight, + scaling_type_grad_output=scaling_type_grad_output, emulate=emulate, ) new_mod.weight = mod.weight @@ -469,7 +480,7 @@ def from_float( # 2. buffers need to be already created for the delayed scaling version # of the weight wrapper to be initialized if config.enable_fsdp_fp8_all_gather: - if scaling_type_w is TensorScalingType.DYNAMIC: + if scaling_type_weight is TensorScalingType.DYNAMIC: new_mod.weight = torch.nn.Parameter( WeightWithDynamicFloat8CastTensor( new_mod.weight, @@ -477,13 +488,13 @@ def from_float( ) ) else: - assert scaling_type_w is TensorScalingType.DELAYED + assert scaling_type_weight is TensorScalingType.DELAYED new_mod.weight = torch.nn.Parameter( WeightWithDelayedFloat8CastTensor( new_mod.weight, - new_mod.fp8_amax_w, - new_mod.fp8_amax_history_w, - new_mod.fp8_scale_w, + new_mod.fp8_amax_weight, + new_mod.fp8_amax_history_weight, + new_mod.fp8_scale_weight, new_mod.linear_mm_config, new_mod.is_amax_initialized, ) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 26364b67..8140baa0 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -23,16 +23,16 @@ def linear_requires_sync( - scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC, - scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC, - scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_input: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_weight: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_grad_output: TensorScalingType = TensorScalingType.DYNAMIC, ): """Returns whether the given linear_type requires sync before forward.""" return any( [ - scaling_type_x is TensorScalingType.DELAYED, - scaling_type_w is TensorScalingType.DELAYED, - scaling_type_dL_dY is TensorScalingType.DELAYED, + scaling_type_input is TensorScalingType.DELAYED, + scaling_type_weight is TensorScalingType.DELAYED, + scaling_type_grad_output is TensorScalingType.DELAYED, ] ) @@ -132,9 +132,9 @@ def swap_linear_with_float8_linear( *, emulate: bool = False, module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, - scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC, - scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC, - scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_input: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_weight: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_grad_output: TensorScalingType = TensorScalingType.DYNAMIC, ) -> Optional[nn.Module]: """ Swaps `torch.nn.Linear` in `module` with `Float8Linear`. @@ -145,9 +145,9 @@ def swap_linear_with_float8_linear( module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that that pass the filter function will be swapped. The inputs to the filter function are the FQN and module instance. - scaling_type_x (TensorScalingType): scaling type for `x` - scaling_type_w (TensorScalingType): scaling type for `w` - scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY` + scaling_type_input (TensorScalingType): scaling type for `input` + scaling_type_weight (TensorScalingType): scaling type for `weight` + scaling_type_grad_output (TensorScalingType): scaling type for `grad_output` Returns: nn.Module: The modified module with swapped linear layers. @@ -155,9 +155,9 @@ def swap_linear_with_float8_linear( from_float = lambda m: Float8Linear.from_float( m, emulate=emulate, - scaling_type_x=scaling_type_x, - scaling_type_w=scaling_type_w, - scaling_type_dL_dY=scaling_type_dL_dY, + scaling_type_input=scaling_type_input, + scaling_type_weight=scaling_type_weight, + scaling_type_grad_output=scaling_type_grad_output, ) return swap_linear_layers( module, @@ -228,25 +228,25 @@ def inner_func(): the inner function will not. """ # Loop over all fp8 layers and grab the needed tensors - fp8_amax_x_tensor_list = [None] * len(fp8_layers) - fp8_amax_w_tensor_list = [None] * len(fp8_layers) - fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers) + fp8_amax_input_tensor_list = [None] * len(fp8_layers) + fp8_amax_weight_tensor_list = [None] * len(fp8_layers) + fp8_amax_grad_output_tensor_list = [None] * len(fp8_layers) - fp8_x_amax_history_stack = [None] * len(fp8_layers) - fp8_w_amax_history_stack = [None] * len(fp8_layers) - fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers) + fp8_input_amax_history_stack = [None] * len(fp8_layers) + fp8_weight_amax_history_stack = [None] * len(fp8_layers) + fp8_grad_output_amax_history_stack = [None] * len(fp8_layers) x_dtypes = set() scale_fn_recipes = set() for idx, child in enumerate(fp8_layers): - fp8_amax_x_tensor_list[idx] = child.fp8_amax_x - fp8_amax_w_tensor_list[idx] = child.fp8_amax_w - fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY + fp8_amax_input_tensor_list[idx] = child.fp8_amax_input + fp8_amax_weight_tensor_list[idx] = child.fp8_amax_weight + fp8_amax_grad_output_tensor_list[idx] = child.fp8_amax_grad_output - fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x - fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w - fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY + fp8_input_amax_history_stack[idx] = child.fp8_amax_history_input + fp8_weight_amax_history_stack[idx] = child.fp8_amax_history_weight + fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output x_dtypes.add(child.last_seen_input_dtype) scale_fn_recipes.add(child.recipe.scale_fn_name) @@ -265,16 +265,16 @@ def inner_func(): scale_fn_recipe = next(iter(scale_fn_recipes)) assert ( - len(fp8_amax_x_tensor_list) - == len(fp8_amax_w_tensor_list) - == len(fp8_amax_dL_dY_tensor_list) + len(fp8_amax_input_tensor_list) + == len(fp8_amax_weight_tensor_list) + == len(fp8_amax_grad_output_tensor_list) ), "Mismatched lengths of amax tensors." if dist.is_initialized(): all_amax_tensors = torch.cat( - fp8_amax_x_tensor_list - + fp8_amax_w_tensor_list - + fp8_amax_dL_dY_tensor_list + fp8_amax_input_tensor_list + + fp8_amax_weight_tensor_list + + fp8_amax_grad_output_tensor_list ) all_reduced_amax_tensor = all_reduce( all_amax_tensors, "MAX", list(range(dist.get_world_size())) @@ -283,46 +283,52 @@ def inner_func(): all_reduced_amax_tensor = all_reduced_amax_tensor.wait() ( - reduced_fp8_amax_x_tensor, - reduced_fp8_amax_w_tensor, - reduced_fp8_amax_dL_dY_tensor, - ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) + reduced_fp8_amax_input_tensor, + reduced_fp8_amax_weight_tensor, + reduced_fp8_amax_grad_output_tensor, + ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_input_tensor_list)) for idx, child in enumerate(fp8_layers): - child.fp8_amax_x.copy_(reduced_fp8_amax_x_tensor[idx]) - child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) - child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) + child.fp8_amax_input.copy_(reduced_fp8_amax_input_tensor[idx]) + child.fp8_amax_weight.copy_(reduced_fp8_amax_weight_tensor[idx]) + child.fp8_amax_grad_output.copy_( + reduced_fp8_amax_grad_output_tensor[idx] + ) # We create two stacked tensor groups, one for the amax history and one for the current scales - fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list) - fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list) - fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list) - - fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack) - fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack) - fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack) + fp8_amax_input_tensors = torch.vstack(fp8_amax_input_tensor_list) + fp8_amax_weight_tensors = torch.vstack(fp8_amax_weight_tensor_list) + fp8_amax_grad_output_tensors = torch.vstack(fp8_amax_grad_output_tensor_list) + + fp8_input_amax_history_stack = torch.vstack(fp8_input_amax_history_stack) + fp8_weight_amax_history_stack = torch.vstack(fp8_weight_amax_history_stack) + fp8_grad_output_amax_history_stack = torch.vstack( + fp8_grad_output_amax_history_stack + ) # Update the history stacks with the new amax values - _update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack) - _update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack) - _update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack) + _update_history_stack(fp8_amax_input_tensors, fp8_input_amax_history_stack) + _update_history_stack(fp8_amax_weight_tensors, fp8_weight_amax_history_stack) + _update_history_stack( + fp8_amax_grad_output_tensors, fp8_grad_output_amax_history_stack + ) # Calculate the new scales from the updated history stacks - new_x_scales = amax_history_to_scale_stack( - fp8_x_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe + new_input_scales = amax_history_to_scale_stack( + fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe ) - new_w_scales = amax_history_to_scale_stack( - fp8_w_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe + new_weight_scales = amax_history_to_scale_stack( + fp8_weight_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe ) - new_dL_dY_scales = amax_history_to_scale_stack( - fp8_dL_dY_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe + new_grad_output_scales = amax_history_to_scale_stack( + fp8_grad_output_amax_history_stack, e5m2_dtype, x_dtype, scale_fn_recipe ) # Iterate through the layers and update the scales for idx, child in enumerate(fp8_layers): - child.fp8_scale_x.copy_(new_x_scales[idx]) - child.fp8_scale_w.copy_(new_w_scales[idx]) - child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx]) + child.fp8_scale_input.copy_(new_input_scales[idx]) + child.fp8_scale_weight.copy_(new_weight_scales[idx]) + child.fp8_scale_grad_output.copy_(new_grad_output_scales[idx]) # This allows for the compile to succede on the inner func and fail on the graph breaks # at the beginning and and of syncing diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 475a17a2..bc2b9eaf 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -22,6 +22,8 @@ # # A note on configuration of float8 logic in a linear # TODO(future): move all the configs to separate file +# TODO(future): change this to input/weight/grad_output notation, +# can be separate PR because none of this is user facing # # There are three gemms in a forward + backward of a Linear layer: # diff --git a/float8_experimental/float8_tensor_parallel.py b/float8_experimental/float8_tensor_parallel.py index 4c5297cf..42f88ffd 100644 --- a/float8_experimental/float8_tensor_parallel.py +++ b/float8_experimental/float8_tensor_parallel.py @@ -28,8 +28,8 @@ def _float8_linear_supports_float8_allgather(m): # TODO(future): add support for delayed scaling for activations # and gradients return ( - m.scaling_type_x == TensorScalingType.DYNAMIC - and m.scaling_type_dL_dY == TensorScalingType.DYNAMIC + m.scaling_type_input == TensorScalingType.DYNAMIC + and m.scaling_type_grad_output == TensorScalingType.DYNAMIC ) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 04cd797e..5a231ee2 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -37,7 +37,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: from torch.distributed._tensor import DTensor if any( - isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED + isinstance(m, Float8Linear) + and m.scaling_type_weight is TensorScalingType.DELAYED for m in module.modules() ): raise NotImplementedError("Only supports delayed scaling") diff --git a/test/test_base.py b/test/test_base.py index 7c212e6f..6841c6f3 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -148,19 +148,21 @@ def _test_linear_impl( x, m_ref, emulate: bool, - scaling_type_x: TensorScalingType = TensorScalingType.DELAYED, - scaling_type_w: TensorScalingType = TensorScalingType.DELAYED, - scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED, + scaling_type_input: TensorScalingType = TensorScalingType.DELAYED, + scaling_type_weight: TensorScalingType = TensorScalingType.DELAYED, + scaling_type_grad_output: TensorScalingType = TensorScalingType.DELAYED, ): m_fp8 = Float8Linear.from_float( copy.deepcopy(m_ref), emulate, - scaling_type_x, - scaling_type_w, - scaling_type_dL_dY, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, ) for _ in range(2): - if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): + if linear_requires_sync( + scaling_type_input, scaling_type_weight, scaling_type_grad_output + ): sync_float8_amax_and_scale_history(m_fp8) y_fp8 = m_fp8(x) y_fp8.sum().backward() @@ -178,24 +180,26 @@ def _test_linear_impl( torch.testing.assert_close(m_ref.bias.grad, m_fp8.bias.grad) # verify all of the amax buffers got updated - if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): + if linear_requires_sync( + scaling_type_input, scaling_type_weight, scaling_type_grad_output + ): # only check buffers that are actually used, based on per-tensor # scaling settings amax_buffer_names = [] amax_history_buffer_names = [] scale_buffer_names = [] - if scaling_type_x is TensorScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_x") - amax_history_buffer_names.append("fp8_amax_history_x") - scale_buffer_names.append("fp8_scale_x") - if scaling_type_w is TensorScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_w") - amax_history_buffer_names.append("fp8_amax_history_w") - scale_buffer_names.append("fp8_scale_w") - if scaling_type_dL_dY is TensorScalingType.DELAYED: - amax_buffer_names.append("fp8_amax_dL_dY") - amax_history_buffer_names.append("fp8_amax_history_dL_dY") - scale_buffer_names.append("fp8_scale_dL_dY") + if scaling_type_input is TensorScalingType.DELAYED: + amax_buffer_names.append("fp8_amax_input") + amax_history_buffer_names.append("fp8_amax_history_input") + scale_buffer_names.append("fp8_scale_input") + if scaling_type_weight is TensorScalingType.DELAYED: + amax_buffer_names.append("fp8_amax_weight") + amax_history_buffer_names.append("fp8_amax_history_weight") + scale_buffer_names.append("fp8_scale_weight") + if scaling_type_grad_output is TensorScalingType.DELAYED: + amax_buffer_names.append("fp8_amax_grad_output") + amax_history_buffer_names.append("fp8_amax_history_grad_output") + scale_buffer_names.append("fp8_scale_grad_output") # verify all of the amax buffers got updated max_float8_pos = {torch.finfo(dtype).max for dtype in FP8_TYPES} @@ -224,13 +228,14 @@ def _test_linear_impl( @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]) @pytest.mark.parametrize( - "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", + [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], ) @pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("linear_bias", [False, True]) @@ -239,9 +244,9 @@ def test_linear( self, x_shape, emulate: bool, - scaling_type_x: TensorScalingType, - scaling_type_w: TensorScalingType, - scaling_type_dL_dY: TensorScalingType, + scaling_type_input: TensorScalingType, + scaling_type_weight: TensorScalingType, + scaling_type_grad_output: TensorScalingType, linear_dtype: torch.dtype, linear_bias: bool, ): @@ -260,9 +265,9 @@ def test_linear( x, m_ref, emulate, - scaling_type_x, - scaling_type_w, - scaling_type_dL_dY, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, ) @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @@ -287,9 +292,9 @@ def test_autocast_outputs( m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) kwargs = { - "scaling_type_x": TensorScalingType.DELAYED, - "scaling_type_w": TensorScalingType.DELAYED, - "scaling_type_dL_dY": TensorScalingType.DELAYED, + "scaling_type_input": TensorScalingType.DELAYED, + "scaling_type_weight": TensorScalingType.DELAYED, + "scaling_type_grad_output": TensorScalingType.DELAYED, } m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate, **kwargs) @@ -327,9 +332,9 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) kwargs = { - "scaling_type_x": TensorScalingType.DYNAMIC, - "scaling_type_w": TensorScalingType.DYNAMIC, - "scaling_type_dL_dY": TensorScalingType.DYNAMIC, + "scaling_type_input": TensorScalingType.DYNAMIC, + "scaling_type_weight": TensorScalingType.DYNAMIC, + "scaling_type_grad_output": TensorScalingType.DYNAMIC, } m = Float8Linear.from_float(copy.deepcopy(m), emulate, **kwargs) @@ -338,15 +343,15 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): if linear_requires_sync(**kwargs): # Check amax buffer types for key in [ - "fp8_amax_x", - "fp8_amax_history_x", - "fp8_scale_x", - "fp8_amax_w", - "fp8_amax_history_w", - "fp8_scale_w", - "fp8_amax_dL_dY", - "fp8_amax_history_dL_dY", - "fp8_scale_dL_dY", + "fp8_amax_input", + "fp8_amax_history_input", + "fp8_scale_input", + "fp8_amax_weight", + "fp8_amax_history_weight", + "fp8_scale_weight", + "fp8_amax_grad_output", + "fp8_amax_history_grad_output", + "fp8_scale_grad_output", ]: assert ( m._buffers[key].dtype == torch.float32 @@ -379,9 +384,9 @@ def test_repr(self): m = Float8Linear.from_float( copy.deepcopy(m), emulate=True, - scaling_type_x=TensorScalingType.DYNAMIC, - scaling_type_w=TensorScalingType.DELAYED, - scaling_type_dL_dY=TensorScalingType.DYNAMIC, + scaling_type_input=TensorScalingType.DYNAMIC, + scaling_type_weight=TensorScalingType.DELAYED, + scaling_type_grad_output=TensorScalingType.DYNAMIC, ) s = m.__repr__() assert "x:dyn,w:del,dldy:dyn" in s diff --git a/test/test_compile.py b/test/test_compile.py index f72425d8..59dadf84 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -32,9 +32,9 @@ def _test_compile_base( backend: str, fullgraph: bool, emulate: bool, - scaling_type_x, - scaling_type_w, - scaling_type_dL_dY, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, dtype: torch.dtype, ): random.seed(0) @@ -48,9 +48,9 @@ def _test_compile_base( m_fp8 = Float8Linear.from_float( copy.deepcopy(m_ref), emulate, - scaling_type_x, - scaling_type_w, - scaling_type_dL_dY, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, ) m_fp8 = torch.compile(m_fp8, backend=backend, fullgraph=fullgraph) @@ -68,13 +68,13 @@ def _test_compile_base( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize( - "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @@ -82,9 +82,9 @@ def _test_compile_base( def test_eager_only( fullgraph, emulate: bool, - scaling_type_x: TensorScalingType, - scaling_type_w: TensorScalingType, - scaling_type_dL_dY: TensorScalingType, + scaling_type_input: TensorScalingType, + scaling_type_weight: TensorScalingType, + scaling_type_grad_output: TensorScalingType, dtype: torch.dtype, ): torch._dynamo.reset() @@ -92,9 +92,9 @@ def test_eager_only( "eager", fullgraph, emulate, - scaling_type_x, - scaling_type_w, - scaling_type_dL_dY, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, dtype, ) @@ -102,22 +102,22 @@ def test_eager_only( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize( - "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager( fullgraph, emulate: bool, - scaling_type_x: TensorScalingType, - scaling_type_w: TensorScalingType, - scaling_type_dL_dY: TensorScalingType, + scaling_type_input: TensorScalingType, + scaling_type_weight: TensorScalingType, + scaling_type_grad_output: TensorScalingType, dtype: torch.dtype, ): torch._dynamo.reset() @@ -125,9 +125,9 @@ def test_aot_eager( "aot_eager", fullgraph, emulate, - scaling_type_x, - scaling_type_w, - scaling_type_dL_dY, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, dtype, ) @@ -135,22 +135,22 @@ def test_aot_eager( @pytest.mark.parametrize("fullgraph", [True]) @pytest.mark.parametrize("emulate", [False]) @pytest.mark.parametrize( - "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_inductor( fullgraph, emulate: bool, - scaling_type_x: TensorScalingType, - scaling_type_w: TensorScalingType, - scaling_type_dL_dY: TensorScalingType, + scaling_type_input: TensorScalingType, + scaling_type_weight: TensorScalingType, + scaling_type_grad_output: TensorScalingType, dtype: torch.dtype, ): torch._dynamo.reset() @@ -158,9 +158,9 @@ def test_inductor( "inductor", fullgraph, emulate, - scaling_type_x, - scaling_type_w, - scaling_type_dL_dY, + scaling_type_input, + scaling_type_weight, + scaling_type_grad_output, dtype, ) @@ -257,9 +257,9 @@ def test_sync_amax_func(): ) float8_mod = swap_linear_with_float8_linear( module, - scaling_type_x=TensorScalingType.DELAYED, - scaling_type_w=TensorScalingType.DELAYED, - scaling_type_dL_dY=TensorScalingType.DELAYED, + scaling_type_input=TensorScalingType.DELAYED, + scaling_type_weight=TensorScalingType.DELAYED, + scaling_type_grad_output=TensorScalingType.DELAYED, ) compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts) compiled_swap_func(float8_mod) @@ -292,9 +292,9 @@ def test_sync_amax_func_cuda_graph_success(): ).to("cuda") swap_linear_with_float8_linear( my_module, - scaling_type_x=TensorScalingType.DELAYED, - scaling_type_w=TensorScalingType.DELAYED, - scaling_type_dL_dY=TensorScalingType.DELAYED, + scaling_type_input=TensorScalingType.DELAYED, + scaling_type_weight=TensorScalingType.DELAYED, + scaling_type_grad_output=TensorScalingType.DELAYED, ) inpt = torch.randn( 16, 16, device="cuda", dtype=torch.float32, requires_grad=True diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 1cd14dba..ade3cfd5 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -185,9 +185,9 @@ def _test_fp8_mlp_tensor_parallelism_base( # TODO(future): add support for float8 all-gather with delayed scaling # for activations and gradients. extra_kwargs = { - "scaling_type_x": TensorScalingType.DYNAMIC, - "scaling_type_w": TensorScalingType.DYNAMIC, - "scaling_type_dL_dY": TensorScalingType.DYNAMIC, + "scaling_type_input": TensorScalingType.DYNAMIC, + "scaling_type_weight": TensorScalingType.DYNAMIC, + "scaling_type_grad_output": TensorScalingType.DYNAMIC, } toy_model = ToyModel().to(device) diff --git a/test/test_fsdp.py b/test/test_fsdp.py index 48b28daa..2624e7b7 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -73,18 +73,18 @@ def fsdp_main(rank, world_size, args): model = get_model(K, N, base_dtype=base_dtype).to(rank) model_fp8 = copy.deepcopy(model) - scaling_type_w = ( + scaling_type_weight = ( TensorScalingType.DYNAMIC if use_weight_dynamic_scaling else TensorScalingType.DELAYED ) - # Note: we only iterate over `scaling_type_w` because FSDP only interacts + # Note: we only iterate over `scaling_type_weight` because FSDP only interacts # with weights. swap_linear_with_float8_linear( model_fp8, emulate=False, - scaling_type_w=scaling_type_w, + scaling_type_weight=scaling_type_weight, ) # To compile FSDP, we need use_orig_params to True @@ -132,7 +132,7 @@ def forward_backward(model, optim, is_fp8, i): y_local.backward(ref_grad_local[i]) if is_fp8 and linear_requires_sync( TensorScalingType.DYNAMIC, - scaling_type_w, + scaling_type_weight, TensorScalingType.DYNAMIC, ): sync_float8_func(model) diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index 1cbec778..f4826193 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -85,7 +85,7 @@ def test_transformer_parity(self): { "enable_fsdp_fp8_all_gather": [False, True], "precompute": [False, True], - "scaling_type_w": [ + "scaling_type_weight": [ TensorScalingType.DYNAMIC, TensorScalingType.DELAYED, ], @@ -98,12 +98,12 @@ def _test_transformer_parity( self, enable_fsdp_fp8_all_gather: bool, precompute: bool, - scaling_type_w: TensorScalingType, + scaling_type_weight: TensorScalingType, compile_transformer_block: bool, ): if not enable_fsdp_fp8_all_gather and precompute: return - elif scaling_type_w is TensorScalingType.DELAYED and precompute: + elif scaling_type_weight is TensorScalingType.DELAYED and precompute: return # NOTE: Weight-tying does not compose with fp8 all-gather because the @@ -113,13 +113,17 @@ def _test_transformer_parity( weight_tying = not enable_fsdp_fp8_all_gather module = self.init_transformer(weight_tying=weight_tying).cuda() ref_module = copy.deepcopy(module) - swap_linear_with_float8_linear(ref_module, scaling_type_w=scaling_type_w) + swap_linear_with_float8_linear( + ref_module, scaling_type_weight=scaling_type_weight + ) if compile_transformer_block: for layer_id, transformer_block in ref_module.layers.named_children(): transformer_block = torch.compile(transformer_block, dynamic=False) ref_module.layers.register_module(layer_id, transformer_block) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - swap_linear_with_float8_linear(module, scaling_type_w=scaling_type_w) + swap_linear_with_float8_linear( + module, scaling_type_weight=scaling_type_weight + ) for layer_id, transformer_block in module.layers.named_children(): if compile_transformer_block: transformer_block = torch.compile(transformer_block, dynamic=False) @@ -139,7 +143,7 @@ def _test_transformer_parity( optim, local_inp, precompute, - scaling_type_w=scaling_type_w, + scaling_type_weight=scaling_type_weight, compile_transformer_block=compile_transformer_block, ) @@ -276,9 +280,9 @@ def world_size(self) -> int: @unittest.skipIf(not TEST_CUDA, "no cuda") def test_weight_subclass_dynamic(self): extra_kwargs = { - "scaling_type_x": TensorScalingType.DYNAMIC, - "scaling_type_w": TensorScalingType.DYNAMIC, - "scaling_type_dL_dY": TensorScalingType.DYNAMIC, + "scaling_type_input": TensorScalingType.DYNAMIC, + "scaling_type_weight": TensorScalingType.DYNAMIC, + "scaling_type_grad_output": TensorScalingType.DYNAMIC, } tensor_cls = WeightWithDynamicFloat8CastTensor # Check for a single FSDP paramter group @@ -405,16 +409,16 @@ def test_fp32_fp8_single_module_parity(self): [False, True], [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], ) - for enable_fsdp_fp8_all_gather, scaling_type_w in choices: + for enable_fsdp_fp8_all_gather, scaling_type_weight in choices: module_fp32 = self.init_single_module() ref_module = copy.deepcopy(module_fp32) ref_module = swap_linear_with_float8_linear( - ref_module, scaling_type_w=scaling_type_w + ref_module, scaling_type_weight=scaling_type_weight ) ref_module = ref_module.cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): module = swap_linear_with_float8_linear( - module_fp32, scaling_type_w=scaling_type_w + module_fp32, scaling_type_weight=scaling_type_weight ) fully_shard(module) ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) @@ -427,7 +431,7 @@ def test_fp32_fp8_single_module_parity(self): module, optim, local_inp, - scaling_type_w=scaling_type_w, + scaling_type_weight=scaling_type_weight, ) @unittest.skipIf(not TEST_CUDA, "no cuda") @@ -440,15 +444,15 @@ def test_fp32_fp8_multi_module_parity(self): [False, True], [TensorScalingType.DYNAMIC, TensorScalingType.DELAYED], ) - for enable_fsdp_fp8_all_gather, scaling_type_w in choices: + for enable_fsdp_fp8_all_gather, scaling_type_weight in choices: module = self.init_multi_module().cuda() ref_module = copy.deepcopy(module) ref_module = swap_linear_with_float8_linear( - ref_module, scaling_type_w=scaling_type_w + ref_module, scaling_type_weight=scaling_type_weight ) with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): module = swap_linear_with_float8_linear( - module, scaling_type_w=scaling_type_w + module, scaling_type_weight=scaling_type_weight ) for submodule in module: fully_shard(submodule) @@ -463,7 +467,7 @@ def test_fp32_fp8_multi_module_parity(self): module, optim, local_inp, - scaling_type_w=scaling_type_w, + scaling_type_weight=scaling_type_weight, ) @unittest.skipIf(not TEST_CUDA, "no cuda") @@ -507,13 +511,13 @@ def test_delayed_scaling_inplace_update(self): with set_enable_fsdp_fp8_all_gather(True): m_fp8 = swap_linear_with_float8_linear( module, - scaling_type_w=TensorScalingType.DELAYED, + scaling_type_weight=TensorScalingType.DELAYED, ) - fp8_amax_w_old = m_fp8.fp8_amax_w.clone().detach() + fp8_amax_weight_old = m_fp8.fp8_amax_weight.clone().detach() dummy_mesh = None data, scale = m_fp8.weight.fsdp_pre_all_gather(dummy_mesh) - self.assertNotEqual(fp8_amax_w_old.item(), m_fp8.fp8_amax_w.item()) + self.assertNotEqual(fp8_amax_weight_old.item(), m_fp8.fp8_amax_weight.item()) if __name__ == "__main__": diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 61edac91..7140cecb 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -22,7 +22,7 @@ def check_parity_no_mp( fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, precompute: bool = False, - scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC, + scaling_type_weight: TensorScalingType = TensorScalingType.DYNAMIC, compile_transformer_block: bool = False, ): for iter_idx in range(10): @@ -36,14 +36,14 @@ def check_parity_no_mp( dist.all_reduce(param.grad) param.grad.div_(dist.get_world_size()) - if linear_requires_sync(scaling_type_w=scaling_type_w): + if linear_requires_sync(scaling_type_weight=scaling_type_weight): sync_float8_amax_and_scale_history(model) optim.step() if ( model is fsdp_model and precompute - and scaling_type_w is TensorScalingType.DYNAMIC + and scaling_type_weight is TensorScalingType.DYNAMIC ): precompute_float8_dynamic_scale_for_fsdp(model) diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index 3f1b5dcb..f1681dbb 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -52,9 +52,9 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): swap_linear_with_float8_linear( m, emulate=emulate, - scaling_type_x=TensorScalingType.DELAYED, - scaling_type_w=TensorScalingType.DELAYED, - scaling_type_dL_dY=TensorScalingType.DELAYED, + scaling_type_input=TensorScalingType.DELAYED, + scaling_type_weight=TensorScalingType.DELAYED, + scaling_type_grad_output=TensorScalingType.DELAYED, ) return m diff --git a/test/test_inference_flows.py b/test/test_inference_flows.py index 55543aef..2e3b8e45 100644 --- a/test/test_inference_flows.py +++ b/test/test_inference_flows.py @@ -193,9 +193,9 @@ def test_fp8_save_and_load(self, dtype: torch.dtype): fp8_mlp.reset_parameters() swap_linear_with_float8_linear( fp8_mlp, - scaling_type_x=TensorScalingType.DYNAMIC, - scaling_type_w=TensorScalingType.DYNAMIC, - scaling_type_dL_dY=TensorScalingType.DYNAMIC, + scaling_type_input=TensorScalingType.DYNAMIC, + scaling_type_weight=TensorScalingType.DYNAMIC, + scaling_type_grad_output=TensorScalingType.DYNAMIC, ) # Train the model @@ -217,9 +217,9 @@ def test_fp8_save_and_load(self, dtype: torch.dtype): new_fp8_mlp = FeedForward().to(dtype=dtype) swap_linear_with_float8_linear( new_fp8_mlp, - scaling_type_x=TensorScalingType.DYNAMIC, - scaling_type_w=TensorScalingType.DYNAMIC, - scaling_type_dL_dY=TensorScalingType.DYNAMIC, + scaling_type_input=TensorScalingType.DYNAMIC, + scaling_type_weight=TensorScalingType.DYNAMIC, + scaling_type_grad_output=TensorScalingType.DYNAMIC, ) # Load the actual data diff --git a/test/test_numerics_integration.py b/test/test_numerics_integration.py index 845c9ea6..59b8204a 100644 --- a/test/test_numerics_integration.py +++ b/test/test_numerics_integration.py @@ -75,21 +75,22 @@ def init_weights(self, init_std: float): class TestFloat8NumericsIntegrationTest: @pytest.mark.parametrize( - "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_input", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_weight", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] ) @pytest.mark.parametrize( - "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + "scaling_type_grad_output", + [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], ) @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") def test_encoder_fw_bw( self, - scaling_type_x: TensorScalingType, - scaling_type_w: TensorScalingType, - scaling_type_dL_dY: TensorScalingType, + scaling_type_input: TensorScalingType, + scaling_type_weight: TensorScalingType, + scaling_type_grad_output: TensorScalingType, ): # TODO(later): maybe add float16 back if it becomes important data_dtype = torch.bfloat16 @@ -111,9 +112,9 @@ def test_encoder_fw_bw( swap_linear_with_float8_linear( model_fp8, emulate=False, - scaling_type_x=scaling_type_x, - scaling_type_w=scaling_type_w, - scaling_type_dL_dY=scaling_type_dL_dY, + scaling_type_input=scaling_type_input, + scaling_type_weight=scaling_type_weight, + scaling_type_grad_output=scaling_type_grad_output, ) lr = 0.01 @@ -135,13 +136,17 @@ def test_encoder_fw_bw( model_ref_out = model_ref(data2) model_ref_out.sum().backward() - if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): + if linear_requires_sync( + scaling_type_input, scaling_type_weight, scaling_type_grad_output + ): sync_float8_amax_and_scale_history(model_fp8) model_fp8(data1).sum().backward() # zero out grads without stepping, since we just want to compare grads # of the second datum optim_fp8.zero_grad() - if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY): + if linear_requires_sync( + scaling_type_input, scaling_type_weight, scaling_type_grad_output + ): sync_float8_amax_and_scale_history(model_fp8) model_fp8_out = model_fp8(data2) model_fp8_out.sum().backward()