From 43db2717e0d2e9ef09b566f7e7bbad049d63dceb Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Oct 2023 07:05:12 -0700 Subject: [PATCH] [LSC] Ignore incorrect type annotations related to jax.numpy APIs PiperOrigin-RevId: 570055305 --- praxis/layers/quantization/quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/praxis/layers/quantization/quantizer.py b/praxis/layers/quantization/quantizer.py index 976853cd..c90e40fe 100644 --- a/praxis/layers/quantization/quantizer.py +++ b/praxis/layers/quantization/quantizer.py @@ -489,7 +489,7 @@ def _get_optimal_scale_and_min( contract_dims: int | Sequence[int], ) -> tuple[JTensor, JTensor | None]: def quantization_error_and_scale(clipping): - q_scale, x_min = self._get_scale_and_min( + q_scale, x_min = self._get_scale_and_min( # pytype: disable=wrong-arg-types # jnp-type x, contract_dims, clipping_coeff=clipping ) x_scaled, zp_time_scale = self._scale(x, q_scale, x_min)