Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[wip] add scaling granularity
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: d66074dab84512d32b149f9df79987c5133fd293
Pull Request resolved: #338
  • Loading branch information
vkuzo committed Jul 25, 2024
1 parent 4736e44 commit 6bd32e3
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 12 deletions.
12 changes: 12 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ def short_str(self):
return "dyn"


class ScalingGranularity(enum.Enum):
"""
Defines the granularity of scaling strategies for casting to float8
"""

# A single scaling factor for the entire tensor
TENSORWISE = "tensorwise"
# Scaling factors computed along one axis of the tensor, reducing it to
# size 1.
AXISWISE = "axiswise"


@dataclass(frozen=True)
class CastConfig:
"""
Expand Down
7 changes: 6 additions & 1 deletion float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple, Union

import torch

from float8_experimental.config import ScalingGranularity
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand Down Expand Up @@ -52,10 +55,12 @@ def cast_to_float8_e4m3_dynamic(
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
dim: Optional[Union[int, Tuple[int]]] = None,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax, granularity, dim)
return Float8Tensor.to_float8(
inpt_tensor,
scale,
Expand Down
6 changes: 0 additions & 6 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,6 @@ def __new__(
linear_mm_config: Optional[LinearMMConfig],
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
assert (
scale.numel() == 1
), "Scale should contain a single value, but got: {} elements".format(
scale.numel()
)

self = torch.Tensor._make_wrapper_subclass(
cls,
data.size(),
Expand Down
25 changes: 20 additions & 5 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterable, Literal, Tuple, Union
from typing import Iterable, Literal, Optional, Tuple, Union

import float8_experimental.config as config
from float8_experimental.config import ScalingGranularity

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -100,8 +101,18 @@ def amax_history_to_scale_stack(


@torch.no_grad()
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
amax = torch.max(torch.abs(x))
def tensor_to_amax(
x: torch.Tensor,
reduce_amax: bool = False,
granularity: ScalingGranularity = ScalingGranularity.AXISWISE,
dim: Optional[Union[int, Tuple[int]]] = None,
) -> torch.Tensor:
if granularity is ScalingGranularity.TENSORWISE:
amax = torch.max(torch.abs(x))
else:
assert granularity is ScalingGranularity.AXISWISE, "unsupported"
assert dim is not None, "unsupported"
amax = torch.amax(torch.abs(x), dim=dim, keepdim=True)

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
Expand All @@ -114,9 +125,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:

@torch.no_grad()
def tensor_to_scale(
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
x: torch.Tensor,
float8_dtype: torch.dtype,
reduce_amax: bool = False,
granularity: ScalingGranularity = ScalingGranularity.AXISWISE,
dim: Optional[Union[int, Tuple[int]]] = None,
) -> torch.Tensor:
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
amax = tensor_to_amax(x, reduce_amax=reduce_amax, granularity=granularity, dim=dim)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
11 changes: 11 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,17 @@ def test_weights_only_load(self):
buffer.seek(0)
_ = torch.load(buffer, weights_only=True)

def test_axiswise_dynamic_cast(self):
a = torch.randn(16, 32, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = cast_to_float8_e4m3_dynamic(
a,
linear_mm_config,
granularity=ScalingGranularity.AXISWISE,
dim=0,
)
print(a_fp8)


class TestFloat8Linear:
def _test_linear_impl(
Expand Down

0 comments on commit 6bd32e3

Please sign in to comment.