Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add test cases for type cast
Browse files Browse the repository at this point in the history
  • Loading branch information
y-sq committed Nov 21, 2023
1 parent e70b4ed commit 1e41dc8
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,43 @@ def test_linear_float8_weight_tag(self):
m_fp8 = Float8Linear.from_float(copy.deepcopy(m_ref))
assert m_fp8.weight._is_fp8_weight

@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
emulate = (not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0))
x_shape = (16, 16)

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
self._test_linear_impl(x, m_ref, linear_type, emulate)

m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
m = Float8Linear.from_float(m, emulate)

# Cast the module to dtype
m = m.to(dtype=linear_dtype)

# autocast off
x = torch.randn(16, 32, device="cuda", dtype=linear_dtype)
sync_float8_amax_and_scale_history(m)
y = m(x)
assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}"

# autocast on
with torch.autocast("cuda"):
sync_float8_amax_and_scale_history(m)
y = m(x)
assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}"

with torch.autocast("cuda", dtype=torch.bfloat16):
sync_float8_amax_and_scale_history(m)
y = m(x)
assert (
y.dtype == torch.bfloat16
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"


class TestScaledMM:
@unittest.skipIf(
Expand Down

1 comment on commit 1e41dc8

@y-sq
Copy link
Contributor Author

@y-sq y-sq commented on 1e41dc8 Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also verified the new test cases failed on main branch.

Please sign in to comment.