Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570055305
  • Loading branch information
Jake VanderPlas authored and pax authors committed Oct 2, 2023
1 parent 68bb8e8 commit 43db271
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion praxis/layers/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def _get_optimal_scale_and_min(
contract_dims: int | Sequence[int],
) -> tuple[JTensor, JTensor | None]:
def quantization_error_and_scale(clipping):
q_scale, x_min = self._get_scale_and_min(
q_scale, x_min = self._get_scale_and_min( # pytype: disable=wrong-arg-types # jnp-type
x, contract_dims, clipping_coeff=clipping
)
x_scaled, zp_time_scale = self._scale(x, q_scale, x_min)
Expand Down

0 comments on commit 43db271

Please sign in to comment.