Skip to content

Commit

Permalink
Add support asymmetric fake-quantization to AQTv2.
Browse files Browse the repository at this point in the history
Integration of native quantization with biases will require computing the cross terms, likely in the AQT operation quantizer (`DefaultGeneralQuantizer`).

Itemized changes:

- `AqtNumerics`:
  - Rename `AqtNumerics.abs_val_mapped_to` to `AqtNumerics.get_scaled_bound` to reflect that the calibration bound may be span the whole quantization range (instead of ~half the range for a strictly linear transformation).
  - Refactor `IntNumerics` into `BaseIntNumerics`, `SymIntNumerics` and `AsymIntNumerics`.
    - `AsymIntNumerics` doesn't need `preserve_zero` or `preserve_max_val`.
- Add `MinMaxCalibration`.

I additionally tested this change by training MNIST models using `flax_e2e_model`. With symmetric quantization the model fails to converge for `config.config_v4(fwd_bits=2, dlhs_bits=None, drhs_bits=None)` (due to `NaN` losses). With asymmetric quantization the model converges even with `config.config_v4(fwd_bits=2, dlhs_bits=2, drhs_bits=4)`.

PiperOrigin-RevId: 651580879
  • Loading branch information
phoenix-meadowlark authored and copybara-github committed Sep 9, 2024
1 parent d2cfb75 commit 097249a
Show file tree
Hide file tree
Showing 19 changed files with 926 additions and 388 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ from aqt.jax.v2 import utils as aqt_utils
from aqt.jax.v2.numerics import int_numerics

q = aqt_quantizer.Quantizer(
numerics=int_numerics.IntNumerics(
numerics=int_numerics.SymIntNumerics(
bits=4,
preserve_zero=True,
preserve_max_val=True,
Expand Down
39 changes: 30 additions & 9 deletions aqt/jax/v2/aqt_conv_general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools

from absl.testing import absltest
from absl.testing import parameterized
from aqt.jax.v2 import aqt_quantizer
Expand All @@ -28,6 +30,16 @@ def rand_unif(shape, maxval, seed, dtype=jnp.float32):
)


def _apply_po2_scale(quantizer):
calibration_cls = quantizer.calibration
keywords = {}
if isinstance(calibration_cls, functools.partial):
keywords = calibration_cls.keywords
calibration_cls = calibration_cls.func
keywords.update(po2_scale=True)
quantizer.calibration = functools.partial(calibration_cls, **keywords)


class AqtConvGeneralTest(parameterized.TestCase):

@parameterized.parameters([
Expand All @@ -48,13 +60,17 @@ def test_conv_general_dilated(
rhs_maxval=20.0,
seed=0,
):
dg_raw_conv = aqt_conv.conv_general_dilated_make(2, lhs_bits, rhs_bits)

dg_raw_conv = aqt_conv.conv_general_dilated_make(
2, lhs_bits, rhs_bits, initialize_calibration=False
)
# Power-of-2 scales allow FQ and AQT to be exactly the same.
dg_quantizer = dg_raw_conv.dg_quantizer
if dg_raw_conv.lhs:
# Power-of-2 scales allow FQ and AQT to be exactly the same.
dg_raw_conv.dg_quantizer.lhs.po2_scale = True
_apply_po2_scale(dg_quantizer.lhs)
dg_quantizer.lhs.init_calibration()
if dg_raw_conv.rhs:
dg_raw_conv.dg_quantizer.rhs.po2_scale = True
_apply_po2_scale(dg_quantizer.rhs)
dg_quantizer.rhs.init_calibration()

batch_n = 10
contr_n = 20
Expand Down Expand Up @@ -94,12 +110,17 @@ def test_conv_general_dilated_quantized(
seed=0,
):
"""Check that passing quantized lhs/rhs to aqt_conv_fn works."""
dg_raw_conv = aqt_conv.conv_general_dilated_make(2, lhs_bits, rhs_bits)
dg_raw_conv = aqt_conv.conv_general_dilated_make(
2, lhs_bits, rhs_bits, initialize_calibration=False
)
# Power-of-2 scales allow FQ and AQT to be exactly the same.
dg_quantizer = dg_raw_conv.dg_quantizer
if dg_raw_conv.lhs:
# Power-of-2 scales allow FQ and AQT to be exactly the same.
dg_raw_conv.dg_quantizer.lhs.po2_scale = True
_apply_po2_scale(dg_quantizer.lhs)
dg_quantizer.lhs.init_calibration()
if dg_raw_conv.rhs:
dg_raw_conv.dg_quantizer.rhs.po2_scale = True
_apply_po2_scale(dg_quantizer.rhs)
dg_quantizer.rhs.init_calibration()

batch_n = 10
contr_n = 20
Expand Down
Loading

0 comments on commit 097249a

Please sign in to comment.