diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 5d1bf9f..22d49ee 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -21,6 +21,18 @@ def short_str(self): return "dyn" +class ScalingGranularity(enum.Enum): + """ + Defines the granularity of scaling strategies for casting to float8 + """ + + # A single scaling factor for the entire tensor + TENSORWISE = "tensorwise" + # Scaling factors computed along one axis of the tensor, reducing it to + # size 1. + AXISWISE = "axiswise" + + @dataclass(frozen=True) class CastConfig: """ diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index bfacd65..1a5c467 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -4,8 +4,11 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional, Tuple, Union + import torch +from float8_experimental.config import ScalingGranularity from float8_experimental.float8_tensor import ( Float8Tensor, GemmInputRole, @@ -52,10 +55,12 @@ def cast_to_float8_e4m3_dynamic( linear_mm_config: LinearMMConfig, reduce_amax: bool = False, gemm_input_role: GemmInputRole = GemmInputRole.INPUT, + granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, + dim: Optional[Union[int, Tuple[int]]] = None, ) -> Float8Tensor: if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor - scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax) + scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax, granularity, dim) return Float8Tensor.to_float8( inpt_tensor, scale, diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index a46e7ce..13b1ccb 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -290,12 +290,6 @@ def __new__( linear_mm_config: Optional[LinearMMConfig], gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): - assert ( - scale.numel() == 1 - ), "Scale should contain a single value, but got: {} elements".format( - scale.numel() - ) - self = torch.Tensor._make_wrapper_subclass( cls, data.size(), diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 2be568e..7d5ed53 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,9 +4,10 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Tuple, Union +from typing import Iterable, Literal, Optional, Tuple, Union import float8_experimental.config as config +from float8_experimental.config import ScalingGranularity import torch import torch.distributed as dist @@ -100,8 +101,18 @@ def amax_history_to_scale_stack( @torch.no_grad() -def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: - amax = torch.max(torch.abs(x)) +def tensor_to_amax( + x: torch.Tensor, + reduce_amax: bool = False, + granularity: ScalingGranularity = ScalingGranularity.AXISWISE, + dim: Optional[Union[int, Tuple[int]]] = None, +) -> torch.Tensor: + if granularity is ScalingGranularity.TENSORWISE: + amax = torch.max(torch.abs(x)) + else: + assert granularity is ScalingGranularity.AXISWISE, "unsupported" + assert dim is not None, "unsupported" + amax = torch.amax(torch.abs(x), dim=dim, keepdim=True) # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -114,9 +125,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor: @torch.no_grad() def tensor_to_scale( - x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False + x: torch.Tensor, + float8_dtype: torch.dtype, + reduce_amax: bool = False, + granularity: ScalingGranularity = ScalingGranularity.AXISWISE, + dim: Optional[Union[int, Tuple[int]]] = None, ) -> torch.Tensor: - amax = tensor_to_amax(x, reduce_amax=reduce_amax) + amax = tensor_to_amax(x, reduce_amax=reduce_amax, granularity=granularity, dim=dim) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/test/test_base.py b/test/test_base.py index 2f7c717..8c27c23 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -143,6 +143,17 @@ def test_weights_only_load(self): buffer.seek(0) _ = torch.load(buffer, weights_only=True) + def test_axiswise_dynamic_cast(self): + a = torch.randn(16, 32, dtype=torch.bfloat16) + linear_mm_config = LinearMMConfig() + a_fp8 = cast_to_float8_e4m3_dynamic( + a, + linear_mm_config, + granularity=ScalingGranularity.AXISWISE, + dim=0, + ) + print(a_fp8) + class TestFloat8Linear: def _test_linear_impl(