diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index fd76a8e..81a5c52 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -71,6 +71,62 @@ def _maybe_initialize_amaxes_scales_for_float8_cast( scale.copy_(new_scale) +# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files +@torch._dynamo.allow_in_graph +class manual_float8_matmul(torch.autograd.Function): + """ + Like torch.matmul, but with the arguments in float8 + """ + + @staticmethod + def forward( + ctx, + input_fp8, + weight_fp8_t, + ): + ctx.save_for_backward(input_fp8, weight_fp8_t) + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + orig_shape = input_fp8.shape + input_fp8_reshaped = input_fp8.reshape(-1, orig_shape[-1]) + res_bits = torch.mm(input_fp8_reshaped, weight_fp8_t) + res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) + return res_bits + + @staticmethod + def backward(ctx, grad_output_fp8): + input_fp8, weight_fp8_t = ctx.saved_tensors + + # the reshapes are needed in order to make the shapes compatible with + # torch.mm + grad_output_fp8_orig_shape = grad_output_fp8.shape + grad_output_fp8_reshaped = grad_output_fp8.reshape( + -1, grad_output_fp8_orig_shape[-1] + ) + + # calculate grad_input + grad_input = torch.mm( + grad_output_fp8_reshaped, + weight_fp8_t.t(), + ) + grad_input = grad_input.reshape( + *grad_output_fp8_orig_shape[:-1], grad_input.shape[-1] + ) + + input_fp8_orig_shape = input_fp8.shape + input_fp8_reshaped = input_fp8.reshape(-1, input_fp8_orig_shape[-1]) + + # calculate grad_weight + # Note: the variant below is slightly faster on LLaMa 3 8B pretraining + # compared to than calculating `grad_weight_t = input_fp8_t @ grad_output_fp8_reshaped` + grad_weight = torch.mm( + grad_output_fp8_reshaped.t(), + input_fp8_reshaped, + ) + + return grad_input, grad_weight.t() + + @torch._dynamo.allow_in_graph class NoopFwToFloat8E5M2Bw(torch.autograd.Function): """ @@ -393,7 +449,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) - output = torch.matmul(input_fp8, weight_fp8.t()) + output = manual_float8_matmul.apply(input_fp8, weight_fp8.t()) # Cast grad_output to float8_e5m2 during backward output = self.cast_output_to_float8_in_bw(output)