Skip to content

Commit

Permalink
Add support asymmetric fake-quantization to AQTv2.
Browse files Browse the repository at this point in the history
Integration of native quantization with biases will require computing the cross terms, likely in the AQT operation quantizer (`DefaultGeneralQuantizer`).

Itemized changes:

- `AqtNumerics`:
  - Rename `AqtNumerics.abs_val_mapped_to` to `AqtNumerics.get_scaled_bound` to reflect that the calibration bound may be span the whole quantization range (instead of ~half the range for a strictly linear transformation).
  - Refactor `IntNumerics` into `BaseIntNumerics`, `SymIntNumerics` and `AsymIntNumerics`.
    - `AsymIntNumerics` doesn't need `preserve_zero` or `preserve_max_val`.
- Add `MinMaxCalibration`.

I additionally tested this change by training MNIST models using `flax_e2e_model`. With symmetric quantization the model fails to converge for `config.config_v4(fwd_bits=2, dlhs_bits=None, drhs_bits=None)` (due to `NaN` losses). With asymmetric quantization the model converges even with `config.config_v4(fwd_bits=2, dlhs_bits=2, drhs_bits=4)`.

PiperOrigin-RevId: 651580879
  • Loading branch information
phoenix-meadowlark authored and copybara-github committed Aug 29, 2024
1 parent b391500 commit cef287a
Show file tree
Hide file tree
Showing 19 changed files with 908 additions and 394 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ from aqt.jax.v2 import utils as aqt_utils
from aqt.jax.v2.numerics import int_numerics

q = aqt_quantizer.Quantizer(
numerics=int_numerics.IntNumerics(
numerics=int_numerics.SymIntNumerics(
bits=4,
preserve_zero=True,
preserve_max_val=True,
Expand Down
28 changes: 19 additions & 9 deletions aqt/jax/v2/aqt_conv_general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from absl.testing import absltest
from absl.testing import parameterized
from aqt.jax.v2 import aqt_quantizer
from aqt.jax.v2 import config
import aqt.jax.v2.aqt_conv_general as aqt_conv
import flax.linen.linear as fl
import jax
Expand Down Expand Up @@ -48,13 +49,17 @@ def test_conv_general_dilated(
rhs_maxval=20.0,
seed=0,
):
dg_raw_conv = aqt_conv.conv_general_dilated_make(2, lhs_bits, rhs_bits)

dg_raw_conv = aqt_conv.conv_general_dilated_make(
2, lhs_bits, rhs_bits, initialize_calibration=False
)
# Power-of-2 scales allow FQ and AQT to be exactly the same.
dg_quantizer = dg_raw_conv.dg_quantizer
if dg_raw_conv.lhs:
# Power-of-2 scales allow FQ and AQT to be exactly the same.
dg_raw_conv.dg_quantizer.lhs.po2_scale = True
config.set_quantizer_calibration_config(dg_quantizer.lhs, po2_scale=True)
dg_quantizer.lhs.init_calibration()
if dg_raw_conv.rhs:
dg_raw_conv.dg_quantizer.rhs.po2_scale = True
config.set_quantizer_calibration_config(dg_quantizer.rhs, po2_scale=True)
dg_quantizer.rhs.init_calibration()

batch_n = 10
contr_n = 20
Expand Down Expand Up @@ -94,12 +99,17 @@ def test_conv_general_dilated_quantized(
seed=0,
):
"""Check that passing quantized lhs/rhs to aqt_conv_fn works."""
dg_raw_conv = aqt_conv.conv_general_dilated_make(2, lhs_bits, rhs_bits)
dg_raw_conv = aqt_conv.conv_general_dilated_make(
2, lhs_bits, rhs_bits, initialize_calibration=False
)
# Power-of-2 scales allow FQ and AQT to be exactly the same.
dg_quantizer = dg_raw_conv.dg_quantizer
if dg_raw_conv.lhs:
# Power-of-2 scales allow FQ and AQT to be exactly the same.
dg_raw_conv.dg_quantizer.lhs.po2_scale = True
config.set_quantizer_calibration_config(dg_quantizer.lhs, po2_scale=True)
dg_quantizer.lhs.init_calibration()
if dg_raw_conv.rhs:
dg_raw_conv.dg_quantizer.rhs.po2_scale = True
config.set_quantizer_calibration_config(dg_quantizer.rhs, po2_scale=True)
dg_quantizer.rhs.init_calibration()

batch_n = 10
contr_n = 20
Expand Down
Loading

0 comments on commit cef287a

Please sign in to comment.