From 33de2e95a4a50e1a313310a15552be9a09501891 Mon Sep 17 00:00:00 2001 From: Phoenix Meadowlark Date: Thu, 11 Jul 2024 17:07:03 -0700 Subject: [PATCH] Add support asymmetric fake-quantization to AQTv2. Integration of native quantization with biases will require computing the cross terms, likely in the AQT operation quantizer (`DefaultGeneralQuantizer`). Itemized changes: - `AqtNumerics`: - Rename `AqtNumerics.abs_val_mapped_to` to `AqtNumerics.get_scaled_bound` to reflect that the calibration bound may be span the whole quantization range (instead of ~half the range for a strictly linear transformation). - Refactor `IntNumerics` into `BaseIntNumerics`, `SymIntNumerics` and `AsymIntNumerics`. - `AsymIntNumerics` doesn't need `preserve_zero` or `preserve_max_val`. - Add support for biases to `QTensor`. - Add `MinMaxCalibration`. - `flax_e2e_model`: - Add support for passing a dataset to `flax_e2e_model.train_and_evaluate` (this saves time when comparing configurations). - Make statistics printing easier to read. I additionally tested this change by training MNIST models using `flax_e2e_model`. With symmetric quantization the model fails to converge for `config.config_v4(fwd_bits=2, dlhs_bits=None, drhs_bits=None)` (due to `NaN` losses). With asymmetric quantization the model converges even with `config.config_v4(fwd_bits=2, dlhs_bits=2, drhs_bits=4)`. PiperOrigin-RevId: 651580879 --- README.md | 2 +- aqt/jax/v2/aqt_dot_general.py | 18 ++- aqt/jax/v2/aqt_dot_general_test.py | 127 ++++++++++++++-- aqt/jax/v2/aqt_quantizer.py | 62 +++++++- aqt/jax/v2/aqt_tensor.py | 34 ++++- aqt/jax/v2/calibration.py | 39 ++++- aqt/jax/v2/config.py | 83 ++++++++++- aqt/jax/v2/config_test.py | 1 + aqt/jax/v2/examples/flax_e2e_model.py | 31 ++-- aqt/jax/v2/examples/flax_e2e_model_test.py | 91 +++++++++--- .../gptq/examples/gptq_flax_e2e_model.py | 4 +- .../gptq/gptq_dot_general_quantizer.py | 4 +- aqt/jax/v2/flax/aqt_flax.py | 9 ++ aqt/jax/v2/numerics/fp8_numerics.py | 2 +- aqt/jax/v2/numerics/fp_numerics.py | 4 +- aqt/jax/v2/numerics/int_numerics.py | 140 ++++++++++++------ aqt/jax/v2/numerics/no_numerics.py | 2 +- aqt/jax/v2/numerics/numerics.py | 13 +- 18 files changed, 532 insertions(+), 134 deletions(-) diff --git a/README.md b/README.md index b00475a7..64376f08 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ from aqt.jax.v2 import utils as aqt_utils from aqt.jax.v2.numerics import int_numerics q = aqt_quantizer.Quantizer( - numerics=int_numerics.IntNumerics( + numerics=int_numerics.SymIntNumerics( bits=4, preserve_zero=True, preserve_max_val=True, diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index 56f7849d..584e8980 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -574,6 +574,11 @@ def _maybe_use_fwd_quant( ) if use_fwd_quant: assert fwd_quantized, msg + if rhs.qx.bias is not None: + raise NotImplementedError( + 'Quantization biases are not supported in forward quantization.' + ) + scale_t = transpose.rhs_scale_transpose_for_lhs_input( rhs.qx.scale[0], dimension_numbers, lhs.shape ) @@ -664,6 +669,17 @@ def __call__( self.allow_dummy_gradient_into_qtensor ) + msg = ( + 'biases are only supported in fake quant mode, but got a {arg} bias ' + 'and self.{arg}.dequant_mode == {mode} != DequantMode.THIS_INPUT' + ) + assert ( + lhs_qt.bias is None or self.lhs.dequant_mode == DequantMode.THIS_INPUT + ), msg.format(arg='lhs', mode=self.lhs.dequant_mode) + assert ( + rhs_qt.bias is None or self.rhs.dequant_mode == DequantMode.THIS_INPUT + ), msg.format(arg='rhs', mode=self.rhs.dequant_mode) + lhs_mt = MultiTensor(x=lhs, qx=lhs_qt) lhs_res = TensorRes(mt=lhs_mt, quant_grad=lhs_quant_grad) @@ -869,7 +885,7 @@ def assert_config_validity(self: Self): expected_fwd_quant = False msg_fwd_quant = ( f'use_fwd_quant should be set to {expected_fwd_quant} when remaining' - ' axis are used for calibration axis.' + ' axis are used for calibration axis. ' ) if self.fwd.rhs.calibration_mode == CalibrationMode.REMAINING_AXIS: diff --git a/aqt/jax/v2/aqt_dot_general_test.py b/aqt/jax/v2/aqt_dot_general_test.py index a9bdaa10..58fa3ef7 100644 --- a/aqt/jax/v2/aqt_dot_general_test.py +++ b/aqt/jax/v2/aqt_dot_general_test.py @@ -163,7 +163,7 @@ class _TrickyNumerics(numerics.AqtNumerics, flax.struct.PyTreeNode): def get_dtype(self): return self.dtype - def abs_val_mapped_to(self) -> jnp.ndarray: + def get_scaled_bound(self) -> jnp.ndarray: return jnp.array(1.0) def fwd(self, x, context): @@ -191,6 +191,7 @@ def _modify_dg( fwd_lhs_tricky_clip_and_round: bool = False, local_aqt: aqt.LocalAqt | None = None, clip_gradient: bool = False, + use_asymmetric: bool = False, ) -> aqt.DotGeneral: dg = copy.deepcopy(readonly_dg) if fwd_lhs_tricky_clip_and_round: @@ -242,11 +243,11 @@ def _disable_quant_types(c, on_lhs=True, on_rhs=True): # that the scales are not too large. def disable_quant(c): _disable_quant_types(c) - if isinstance(c.dg_quantizer.lhs.numerics, int_numerics.IntNumerics): + if isinstance(c.dg_quantizer.lhs.numerics, int_numerics.BaseIntNumerics): c.dg_quantizer.lhs.numerics = ( c.dg_quantizer.lhs.numerics.replace(round=False) ) - if isinstance(c.dg_quantizer.rhs.numerics, int_numerics.IntNumerics): + if isinstance(c.dg_quantizer.rhs.numerics, int_numerics.BaseIntNumerics): c.dg_quantizer.rhs.numerics = ( c.dg_quantizer.rhs.numerics.replace(round=False) ) @@ -271,15 +272,18 @@ def disable_quant(c): dg.drhs.local_aqt = local_aqt # When using abs-max scaling, this should be a no-op. - if isinstance(dg.fwd.dg_quantizer.lhs.numerics, int_numerics.IntNumerics): + if isinstance(dg.fwd.dg_quantizer.lhs.numerics, int_numerics.SymIntNumerics): dg.fwd.dg_quantizer.lhs.numerics = ( dg.fwd.dg_quantizer.lhs.numerics.replace(clip_gradient=clip_gradient) ) - if isinstance(dg.fwd.dg_quantizer.rhs.numerics, int_numerics.IntNumerics): + if isinstance(dg.fwd.dg_quantizer.rhs.numerics, int_numerics.SymIntNumerics): dg.fwd.dg_quantizer.rhs.numerics = ( dg.fwd.dg_quantizer.rhs.numerics.replace(clip_gradient=clip_gradient) ) + if use_asymmetric: + config.set_asymmetric_quantization(dg) + return dg @@ -295,6 +299,7 @@ def _aqt_dg_full_lr_diff( readonly_dg: aqt.DotGeneral, dims: jax.lax.DotDimensionNumbers, clip_gradient: bool = False, + use_asymmetric: bool = False, ) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]: dg = _modify_dg( readonly_dg, @@ -306,6 +311,7 @@ def _aqt_dg_full_lr_diff( fwd_lhs_tricky_clip_and_round=fwd_lhs_tricky_clip_and_round, local_aqt=local_aqt, clip_gradient=clip_gradient, + use_asymmetric=use_asymmetric, ) dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None) return lambda lhs, rhs: dg(lhs, rhs, dims) @@ -321,6 +327,7 @@ def _aqt_dg_full( readonly_dg: aqt.DotGeneral, dims: jax.lax.DotDimensionNumbers, clip_gradient: bool = False, + use_asymmetric: bool = False, ) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]: return _aqt_dg_full_lr_diff( dequant_mode, @@ -332,7 +339,8 @@ def _aqt_dg_full( local_aqt, readonly_dg=readonly_dg, dims=dims, - clip_gradient=clip_gradient + clip_gradient=clip_gradient, + use_asymmetric=use_asymmetric, ) @@ -344,6 +352,7 @@ def _aqt_dg_raw_lr_diff( *, readonly_dg: aqt.DotGeneral, dims: jax.lax.DotDimensionNumbers, + use_asymmetric: bool = False, ) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]: dg = _modify_dg( readonly_dg, @@ -351,6 +360,7 @@ def _aqt_dg_raw_lr_diff( rhs_dequant_mode=rhs_dequant_mode, lhs_calibration_mode=lhs_calibration_mode, rhs_calibration_mode=rhs_calibration_mode, + use_asymmetric=use_asymmetric, ) dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None) dg.fwd.dg_quantizer.init_calibration() @@ -363,6 +373,7 @@ def _aqt_dg_raw( *, readonly_dg: aqt.DotGeneral, dims: jax.lax.DotDimensionNumbers, + use_asymmetric: bool = False, ) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]: return _aqt_dg_raw_lr_diff( dequant_mode, @@ -371,6 +382,7 @@ def _aqt_dg_raw( calibration_mode, readonly_dg=readonly_dg, dims=dims, + use_asymmetric=use_asymmetric, ) @@ -389,7 +401,7 @@ def test_empty(self): def test_fq_noise(self, preserve_zero, prec, v, seed): key = jax.random.PRNGKey(seed) quantizer = config.quantizer_make(prec) - if isinstance(quantizer.numerics, int_numerics.IntNumerics): + if isinstance(quantizer.numerics, int_numerics.SymIntNumerics): quantizer.numerics.preserve_zero = preserve_zero if not preserve_zero: quantizer.numerics.dtype = None @@ -541,6 +553,24 @@ def test_dot_general_calibration_with_contracting_axis( dtype=jnp.float32, clip_gradient=False, ): + # This should be removed once asymmetric quant supports use_fwd_quant. + test_asym = not any([ + dg.fwd.lhs.use_fwd_quant, + dg.fwd.rhs.use_fwd_quant, + dg.dlhs.lhs.use_fwd_quant, + dg.dlhs.rhs.use_fwd_quant, + dg.drhs.lhs.use_fwd_quant, + dg.drhs.rhs.use_fwd_quant, + ]) + is_quantized = not all([ + isinstance(dg.fwd.dg_quantizer.lhs.numerics, no_numerics.NoNumerics), + isinstance(dg.fwd.dg_quantizer.rhs.numerics, no_numerics.NoNumerics), + isinstance(dg.dlhs.dg_quantizer.lhs.numerics, no_numerics.NoNumerics), + isinstance(dg.dlhs.dg_quantizer.rhs.numerics, no_numerics.NoNumerics), + isinstance(dg.drhs.dg_quantizer.lhs.numerics, no_numerics.NoNumerics), + isinstance(dg.drhs.dg_quantizer.rhs.numerics, no_numerics.NoNumerics), + ]) + readonly_dg = dg del dg @@ -555,9 +585,25 @@ def test_dot_general_calibration_with_contracting_axis( dims=dims, clip_gradient=clip_gradient, ) + asym_dg_full = functools.partial( + _aqt_dg_full, + readonly_dg=readonly_dg, + dims=dims, + clip_gradient=clip_gradient, + # As an argument to _modify_dg this must be None, not False. + # Unrelated things happen when False. + use_fwd_quant=None, + use_asymmetric=True, + ) aqt_dg_raw = functools.partial( _aqt_dg_raw, readonly_dg=readonly_dg, dims=dims ) + asym_dg_raw = functools.partial( + _aqt_dg_raw, + readonly_dg=readonly_dg, + dims=dims, + use_asymmetric=True, + ) modify_dg = functools.partial(_modify_dg, readonly_dg=readonly_dg) check = functools.partial(_check_result_eq, lhs=lhs, rhs=rhs, gra=gra) @@ -593,19 +639,57 @@ def test_dot_general_calibration_with_contracting_axis( dict(test_gradient=False), ), ]) + if test_asym: + check([ + ("default ", asym_dg_full(aqt.DequantMode.OUTPUT), dict()), + ("FQ ", asym_dg_full(aqt.DequantMode.THIS_INPUT), dict()), + ( + "raw fwd ", + asym_dg_raw(aqt.DequantMode.OUTPUT), + dict(test_gradient=False), + ), + ( + "raw fwd FQ ", + asym_dg_raw(aqt.DequantMode.THIS_INPUT), + dict(test_gradient=False), + ), + ]) check([ ( - "fwd_quant=T", + "fwd_quant=F", aqt_dg_full(aqt.DequantMode.OUTPUT, use_fwd_quant=False), dict(), ), ( - "fwd_quant=F", + "fwd_quant=T", aqt_dg_full(aqt.DequantMode.OUTPUT, use_fwd_quant=True), dict(), ), ]) + if test_asym and is_quantized: + # Asymmetric quantization does not currently support forward quantization. + with self.assertRaisesRegex(NotImplementedError, r"biases.*forward"): + check([ + ( + "fwd_quant=F", + aqt_dg_full( + aqt.DequantMode.OUTPUT, + use_fwd_quant=False, + use_asymmetric=True, + ), + dict(), + ), + ( + "fwd_quant=T", + aqt_dg_full( + aqt.DequantMode.OUTPUT, + use_fwd_quant=True, + use_asymmetric=True, + ), + dict(), + ), + ]) check([ ( @@ -617,7 +701,7 @@ def test_dot_general_calibration_with_contracting_axis( dict(), ), ( - "default ", + "FQ ", aqt_dg_full( aqt.DequantMode.THIS_INPUT, local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2), @@ -625,10 +709,29 @@ def test_dot_general_calibration_with_contracting_axis( dict(), ), ]) + if test_asym: + check([ + ( + "default ", + asym_dg_full( + aqt.DequantMode.OUTPUT, + local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2), + ), + dict(), + ), + ( + "FQ ", + asym_dg_full( + aqt.DequantMode.THIS_INPUT, + local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2), + ), + dict(), + ), + ]) if isinstance( readonly_dg.fwd.dg_quantizer.lhs.numerics, - int_numerics.IntNumerics, + int_numerics.SymIntNumerics, ): check([ ( @@ -1057,7 +1160,7 @@ def test_local_aqt(self, shard_count, lhs, expected_product): def test_per_tensor(self): # TODO(lew): bits=8 started failing in VLP colab due x/x != 1.0 sometimes bits = 4 - my_numerics = int_numerics.IntNumerics( + my_numerics = int_numerics.SymIntNumerics( bits=bits, preserve_zero=True, preserve_max_val=False, diff --git a/aqt/jax/v2/aqt_quantizer.py b/aqt/jax/v2/aqt_quantizer.py index 338be803..44c17b1f 100644 --- a/aqt/jax/v2/aqt_quantizer.py +++ b/aqt/jax/v2/aqt_quantizer.py @@ -13,7 +13,7 @@ # limitations under the License. """Configuration dataclasses.""" -from typing import Literal, Sequence +from typing import Callable, Literal, Sequence from aqt.jax.v2 import aqt_tensor from aqt.jax.v2 import calibration from aqt.jax.v2 import utils @@ -25,7 +25,9 @@ AbstractAqtNumerics = numerics.AqtNumerics +BaseIntNumerics = int_numerics.BaseIntNumerics AbstractAqtCalibration = calibration.Calibration +Axes = Sequence[utils.AxisIdx] @utils.flax_slots_kw_only_dataclass @@ -33,9 +35,7 @@ class Quantizer: """Configuration of quantization of one tensor.""" numerics: AbstractAqtNumerics = utils.static_field() - calib_shared_axes: Sequence[utils.AxisIdx] | Literal["per_tensor"] | None = ( - utils.static_field() - ) + calib_shared_axes: Axes | Literal["per_tensor"] | None = utils.static_field() scale_stop_grad: bool = utils.static_field() # noise+clip+round # We apply gradient of clip_and_round in bwd pass. @@ -48,6 +48,10 @@ class Quantizer: scale_dtype: jnp.dtype | None = utils.static_field(default=None) # TODO(yichizh): Factor out auxilliary dataclasses into a separate file. context: utils.Context + calculate_bias: ( + Callable[[jnp.ndarray, jnp.ndarray, Axes, BaseIntNumerics], jnp.ndarray] + | None + ) = utils.static_field(default=None) # we need to speed up this initialization for the backward pass to happen # outside of bwd pass. @@ -55,6 +59,23 @@ def init_calibration(self): assert self._calibrator is None, "second call to self.init_calibration()" self._calibrator = self.calibration() + def _validate_asymmetric_config(self): + # DO_NOT_SUBMIT: As things stand a check like this is necessary to prevent + # configurations that silently produce bad numerical results, but it's + # not pretty. + if not ( + isinstance(self.numerics, int_numerics.AsymIntNumerics) + == (self.calculate_bias is not None) + == isinstance(self._calibrator, calibration.MinMaxCalibration) + ): + raise ValueError( + "Asymmetric quantization requires self.numerics=AsymIntNumerics," + " self.calculate_bias is not None, and" + " self.calibrator=MinMaxCalibration, but got the following partial" + f" configuration: {self.numerics=}, {self.calculate_bias=}, and" + f" {self._calibrator=}" + ) + # TODO(yichizh): Need to add type annotation back to cfg. def quant( self, @@ -71,7 +92,7 @@ def calibrate(self, x, *, calibration_axes) -> aqt_tensor.QTensor: """Create incomplete QTensor with only quantization parameters.""" if isinstance(self.numerics, no_numerics.NoNumerics): qt = aqt_tensor.QTensor( - qvalue=x, scale=[], scale_t=None, dequant_dtype=x.dtype + qvalue=x, scale=[], scale_t=None, bias=None, dequant_dtype=x.dtype ) return qt @@ -85,9 +106,11 @@ def calibrate(self, x, *, calibration_axes) -> aqt_tensor.QTensor: shared_axes = self.calib_shared_axes or calibration_axes assert self._calibrator is not None, "forgot self.init_calibration()?" + self._validate_asymmetric_config() + bound = self._calibrator.get_bound(x, shared_axes, self.context) - abs_max_mapped_to = self.numerics.abs_val_mapped_to() - scale = bound / abs_max_mapped_to + scaled_bound = self.numerics.get_scaled_bound() + scale = bound / scaled_bound if self.po2_scale: # With floor the biggest value (we are using jnp.max) is in the range of @@ -101,10 +124,16 @@ def calibrate(self, x, *, calibration_axes) -> aqt_tensor.QTensor: if self.scale_dtype is not None: scale = scale.astype(self.scale_dtype) + if self.calculate_bias is not None: + bias = [self.calculate_bias(x, scale, shared_axes, self.numerics)] + else: + bias = None + qt = aqt_tensor.QTensor( qvalue=None, scale=[scale], scale_t=None, + bias=bias, dequant_dtype=dequant_dtype, ) return qt @@ -131,6 +160,23 @@ def calculate_qvalue( return qt, quant_grad +def calculate_asymmetric_bias( + x: jnp.ndarray, + scale: jnp.ndarray, + shared_axes: Axes, + numerics_: BaseIntNumerics, +) -> jnp.ndarray: + """Calculates the bias for asymmetric quantization.""" + if not isinstance(numerics_, int_numerics.AsymIntNumerics): + raise NotImplementedError( + "calculate_signed_asymmetric_bias only supports " + f" AsymIntNumerics, but got {numerics}" + ) + # Calcualte bias s.t. quant(min(x)) = (min(x) + bias) / scale = quant_min. + quant_min, _ = numerics_.get_quant_range() + return quant_min * scale - jnp.min(x, axis=shared_axes, keepdims=True) + + def quantizer_make( n_bits: int | None, preserve_max_val: bool = False, @@ -142,7 +188,7 @@ def quantizer_make( else: pz = False if n_bits == 1 else True dtype = utils.infer_dtype_from_bits(n_bits) if pz else None - effective_numerics = int_numerics.IntNumerics( + effective_numerics = int_numerics.SymIntNumerics( bits=n_bits, preserve_zero=pz, preserve_max_val=preserve_max_val, diff --git a/aqt/jax/v2/aqt_tensor.py b/aqt/jax/v2/aqt_tensor.py index 703a2447..b8c6b314 100644 --- a/aqt/jax/v2/aqt_tensor.py +++ b/aqt/jax/v2/aqt_tensor.py @@ -22,6 +22,7 @@ # pylint: disable=g-explicit-bool-comparison # pylint: disable=g-explicit-length-test +import itertools import typing from typing import Any, Callable, Optional, Sequence, TypeAlias from aqt.jax.v2 import utils @@ -69,6 +70,11 @@ class QTensor: # TODO(lew): Move scale_t from QTensor to some dot-general specific type? scale_t: Optional[list[ArrayT]] + # (bias == None) means that bias should not be applied; + # If bias is not None, it should have the same length as scale. + # Biases operate on the scale of the unquatnized tesnsors. + bias: Optional[list[ArrayT]] = utils.static_field(default=None) + # DType of the tensor before quantized. # NOTE: AQT Users should use the public property, dtype, instead. dequant_dtype: Optional[jnp.dtype] = flax.struct.field( @@ -97,13 +103,20 @@ def quant(self, x): """Quantizes the QTensor.""" assert not self.is_full(), 'Already quantized QTensor.' assert self.scale is not None, 'Missing scales to be used for quantization.' + assert self.bias is None or len(self.bias) == len( + self.scale + ), 'self.bias must be None or have the same length as self.scale.' qvalue = x - for s in self.scale: + bias = [] if self.bias is None else self.bias + for s, b in itertools.zip_longest(self.scale, bias): + # quant(x) = (x + b) / s + if b is not None: + qvalue += b # TODO(lew): We could store s_inv for faster activation quantization. s_inv = jax.lax.reciprocal(s) s_inv = jnp.where(jnp.isinf(s_inv), jnp.ones_like(s_inv), s_inv) - qvalue = qvalue * s_inv + qvalue *= s_inv # TODO(lew): We should apply numerics here, so that 'quant' function # Can be considered a part of API. @@ -112,18 +125,27 @@ def quant(self, x): def dequant(self) -> jnp.ndarray: """Dequantizes the QTensor.""" assert self.scale is not None, 'Missing scales when dequantizing a QTensor.' + assert self.bias is None or len(self.bias) == len( + self.scale + ), 'self.bias must be None or have the same length as self.scale.' msg = ( 'QTensor is manually created without setting a dequant_detype. It can' ' be used in dot_general, but to dequantize you need to set its dtype.' ) assert self.dequant_dtype is not None, msg assert self.is_full(), _MSG_NO_QVALUE + # pytype: disable=attribute-error ret = self.qvalue.astype(self.dequant_dtype) - for scale in self.scale: - ret = ret * scale - # In case the scale dtype is not the same as dequant_dtype, and it is a - # higher precision. + bias = [] if self.bias is None else self.bias + for s, b in itertools.zip_longest(self.scale, bias): + # dequant(x) = x * s - b + ret *= s + if b is not None: + ret -= b + + # In case the scale or bias dtypes are not the same as dequant_dtype, and it + # is a higher precision. ret = ret.astype(self.dequant_dtype) # pytype: enable=attribute-error return ret # pytype: disable=bad-return-type diff --git a/aqt/jax/v2/calibration.py b/aqt/jax/v2/calibration.py index 3497f89c..f9f4032b 100644 --- a/aqt/jax/v2/calibration.py +++ b/aqt/jax/v2/calibration.py @@ -58,7 +58,7 @@ class AbsMaxCalibration(Calibration): Attributes: scale: Set it to something like 0.3, 0.1, 0.03. If scale < 1.0, setting - IntNumerics.clip_gradient=True is likely to be important. + SymIntNumerics.clip_gradient=True is likely to be important. """ scale: float | None = None @@ -79,11 +79,44 @@ def get_bound( assert shared_axes is not None, msg # NOTE: If you want to clip, consider using clip and clip_gradient in - # int_numerics.IntNumerics. + # int_numerics.SymIntNumerics. abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True) # TODO(yichizh): the zero filtering is not needed anymore because inf is - # filtered when calculating the reciprocal of scaline factor + # filtered when calculating the reciprocal of scaling factor abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max) if self.scale is not None: abs_max = abs_max * self.scale return abs_max.astype(x.dtype) + + +@utils.flax_slots_kw_only_dataclass +class MinMaxCalibration(Calibration): + """Calibration between the min and max values. + + Attributes: + eps: Optional epsilon to add to the bound to avoid division by zero. + """ + + eps: float | None = None + + def get_bound( + self, + x: jnp.ndarray, + shared_axes: Sequence[utils.AxisIdx] | None, + context: utils.Context | None = None, + ) -> jnp.ndarray: + """Calibration.""" + del context + + msg = ( + 'Perhaps you are using DequantMode.THIS_INPUT (fake_quant) and forgot' + ' to set them.' + ) + assert shared_axes is not None, msg + + x_min = jnp.min(x, axis=shared_axes, keepdims=True) + x_max = jnp.max(x, axis=shared_axes, keepdims=True) + bound = x_max - x_min + if self.eps is not None: + bound += self.eps + return bound.astype(x.dtype) diff --git a/aqt/jax/v2/config.py b/aqt/jax/v2/config.py index 1dd0a488..46d984bd 100644 --- a/aqt/jax/v2/config.py +++ b/aqt/jax/v2/config.py @@ -89,16 +89,35 @@ def set_dg_raw_context(cfg_raw: DotGeneralRaw, key: Optional[jax.Array]): return ret_cfg -def set_fwd_dequant_mode( - cfg: DotGeneral, +def set_dequant_mode( + cfg: DotGeneralRaw, *, lhs_dequant_mode: Optional[DequantMode] = None, rhs_dequant_mode: Optional[DequantMode] = None, ): + """Sets the dequant mode for the lhs and rhs of a single dot general.""" if lhs_dequant_mode is not None: - cfg.fwd.lhs.dequant_mode = lhs_dequant_mode + cfg.lhs.dequant_mode = lhs_dequant_mode if rhs_dequant_mode is not None: - cfg.fwd.rhs.dequant_mode = rhs_dequant_mode + cfg.rhs.dequant_mode = rhs_dequant_mode + + fake_quant = DequantMode.THIS_INPUT in [lhs_dequant_mode, rhs_dequant_mode] + if fake_quant and jnp.issubdtype(cfg.dg_accumulator_dtype, jnp.integer): + # Fake-quantization is not compatible with integer accumulation. + cfg.dg_accumulator_dtype = None + + +def set_fwd_dequant_mode( + cfg: DotGeneral, + *, + lhs_dequant_mode: Optional[DequantMode] = None, + rhs_dequant_mode: Optional[DequantMode] = None, +): + set_dequant_mode( + cfg.fwd, + lhs_dequant_mode=lhs_dequant_mode, + rhs_dequant_mode=rhs_dequant_mode, + ) def set_fwd_calibration_mode( @@ -265,7 +284,7 @@ def set_int_numerics_preserve_zero(cfg: DotGeneral, preserve_zero: bool): for dot_general_raw in [cfg.fwd, cfg.dlhs, cfg.drhs]: dg_quantizer = dot_general_raw.dg_quantizer for q_numerics in [dg_quantizer.lhs.numerics, dg_quantizer.rhs.numerics]: - if isinstance(q_numerics, int_numerics.IntNumerics): + if isinstance(q_numerics, int_numerics.SymIntNumerics): q_numerics.preserve_zero = preserve_zero updated_dtype = ( utils.infer_dtype_from_bits(q_numerics.bits) # pytype: disable=attribute-error @@ -301,7 +320,7 @@ def set_absmax_calib_scale(cfg: DotGeneral, scale: float): calibration.AbsMaxCalibration, scale=scale ) if scale < 1.0 and isinstance( - quantizer.numerics, int_numerics.IntNumerics + quantizer.numerics, int_numerics.SymIntNumerics ): quantizer.numerics.clip_gradient = True @@ -330,7 +349,7 @@ def get_numerics(bits): else: pz = False if bits == 1 else True dtype = utils.infer_dtype_from_bits(bits) if pz else None - effective_numerics = int_numerics.IntNumerics( + effective_numerics = int_numerics.SymIntNumerics( bits=bits, preserve_zero=pz, preserve_max_val=False, @@ -355,6 +374,56 @@ def get_numerics(bits): return cfg +def _set_asymmetric_quantization(cfg: DotGeneralRaw): + """Replaces symmetric quantization with asymmetric quantization.""" + + def get_asym_numerics(numerics_: numerics.AqtNumerics): + if isinstance(numerics_, int_numerics.BaseIntNumerics): + # pytype: disable=attribute-error + return int_numerics.AsymIntNumerics( + bits=numerics_.bits, + clip=numerics_.clip, + clip_gradient=numerics_.clip_gradient, + round=numerics_.round, + noise_fn=numerics_.noise_fn, + dtype=numerics_.dtype, + ) + # pytype: enable=attribute-error + elif isinstance(numerics_, no_numerics.NoNumerics): + return numerics_ + else: + raise NotImplementedError( + 'Asymmetric quantization currently only supports integer numerics,' + f' but got {numerics_}' + ) + + set_numerics( + cfg, + get_asym_numerics(cfg.dg_quantizer.lhs.numerics), + get_asym_numerics(cfg.dg_quantizer.rhs.numerics), + ) + + cfg.dg_quantizer.lhs.calibration = calibration.MinMaxCalibration + cfg.dg_quantizer.rhs.calibration = calibration.MinMaxCalibration + + cfg.dg_quantizer.lhs.calculate_bias = aqt_quantizer.calculate_asymmetric_bias + cfg.dg_quantizer.rhs.calculate_bias = aqt_quantizer.calculate_asymmetric_bias + + # Only fake quantization currently supports quantization biases. + set_dequant_mode( + cfg, + lhs_dequant_mode=DequantMode.THIS_INPUT, + rhs_dequant_mode=DequantMode.THIS_INPUT, + ) + + +def set_asymmetric_quantization(cfg: DotGeneral): + """Replaces symmetric quantization with asymmetric quantization.""" + _set_asymmetric_quantization(cfg.fwd) + _set_asymmetric_quantization(cfg.dlhs) + _set_asymmetric_quantization(cfg.drhs) + + def set_scale_dtype(cfg: DotGeneral, scale_dtype: jnp.dtype): """Set the dtype for all scales in the given DotGeneral config.""" assert isinstance( diff --git a/aqt/jax/v2/config_test.py b/aqt/jax/v2/config_test.py index 2615c2aa..52879256 100644 --- a/aqt/jax/v2/config_test.py +++ b/aqt/jax/v2/config_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# DO_NOT_SUBMIT: Update once an API has been finalized. """Test for AQT configs.""" from absl.testing import absltest diff --git a/aqt/jax/v2/examples/flax_e2e_model.py b/aqt/jax/v2/examples/flax_e2e_model.py index 725843d4..4c081714 100644 --- a/aqt/jax/v2/examples/flax_e2e_model.py +++ b/aqt/jax/v2/examples/flax_e2e_model.py @@ -33,6 +33,9 @@ import tensorflow_datasets as tfds +Dataset = dict[str, jnp.ndarray] + + class CNN(nn.Module): """A simple CNN model.""" bn_use_stats: bool @@ -131,7 +134,7 @@ def loss_fn(model): grad_fn = jax.value_and_grad(loss_fn, has_aux=True, allow_int=True) aux, grads = grad_fn(model_params) loss, (logits, updated_var) = aux - accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) + accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) * 100 return grads, loss, accuracy, updated_var @@ -225,9 +228,13 @@ def train_and_evaluate( workdir: str, aqt_cfg: aqt_config.DotGeneral | None = None, state: TrainState | None = None, + datasets: tuple[Dataset, Dataset] | None = None, ) -> TrainState: """Execute model training and evaluation loop.""" - train_ds, test_ds = get_datasets() + if datasets is None: + train_ds, test_ds = get_datasets() + else: + train_ds, test_ds = datasets rng = jax.random.key(0) summary_writer = tensorboard.SummaryWriter(workdir) @@ -247,18 +254,14 @@ def train_and_evaluate( state.model, test_ds['image'], test_ds['label'], state.cnn_eval.apply ) - print( - 'epoch:% 3d, train_loss: %.30f, train_accuracy: %.30f, test_loss:' - ' %.30f, test_accuracy: %.30f' - % ( - epoch, - train_loss, - train_accuracy * 100, - test_loss, - test_accuracy * 100, - ), - flush=True, - ) + stats = [ + f'epoch: {epoch:3d}', + f'{train_loss = :.30f}', + f'{test_loss = :.30f}', + f'{train_accuracy = :34.30f}', + f'{test_accuracy = :34.30f}', + ] + print('\n'.join(stats) + '\n', flush=True) summary_writer.scalar('train_loss', train_loss, epoch) summary_writer.scalar('train_accuracy', train_accuracy, epoch) diff --git a/aqt/jax/v2/examples/flax_e2e_model_test.py b/aqt/jax/v2/examples/flax_e2e_model_test.py index 0d9f4c79..28017fe1 100644 --- a/aqt/jax/v2/examples/flax_e2e_model_test.py +++ b/aqt/jax/v2/examples/flax_e2e_model_test.py @@ -58,33 +58,73 @@ class MnistTest(parameterized.TestCase): }, 4, ), + ( + { + "fwd_bits": 2, + "dlhs_bits": 2, + }, + 2, + False, + ), + ( + { + "fwd_bits": 2, + "dlhs_bits": 2, + }, + 2, + True, + ), ]) - def test_mnist_training(self, configs, bits): + def test_mnist_training(self, configs, bits, use_asymmetric=False): aqt_cfg = config.config_v4(**configs) + if use_asymmetric: + config.set_asymmetric_quantization(aqt_cfg) target_loss = { - 8: { - "cpu": [ - 3.123474359512329101562500000000, - 3.123474597930908203125000000000, - 3.123473882675170898437500000000, # colab - ], - "TPU v2": [3.198328018188476562500000000000], - "TPU v3": [3.198328018188476562500000000000], - "TPU v4": [3.198297500610351562500000000000], - "TPU v5 lite": [3.198297500610351562500000000000], + False: { # use_asymmetric + 8: { # bits + "cpu": [ + 3.123474359512329101562500000000, + 3.123474597930908203125000000000, + 3.123473882675170898437500000000, # colab + ], + "TPU v2": [3.198328018188476562500000000000], + "TPU v3": [3.198328018188476562500000000000], + "TPU v4": [3.198297500610351562500000000000], + "TPU v5 lite": [3.198297500610351562500000000000], + }, + 4: { + "cpu": [2.258865118026733398437500000000], + "TPU v2": [2.302409172058105468750000000000], + "TPU v3": [2.302409172058105468750000000000], + "TPU v4": [2.302409172058105468750000000000], + "TPU v5 lite": [2.302409172058105468750000000000], + }, + 2: { + "cpu": [2.067147016525268554687500000000], + "TPU v2": [2.052407503128051757812500000000], + "TPU v3": [2.052407503128051757812500000000], + "TPU v4": [2.052407741546630859375000000000], + "TPU v5 lite": [2.054144620895385742187500000000], + }, }, - 4: { - "cpu": [2.258865118026733398437500000000], - "TPU v2": [2.302409172058105468750000000000], - "TPU v3": [2.302409172058105468750000000000], - "TPU v4": [2.302409172058105468750000000000], - "TPU v5 lite": [2.302409172058105468750000000000], + True: { + 2: { + "cpu": [ + 3.539643526077270507812500000000, + 3.539642572402954101562500000000, + ], + "TPU v2": [2.984576702117919921875000000000], + "TPU v3": [2.984576702117919921875000000000], + "TPU v4": [2.984576702117919921875000000000], + "TPU v5 lite": [2.982401847839355468750000000000], + }, }, } # below 3 lines are differences between config_v4/v3 and fully_quantized config.set_stochastic_rounding(aqt_cfg, True, True, "jax.uniform") - aqt_cfg.dlhs.rhs.use_fwd_quant = True - aqt_cfg.drhs.rhs.use_fwd_quant = True + if not use_asymmetric: + aqt_cfg.dlhs.rhs.use_fwd_quant = True + aqt_cfg.drhs.rhs.use_fwd_quant = True def forward(model, apply_fn): return apply_fn( @@ -114,16 +154,23 @@ def forward(model, apply_fn): ) device_kind = jax.devices()[0].device_kind - expected_train_loss = target_loss[bits][device_kind] + expected_train_loss = target_loss[use_asymmetric][bits][device_kind] if train_loss not in expected_train_loss: - msg = "train_loss changed. Consider updating with the following:\n" - msg += f' "{device_kind}": [{train_loss:.30f}]' + msg = ( + "train_loss changed. Consider updating with the following:\n" + f' "{device_kind}": [{train_loss:.30f}]\n' + f" expected one of: {expected_train_loss}" + ) self.fail(msg) # Run forward once more in the same mode to get logits for testing below. logits_s1, _ = forward(state.model, state.cnn_eval.apply) # Stage 2: Model conversion (quantized weights freezing) + if use_asymmetric: + with self.assertRaisesRegex(NotImplementedError, "biases"): + flax_e2e_model.serving_conversion(state) + return # Exit early out of the serving tests. apply_serving, model_serving = flax_e2e_model.serving_conversion(state) diff --git a/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py b/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py index 58cf6ad7..21c00ba5 100644 --- a/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py +++ b/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py @@ -42,10 +42,10 @@ def update_cfg_with_gptq(aqt_cfg: aqt_dot_general.DotGeneral) -> None: aqt_cfg.fwd.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer ) assert isinstance( - aqt_cfg.fwd.dg_quantizer.lhs.numerics, int_numerics.IntNumerics + aqt_cfg.fwd.dg_quantizer.lhs.numerics, int_numerics.SymIntNumerics ) assert isinstance( - aqt_cfg.fwd.dg_quantizer.rhs.numerics, int_numerics.IntNumerics + aqt_cfg.fwd.dg_quantizer.rhs.numerics, int_numerics.SymIntNumerics ) lhs_bits = aqt_cfg.fwd.dg_quantizer.lhs.numerics.bits rhs_bits = aqt_cfg.fwd.dg_quantizer.rhs.numerics.bits diff --git a/aqt/jax/v2/extensions/gptq/gptq_dot_general_quantizer.py b/aqt/jax/v2/extensions/gptq/gptq_dot_general_quantizer.py index 510730c1..a8f23d29 100644 --- a/aqt/jax/v2/extensions/gptq/gptq_dot_general_quantizer.py +++ b/aqt/jax/v2/extensions/gptq/gptq_dot_general_quantizer.py @@ -270,11 +270,11 @@ def calibrate( # Follow the quantization mode and num_bits of the kernel. if self.is_rhs_kernel: quant_mode = _get_quant_mode(self.rhs.context) - assert isinstance(self.rhs.numerics, int_numerics.IntNumerics) + assert isinstance(self.rhs.numerics, int_numerics.SymIntNumerics) num_bits = self.rhs.numerics.bits else: quant_mode = _get_quant_mode(self.lhs.context) - assert isinstance(self.lhs.numerics, int_numerics.IntNumerics) + assert isinstance(self.lhs.numerics, int_numerics.SymIntNumerics) num_bits = self.lhs.numerics.bits if quant_mode == utils.QuantMode.TRAIN: diff --git a/aqt/jax/v2/flax/aqt_flax.py b/aqt/jax/v2/flax/aqt_flax.py index 92019fde..7620de82 100644 --- a/aqt/jax/v2/flax/aqt_flax.py +++ b/aqt/jax/v2/flax/aqt_flax.py @@ -128,6 +128,10 @@ def set(self, inputs: aqt_tensor.QTensor) -> None: self.qvalue.value = inputs.qvalue assert inputs.scale_t is not None and len(inputs.scale_t) == 1 self.scale_t.value = inputs.scale_t[0] + if inputs.bias is not None: + raise NotImplementedError( + 'Quantization biases are not supported in AQT Flax Legacy Freezer.' + ) elif self.quant_mode == QuantMode.SERVE: # TODO(lew): Optionally compare stored and served value. pass @@ -304,6 +308,11 @@ def init_wrapper( axis_metadata_wrapper: Optional[AxisMetadataWrapper], tile_map: tiled_dot_general.AqtTileMap, ): + if qt.bias is not None: + raise NotImplementedError( + 'Quantization biases are not supported in AQT Flax Freezer.' + ) + if axis_metadata_wrapper is None: return qt diff --git a/aqt/jax/v2/numerics/fp8_numerics.py b/aqt/jax/v2/numerics/fp8_numerics.py index 25476bd6..1f87701f 100644 --- a/aqt/jax/v2/numerics/fp8_numerics.py +++ b/aqt/jax/v2/numerics/fp8_numerics.py @@ -73,7 +73,7 @@ def _get_edge_of_last_fp8_bucket(self): def get_dtype(self): return self.dtype - def abs_val_mapped_to(self): + def get_scaled_bound(self): return self._get_edge_of_last_fp8_bucket() def vjp_fwd(self, x, context): diff --git a/aqt/jax/v2/numerics/fp_numerics.py b/aqt/jax/v2/numerics/fp_numerics.py index 1b452da9..2b7d6d9b 100644 --- a/aqt/jax/v2/numerics/fp_numerics.py +++ b/aqt/jax/v2/numerics/fp_numerics.py @@ -196,7 +196,7 @@ class FpNumerics(numerics.AqtNumerics): stochastic_rounding: bool = utils.static_field(default=False) clip_gradient: bool = utils.static_field(default=False) - def abs_val_mapped_to(self): + def get_scaled_bound(self): return fp_largest_representable(cfg=self.cfg) def get_dtype(self): @@ -216,6 +216,6 @@ def vjp_bwd(self, res, grad): ret = grad if self.clip_gradient: (x,) = res - clip_bound = self.abs_val_mapped_to() + clip_bound = self.get_scaled_bound() ret *= (-clip_bound <= x) * (x <= clip_bound) return (ret, None) diff --git a/aqt/jax/v2/numerics/int_numerics.py b/aqt/jax/v2/numerics/int_numerics.py index 127a8923..ab3af08b 100644 --- a/aqt/jax/v2/numerics/int_numerics.py +++ b/aqt/jax/v2/numerics/int_numerics.py @@ -22,20 +22,81 @@ @utils.flax_slots_kw_only_dataclass -class IntNumerics(numerics.AqtNumerics): - """Numerics for int8, int4, binary, etc.""" +class BaseIntNumerics(numerics.AqtNumerics): + """Base numerics for sint8, sint4, binary, etc.""" bits: int - preserve_zero: bool - # false = map max val on the end of the last bucket - # true = map max val on the middle of the last - preserve_max_val: bool clip: bool clip_gradient: bool round: bool noise_fn: Optional[stochastic_rounding.NoiseFn] dtype: Optional[Any] = None + def get_dtype(self): + return self.dtype + + def get_scaled_bound(self): + raise NotImplementedError("Use either SymIntNumerics or AsymIntNumerics") + + def _get_fwd_clip_bound(self): + raise NotImplementedError("Use either SymIntNumerics or AsymIntNumerics") + + # DO_NOT_SUBMIT: Should this be a method on all AqtNumerics classes? + def get_quant_range(self): + raise NotImplementedError("Use either SymIntNumerics or AsymIntNumerics") + + def _maybe_round(self, x): + raise NotImplementedError("Use either SymIntNumerics or AsymIntNumerics") + + def vjp_fwd(self, x, context): + """Forward pass.""" + res = (x,) + input_dtype = x.dtype + assert self.bits <= 22, "Too many bits, float32 has less precision." + + # Maybe noise + if self.noise_fn: + assert context.key is not None, ( + "noise_fn is set, requestic stochastic rounding, but RNG was not " + "passed in Context.key" + ) + x = (x + self.noise_fn(x.shape, context.key)).astype(input_dtype) + + if self.clip: + lower_clip_bound, upper_clip_bound = self._get_fwd_clip_bound() + x = jnp.clip(x, lower_clip_bound, upper_clip_bound) + + x = self._maybe_round(x) + + # Maybe cast: return dtype is either int or the input dtype + dtype = self.get_dtype() + x = x.astype(dtype if dtype is not None else input_dtype) + return x, res + + def vjp_bwd(self, res, grad): + # Gradient of the clip function. + # For boundary values we will have full gradient. + # When using abs(max(x)) scaling, x is always in the interior, and the + # gradient clip is always 1. So, we can always set clip_gradient to false. + # However, other types of scaling may result in x being outside (i.e., there + # is clipping). In that case it may be desirable to make the gradient zero. + ret = grad + if self.clip_gradient: + (x,) = res + lower_clip_bound, upper_clip_bound = self._get_fwd_clip_bound() + ret *= (lower_clip_bound <= x) * (x <= upper_clip_bound) + return (ret, None) + + +@utils.flax_slots_kw_only_dataclass +class SymIntNumerics(BaseIntNumerics): + """Symmetric numerics for sint8, sint4, binary, etc.""" + + preserve_zero: bool + # false = map max val on the end of the last bucket + # true = map max val on the middle of the last + preserve_max_val: bool + # pylint: disable=line-too-long # Verifying the correctness of these functions amounts to verifying this table: # if preserve_zero == F, zero might be rounded either to [-1, 0] bucket or to [0, 1] bucket @@ -56,7 +117,7 @@ def get_edge_of_last_int_bucket(self): def get_center_of_last_int_bucket(self): return self.get_edge_of_last_int_bucket() - 0.5 - def abs_val_mapped_to(self): + def get_scaled_bound(self): if self.preserve_max_val: return self.get_center_of_last_int_bucket() else: @@ -71,30 +132,9 @@ def _get_fwd_clip_bound(self): if self.round: # Reducing fwd_clip_bound by any value in (0.0, 1.0) is correct. fwd_clip_bound -= 0.5 - return fwd_clip_bound + return -fwd_clip_bound, fwd_clip_bound - def get_dtype(self): - return self.dtype - - def vjp_fwd(self, x, context): - """Forward pass.""" - res = (x,) - input_dtype = x.dtype - assert self.bits <= 22, 'Too many bits, float32 has less precision.' - - # Maybe noise - if self.noise_fn: - assert context.key is not None, ( - 'noise_fn is set, requestic stochastic rounding, but RNG was not ' - 'passed in Context.key' - ) - x = (x + self.noise_fn(x.shape, context.key)).astype(input_dtype) - - if self.clip: - fwd_clip_bound = self._get_fwd_clip_bound() - x = jnp.clip(x, -fwd_clip_bound, fwd_clip_bound) - - # Maybe round + def _maybe_round(self, x): if self.round: # TODO(lew): Have bucket centers at 2*k + 1, not at halves. round_to_halves = not self.preserve_zero @@ -102,22 +142,28 @@ def vjp_fwd(self, x, context): x = jnp.floor(x) + 0.5 else: x = lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) + return x - # Maybe cast: return dtype is either int or the input dtype - dtype = self.get_dtype() - x = x.astype(dtype if dtype is not None else input_dtype) - return x, res - def vjp_bwd(self, res, grad): - # Gradient of the clip function. - # For boundary values we will have full gradient. - # When using abs(max(x)) scaling, x is always in the interior, and the - # gradient clip is always 1. So, we can always set clip_gradient to false. - # However, other types of scaling may result in x being outside (i.e., there - # is clipping). In that case it may be desirable to make the gradient zero. - ret = grad - if self.clip_gradient: - (x,) = res - clip_bound = self._get_fwd_clip_bound() - ret *= (-clip_bound <= x) * (x <= clip_bound) - return (ret, None) +@utils.flax_slots_kw_only_dataclass +class AsymIntNumerics(BaseIntNumerics): + """Asymmetric numerics for sint8, sint4, binary, etc.""" + + def get_scaled_bound(self): + return 2.0**self.bits - 1 + + def get_quant_range(self): + if self.bits > 1: + # Full signed int range. + return -(2.0 ** (self.bits - 1)), 2.0 ** (self.bits - 1) - 1 + else: + # Boolean range. + return 0.0, 1.0 + + def _get_fwd_clip_bound(self): + return self.get_quant_range() + + def _maybe_round(self, x): + if self.round: + x = lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) + return x diff --git a/aqt/jax/v2/numerics/no_numerics.py b/aqt/jax/v2/numerics/no_numerics.py index 60fa2fbb..cf03e1db 100644 --- a/aqt/jax/v2/numerics/no_numerics.py +++ b/aqt/jax/v2/numerics/no_numerics.py @@ -33,7 +33,7 @@ class NoNumerics(numerics.AqtNumerics): def get_dtype(self): return None - def abs_val_mapped_to(self): + def get_scaled_bound(self): pass def vjp_fwd(self, x, context): diff --git a/aqt/jax/v2/numerics/numerics.py b/aqt/jax/v2/numerics/numerics.py index a7cc83c9..066714bb 100644 --- a/aqt/jax/v2/numerics/numerics.py +++ b/aqt/jax/v2/numerics/numerics.py @@ -26,12 +26,15 @@ def get_dtype(self): pass @abc.abstractmethod - def abs_val_mapped_to(self): - """The value returned is the end of quantization range. + def get_scaled_bound(self): + """Returns the width that the scale corresponds to in the quantizion range. - It could be biggest value that can be represented by numerical format - exactly. E.g. in case of int8, 127 . Or it could be edge of the last bucket. - Edge in case of int8, 127.5 + For symmetric scaling (relative to a fixed zero point) it could be biggest + value that can be represented by numerical format exactly. E.g. in case of + int8, 127 . Or it could be edge of the last bucket (in case of int8, 127.5). + + For asymmetric scaling, it corresponds to the width of the entire + quantization range. E.g. in case of int8, 255. """ pass