Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
clean up casting: ToFloat8ConstrFunc -> hp_tensor_and_scale_to_float8 (
Browse files Browse the repository at this point in the history
…#348)

Summary:
Pull Request resolved: #348

Moves `ToFloat8ConstrFunc` to private, and creates
`hp_tensor_and_scale_to_float8` as the official wrapper
which clearly describes what this function is doing.

A future PR will rename the scaling-aware functions to match this
naming.

Reviewed By: drisspg

Differential Revision: D60310240

fbshipit-source-id: 954e7c910cee36f2ea0b0d1984fe163862b47ee5
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 26, 2024
1 parent 7e0182f commit 6ac2f82
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 53 deletions.
6 changes: 3 additions & 3 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch
from float8_experimental.float8_tensor import (
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
ToFloat8ConstrFunc,
)
from float8_experimental.float8_utils import pad_tensor_for_matmul
from tabulate import tabulate
Expand Down Expand Up @@ -58,14 +58,14 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
a_config = LinearMMConfig(a_config, a_config, a_config)
b_config = LinearMMConfig(b_config, b_config, b_config)

a_fp8 = ToFloat8ConstrFunc.apply(
a_fp8 = hp_tensor_and_scale_to_float8(
A,
scale_a,
fp8_dtype,
a_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
b_fp8 = hp_tensor_and_scale_to_float8(
B,
scale_b,
fp8_dtype,
Expand Down
10 changes: 5 additions & 5 deletions float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
tensor_already_casted_to_fp8,
ToFloat8ConstrFunc,
)

from float8_experimental.float8_utils import (
Expand All @@ -39,7 +39,7 @@ def cast_to_float8_e4m3_dynamic(
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
return ToFloat8ConstrFunc.apply(
return hp_tensor_and_scale_to_float8(
inpt_tensor,
scale,
e4m3_dtype,
Expand All @@ -58,7 +58,7 @@ def cast_to_float8_delayed(
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
amax_buffer.fill_(tensor_to_amax(tensor))
return ToFloat8ConstrFunc.apply(
return hp_tensor_and_scale_to_float8(
tensor,
scale,
float8_dtype,
Expand Down Expand Up @@ -145,7 +145,7 @@ def backward(ctx, go):

fp8_amax_grad_output.fill_(tensor_to_amax(go))

res = ToFloat8ConstrFunc.apply(
res = hp_tensor_and_scale_to_float8(
go,
fp8_scale_grad_output,
e5m2_dtype,
Expand Down Expand Up @@ -177,7 +177,7 @@ def backward(ctx, gradY):
if tensor_already_casted_to_fp8(gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
fp8_tensor = ToFloat8ConstrFunc.apply(
fp8_tensor = hp_tensor_and_scale_to_float8(
gradY,
gradY_scale,
e5m2_dtype,
Expand Down
43 changes: 31 additions & 12 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:


@torch._dynamo.allow_in_graph
class ToFloat8ConstrFunc(torch.autograd.Function):
class _ToFloat8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion to fp8.
* forward: convert from high precision to float8
Expand All @@ -154,15 +154,6 @@ def forward(
with that composing with FakeTensor, so we special case here.
DTensor Invariant: DTensor must always be the outer most tensor subclass
Args:
tensor: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
linear_mm_config: Defines the configuration for the scaled_mm for
the 3 fwd/bwd gemms of linear
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
"""
tensor_scaled = tensor * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
Expand Down Expand Up @@ -205,7 +196,7 @@ def backward(ctx, g):


@torch._dynamo.allow_in_graph
class FromFloat8ConstrFunc(torch.autograd.Function):
class _FromFloat8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion from fp8.
* forward: convert from float8 to high precision
Expand All @@ -221,6 +212,34 @@ def backward(ctx, g):
return g, None, None


def hp_tensor_and_scale_to_float8(
hp_tensor: torch.Tensor,
s: torch.Tensor,
float8_dtype=e4m3_dtype,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
"""
Given a high precision tensor `hp_tensor` and a precalculated scale `s`,
scales `hp_tensor` by `s` and returns a `Float8Tensor` of the result.
Autograd-aware, the derivative is pass-through.
DTensor-aware, if the input is a DTensor the output will be DTensor(Float8Tensor).
Args:
hp_tensor: the tensor to convert
s: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
linear_mm_config: Defines the configuration for the scaled_mm for
the 3 fwd/bwd gemms of linear
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
"""
return _ToFloat8ConstrFunc.apply(
hp_tensor, s, float8_dtype, linear_mm_config, gemm_input_role
)


class Float8Tensor(torch.Tensor):
"""
Note: this is **not** a public API and is only intended to be used
Expand Down Expand Up @@ -309,7 +328,7 @@ def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride
)

def to_original_precision(self):
return FromFloat8ConstrFunc.apply(self)
return _FromFloat8ConstrFunc.apply(self)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
Expand Down
4 changes: 2 additions & 2 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ToFloat8ConstrFunc,
)

from float8_experimental.float8_utils import e4m3_dtype, EPS
Expand Down Expand Up @@ -167,7 +167,7 @@ def __repr__(self):

def fsdp_pre_all_gather(self, mesh):
if self._precomputed_scale is not None:
float8_tensor = ToFloat8ConstrFunc.apply(
float8_tensor = hp_tensor_and_scale_to_float8(
self._tensor,
self._precomputed_scale,
torch.float8_e4m3fn,
Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
tensor_already_casted_to_fp8,
ToFloat8ConstrFunc,
)
from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale

Expand Down Expand Up @@ -127,7 +127,7 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
self.weight, Float8Tensor
), "Weight has already been quantized, cannot quantize again."
scale = tensor_to_scale(self.weight, dtype)
quantized_weight = ToFloat8ConstrFunc.apply(
quantized_weight = hp_tensor_and_scale_to_float8(
self.weight,
scale,
dtype,
Expand Down Expand Up @@ -200,7 +200,7 @@ def cast_to_float8_e4m3_inference(
if static_quantization_scale is not None
else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
)
return ToFloat8ConstrFunc.apply(
return hp_tensor_and_scale_to_float8(
inpt_tensor,
scale,
e4m3_dtype,
Expand Down
42 changes: 21 additions & 21 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
ToFloat8ConstrFunc,
)
from float8_experimental.float8_utils import (
compute_error,
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_preserves_dtype(self) -> None:
for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes):
x1_hp = torch.randn(4, 4, dtype=hp_dtype)
x1_s = tensor_to_scale(x1_hp, lp_dtype)
x2_lp = ToFloat8ConstrFunc.apply(x1_hp, x1_s, lp_dtype)
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
x3_hp = x2_lp.to_original_precision()
self.assertTrue(x3_hp.dtype == hp_dtype)

Expand All @@ -76,7 +76,7 @@ def test_differentiable_casts(self) -> None:
x = torch.randn(1).requires_grad_()
grad = torch.randn(1)
x_s = tensor_to_scale(x, f8_dtype)
x_f8 = ToFloat8ConstrFunc.apply(x, x_s, f8_dtype)
x_f8 = hp_tensor_and_scale_to_float8(x, x_s, f8_dtype)
x_f8_hp = x_f8.to_original_precision()
x_f8_hp.backward(grad)
# the gradient should be unchanged through both casts
Expand All @@ -85,7 +85,7 @@ def test_differentiable_casts(self) -> None:
def test_split_cat(self):
a = torch.rand(16, 16, dtype=torch.bfloat16)
scale = tensor_to_scale(a, e4m3_dtype)
fp8_a = ToFloat8ConstrFunc.apply(a, scale, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(a, scale, e4m3_dtype)

splits = torch.split(fp8_a, 16)
catted = torch.cat(splits, dim=0)
Expand All @@ -94,14 +94,14 @@ def test_split_cat(self):
def test_index_put(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = ToFloat8ConstrFunc.apply(a, scale_a, torch.float8_e4m3fn)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)

index = torch.randint(0, 15, (16,), dtype=torch.long)

b = torch.rand(16, 16, dtype=torch.bfloat16)
scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
fp8_b = ToFloat8ConstrFunc.apply(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = ToFloat8ConstrFunc.apply(b, scale_b, torch.float8_e4m3fn)
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)

with self.assertRaises(AssertionError):
b[index] = fp8_a
Expand All @@ -112,7 +112,7 @@ def test_index_put(self):
def test_copy_(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = ToFloat8ConstrFunc.apply(a, scale_a, torch.float8_e4m3fn)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)

b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
Expand Down Expand Up @@ -407,8 +407,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = ToFloat8ConstrFunc.apply(a, a_scale, input_dtype)
b_fp8 = ToFloat8ConstrFunc.apply(b, b_scale, input_dtype)
a_fp8 = hp_tensor_and_scale_to_float8(a, a_scale, input_dtype)
b_fp8 = hp_tensor_and_scale_to_float8(b, b_scale, input_dtype)

out_scaled_mm = addmm_float8_unwrapped(
a_fp8._data,
Expand Down Expand Up @@ -447,14 +447,14 @@ def test_different_configs_error(self):
ScaledMMConfig(True, False, False, False),
ScaledMMConfig(True, False, False, False),
)
a = ToFloat8ConstrFunc.apply(
a = hp_tensor_and_scale_to_float8(
x_fp32,
x_scale,
fp8_dtype,
linear_config_a,
GemmInputRole.INPUT,
)
b = ToFloat8ConstrFunc.apply(
b = hp_tensor_and_scale_to_float8(
x_fp32,
x_scale,
fp8_dtype,
Expand Down Expand Up @@ -486,10 +486,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = ToFloat8ConstrFunc.apply(
a_fp8 = hp_tensor_and_scale_to_float8(
a, a_scale, input_dtype, None, GemmInputRole.INPUT
)
b_fp8 = ToFloat8ConstrFunc.apply(
b_fp8 = hp_tensor_and_scale_to_float8(
b, b_scale, input_dtype, None, GemmInputRole.WEIGHT
)

Expand All @@ -506,14 +506,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
scaled_mm_config, scaled_mm_config, scaled_mm_config
)

a_fp8 = ToFloat8ConstrFunc.apply(
a_fp8 = hp_tensor_and_scale_to_float8(
a,
a_scale,
input_dtype,
pad_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
b_fp8 = hp_tensor_and_scale_to_float8(
b,
b_scale,
input_dtype,
Expand All @@ -529,14 +529,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
emulated_scaled_mm_config,
emulated_scaled_mm_config,
)
a_fp8 = ToFloat8ConstrFunc.apply(
a_fp8 = hp_tensor_and_scale_to_float8(
a,
a_scale,
input_dtype,
emulated_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
b_fp8 = hp_tensor_and_scale_to_float8(
b,
b_scale,
input_dtype,
Expand Down Expand Up @@ -695,19 +695,19 @@ def test_fp8_tensor_statistics(self):

# Overflow caused by a too large scaling factor
s_overflow = torch.tensor(1e9)
fp8_overflow = ToFloat8ConstrFunc.apply(x1_hp, s_overflow, lp_dtype)
fp8_overflow = hp_tensor_and_scale_to_float8(x1_hp, s_overflow, lp_dtype)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_overflow, lp_dtype)
self.assertEqual((zero_cnt, max_cnt), (0, tensor_len))

# Underflow caused by a too small scaling factor
s_underflow = torch.tensor(1e-9)
fp8_underflow = ToFloat8ConstrFunc.apply(x1_hp, s_underflow, lp_dtype)
fp8_underflow = hp_tensor_and_scale_to_float8(x1_hp, s_underflow, lp_dtype)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_underflow, lp_dtype)
self.assertEqual((zero_cnt, max_cnt), (tensor_len, 0))

# Both overflow and underflow
x2_hp = torch.cat((x1_hp * 1e9, x1_hp * 1.0, x1_hp * 1e-9), 0)
fp8_over_underflow = ToFloat8ConstrFunc.apply(
fp8_over_underflow = hp_tensor_and_scale_to_float8(
x2_hp, torch.tensor(1.0), lp_dtype
)
(zero_cnt, max_cnt) = fp8_tensor_statistics(fp8_over_underflow, lp_dtype)
Expand Down
Loading

0 comments on commit 6ac2f82

Please sign in to comment.