From 1e41dc8ffa2331ad1556385e5497055998c44b8f Mon Sep 17 00:00:00 2001 From: y-sq Date: Tue, 21 Nov 2023 11:59:45 -0800 Subject: [PATCH] add test cases for type cast --- test/test_base.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/test_base.py b/test/test_base.py index 3f4f36d5..33a4b638 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -177,6 +177,43 @@ def test_linear_float8_weight_tag(self): m_fp8 = Float8Linear.from_float(copy.deepcopy(m_ref)) assert m_fp8.weight._is_fp8_weight + @pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]) + @pytest.mark.parametrize( + "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] + ) + def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype): + emulate = (not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)) + x_shape = (16, 16) + + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) + m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype) + self._test_linear_impl(x, m_ref, linear_type, emulate) + + m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) + m = Float8Linear.from_float(m, emulate) + + # Cast the module to dtype + m = m.to(dtype=linear_dtype) + + # autocast off + x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) + sync_float8_amax_and_scale_history(m) + y = m(x) + assert y.dtype == linear_dtype, f"y.dtype is {y.dtype}, expected {linear_dtype}" + + # autocast on + with torch.autocast("cuda"): + sync_float8_amax_and_scale_history(m) + y = m(x) + assert y.dtype == torch.half, f"y.dtype is {y.dtype}, expected {torch.half}" + + with torch.autocast("cuda", dtype=torch.bfloat16): + sync_float8_amax_and_scale_history(m) + y = m(x) + assert ( + y.dtype == torch.bfloat16 + ), f"y.dtype is {y.dtype}, expected {torch.bfloat16}" + class TestScaledMM: @unittest.skipIf(