diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 29a965a..428d5c9 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -207,7 +207,6 @@ def forward( tensor: torch.Tensor, scale: torch.Tensor, float8_dtype=e4m3_dtype, - # amax_buffer: Optional[torch.Tensor] = None, linear_mm_config: Optional[LinearMMConfig] = None, gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT, ): @@ -216,11 +215,8 @@ def forward( tensor: the tensor to convert scale: the scale to use to convert the tensor float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn - amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion emulate: whether to emulate the matmuls in fp32 """ - # if amax_buffer is not None: - # amax_buffer.fill_(tensor_to_amax(tensor)) return to_fp8_no_autograd( tensor,