Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[wip] add scaling granularity #338

Open
wants to merge 2 commits into
base: gh/vkuzo/47/base
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,18 @@ def short_str(self):
return "dyn"


class ScalingGranularity(enum.Enum):
"""
Defines the granularity of scaling strategies for casting to float8
"""

# A single scaling factor for the entire tensor
TENSORWISE = "tensorwise"
# Scaling factors computed along one axis of the tensor, reducing it to
# size 1.
AXISWISE = "axiswise"


@dataclass(frozen=True)
class CastConfig:
"""
7 changes: 6 additions & 1 deletion float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
@@ -4,8 +4,11 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, Union

import torch

from float8_experimental.config import ScalingGranularity
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
@@ -52,10 +55,12 @@ def cast_to_float8_e4m3_dynamic(
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
dim: Optional[Union[int, Tuple[int]]] = None,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax, granularity, dim)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
6 changes: 0 additions & 6 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
@@ -290,12 +290,6 @@ def __new__(
linear_mm_config: Optional[LinearMMConfig],
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
assert (
scale.numel() == 1
), "Scale should contain a single value, but got: {} elements".format(
scale.numel()
)

self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
25 changes: 20 additions & 5 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
@@ -4,9 +4,10 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterable, Literal, Tuple, Union
from typing import Iterable, Literal, Optional, Tuple, Union

import float8_experimental.config as config
from float8_experimental.config import ScalingGranularity

import torch
import torch.distributed as dist
@@ -100,8 +101,18 @@ def amax_history_to_scale_stack(


@torch.no_grad()
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
amax = torch.max(torch.abs(x))
def tensor_to_amax(
x: torch.Tensor,
reduce_amax: bool = False,
granularity: ScalingGranularity = ScalingGranularity.AXISWISE,
dim: Optional[Union[int, Tuple[int]]] = None,
) -> torch.Tensor:
if granularity is ScalingGranularity.TENSORWISE:
amax = torch.max(torch.abs(x))
else:
assert granularity is ScalingGranularity.AXISWISE, "unsupported"
assert dim is not None, "unsupported"
amax = torch.amax(torch.abs(x), dim=dim, keepdim=True)

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
@@ -114,9 +125,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:

@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
x: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
granularity: ScalingGranularity = ScalingGranularity.AXISWISE,
dim: Optional[Union[int, Tuple[int]]] = None,
) -> torch.Tensor:
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
amax = tensor_to_amax(x, reduce_amax=reduce_amax, granularity=granularity, dim=dim)
return amax_to_scale(amax, float8_dtype, x.dtype)


11 changes: 11 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
@@ -143,6 +143,17 @@ def test_weights_only_load(self):
buffer.seek(0)
_ = torch.load(buffer, weights_only=True)

def test_axiswise_dynamic_cast(self):
a = torch.randn(16, 32, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = cast_to_float8_e4m3_dynamic(
a,
linear_mm_config,
granularity=ScalingGranularity.AXISWISE,
dim=0,
)
print(a_fp8)


class TestFloat8Linear:
def _test_linear_impl(
Loading