Skip to content

Commit

Permalink
Add support asymmetric fake-quantization to AQTv2.
Browse files Browse the repository at this point in the history
Integration of native quantization with biases will require computing the cross terms, likely in the AQT operation quantizer (`DefaultGeneralQuantizer`).

Itemized changes:

- `AqtNumerics`:
  - Rename `AqtNumerics.abs_val_mapped_to` to `AqtNumerics.get_scaled_bound` to reflect that the calibration bound may be span the whole quantization range (instead of ~half the range for a strictly linear transformation).
  - Refactor `IntNumerics` into `BaseIntNumerics`, `SymIntNumerics` and `AsymIntNumerics`.
    - `AsymIntNumerics` doesn't need `preserve_zero` or `preserve_max_val`.
- Add support for biases to `QTensor`.
- Add `MinMaxCalibration`.
- `flax_e2e_model`:
  - Add support for passing a dataset to `flax_e2e_model.train_and_evaluate` (this saves time when comparing configurations).
  - Make statistics printing easier to read.

I additionally tested this change by training MNIST models using `flax_e2e_model`. With symmetric quantization the model fails to converge for `config.config_v4(fwd_bits=2, dlhs_bits=None, drhs_bits=None)` (due to `NaN` losses). With asymmetric quantization the model converges even with `config.config_v4(fwd_bits=2, dlhs_bits=2, drhs_bits=4)`.

PiperOrigin-RevId: 651580879
  • Loading branch information
phoenix-meadowlark authored and copybara-github committed Jul 23, 2024
1 parent 5e5897c commit 33de2e9
Show file tree
Hide file tree
Showing 18 changed files with 532 additions and 134 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ from aqt.jax.v2 import utils as aqt_utils
from aqt.jax.v2.numerics import int_numerics

q = aqt_quantizer.Quantizer(
numerics=int_numerics.IntNumerics(
numerics=int_numerics.SymIntNumerics(
bits=4,
preserve_zero=True,
preserve_max_val=True,
Expand Down
18 changes: 17 additions & 1 deletion aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,11 @@ def _maybe_use_fwd_quant(
)
if use_fwd_quant:
assert fwd_quantized, msg
if rhs.qx.bias is not None:
raise NotImplementedError(
'Quantization biases are not supported in forward quantization.'
)

scale_t = transpose.rhs_scale_transpose_for_lhs_input(
rhs.qx.scale[0], dimension_numbers, lhs.shape
)
Expand Down Expand Up @@ -664,6 +669,17 @@ def __call__(
self.allow_dummy_gradient_into_qtensor
)

msg = (
'biases are only supported in fake quant mode, but got a {arg} bias '
'and self.{arg}.dequant_mode == {mode} != DequantMode.THIS_INPUT'
)
assert (
lhs_qt.bias is None or self.lhs.dequant_mode == DequantMode.THIS_INPUT
), msg.format(arg='lhs', mode=self.lhs.dequant_mode)
assert (
rhs_qt.bias is None or self.rhs.dequant_mode == DequantMode.THIS_INPUT
), msg.format(arg='rhs', mode=self.rhs.dequant_mode)

lhs_mt = MultiTensor(x=lhs, qx=lhs_qt)
lhs_res = TensorRes(mt=lhs_mt, quant_grad=lhs_quant_grad)

Expand Down Expand Up @@ -869,7 +885,7 @@ def assert_config_validity(self: Self):
expected_fwd_quant = False
msg_fwd_quant = (
f'use_fwd_quant should be set to {expected_fwd_quant} when remaining'
' axis are used for calibration axis.'
' axis are used for calibration axis. '
)

if self.fwd.rhs.calibration_mode == CalibrationMode.REMAINING_AXIS:
Expand Down
127 changes: 115 additions & 12 deletions aqt/jax/v2/aqt_dot_general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class _TrickyNumerics(numerics.AqtNumerics, flax.struct.PyTreeNode):
def get_dtype(self):
return self.dtype

def abs_val_mapped_to(self) -> jnp.ndarray:
def get_scaled_bound(self) -> jnp.ndarray:
return jnp.array(1.0)

def fwd(self, x, context):
Expand Down Expand Up @@ -191,6 +191,7 @@ def _modify_dg(
fwd_lhs_tricky_clip_and_round: bool = False,
local_aqt: aqt.LocalAqt | None = None,
clip_gradient: bool = False,
use_asymmetric: bool = False,
) -> aqt.DotGeneral:
dg = copy.deepcopy(readonly_dg)
if fwd_lhs_tricky_clip_and_round:
Expand Down Expand Up @@ -242,11 +243,11 @@ def _disable_quant_types(c, on_lhs=True, on_rhs=True):
# that the scales are not too large.
def disable_quant(c):
_disable_quant_types(c)
if isinstance(c.dg_quantizer.lhs.numerics, int_numerics.IntNumerics):
if isinstance(c.dg_quantizer.lhs.numerics, int_numerics.BaseIntNumerics):
c.dg_quantizer.lhs.numerics = (
c.dg_quantizer.lhs.numerics.replace(round=False)
)
if isinstance(c.dg_quantizer.rhs.numerics, int_numerics.IntNumerics):
if isinstance(c.dg_quantizer.rhs.numerics, int_numerics.BaseIntNumerics):
c.dg_quantizer.rhs.numerics = (
c.dg_quantizer.rhs.numerics.replace(round=False)
)
Expand All @@ -271,15 +272,18 @@ def disable_quant(c):
dg.drhs.local_aqt = local_aqt

# When using abs-max scaling, this should be a no-op.
if isinstance(dg.fwd.dg_quantizer.lhs.numerics, int_numerics.IntNumerics):
if isinstance(dg.fwd.dg_quantizer.lhs.numerics, int_numerics.SymIntNumerics):
dg.fwd.dg_quantizer.lhs.numerics = (
dg.fwd.dg_quantizer.lhs.numerics.replace(clip_gradient=clip_gradient)
)
if isinstance(dg.fwd.dg_quantizer.rhs.numerics, int_numerics.IntNumerics):
if isinstance(dg.fwd.dg_quantizer.rhs.numerics, int_numerics.SymIntNumerics):
dg.fwd.dg_quantizer.rhs.numerics = (
dg.fwd.dg_quantizer.rhs.numerics.replace(clip_gradient=clip_gradient)
)

if use_asymmetric:
config.set_asymmetric_quantization(dg)

return dg


Expand All @@ -295,6 +299,7 @@ def _aqt_dg_full_lr_diff(
readonly_dg: aqt.DotGeneral,
dims: jax.lax.DotDimensionNumbers,
clip_gradient: bool = False,
use_asymmetric: bool = False,
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
dg = _modify_dg(
readonly_dg,
Expand All @@ -306,6 +311,7 @@ def _aqt_dg_full_lr_diff(
fwd_lhs_tricky_clip_and_round=fwd_lhs_tricky_clip_and_round,
local_aqt=local_aqt,
clip_gradient=clip_gradient,
use_asymmetric=use_asymmetric,
)
dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None)
return lambda lhs, rhs: dg(lhs, rhs, dims)
Expand All @@ -321,6 +327,7 @@ def _aqt_dg_full(
readonly_dg: aqt.DotGeneral,
dims: jax.lax.DotDimensionNumbers,
clip_gradient: bool = False,
use_asymmetric: bool = False,
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
return _aqt_dg_full_lr_diff(
dequant_mode,
Expand All @@ -332,7 +339,8 @@ def _aqt_dg_full(
local_aqt,
readonly_dg=readonly_dg,
dims=dims,
clip_gradient=clip_gradient
clip_gradient=clip_gradient,
use_asymmetric=use_asymmetric,
)


Expand All @@ -344,13 +352,15 @@ def _aqt_dg_raw_lr_diff(
*,
readonly_dg: aqt.DotGeneral,
dims: jax.lax.DotDimensionNumbers,
use_asymmetric: bool = False,
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
dg = _modify_dg(
readonly_dg,
lhs_dequant_mode=lhs_dequant_mode,
rhs_dequant_mode=rhs_dequant_mode,
lhs_calibration_mode=lhs_calibration_mode,
rhs_calibration_mode=rhs_calibration_mode,
use_asymmetric=use_asymmetric,
)
dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None)
dg.fwd.dg_quantizer.init_calibration()
Expand All @@ -363,6 +373,7 @@ def _aqt_dg_raw(
*,
readonly_dg: aqt.DotGeneral,
dims: jax.lax.DotDimensionNumbers,
use_asymmetric: bool = False,
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
return _aqt_dg_raw_lr_diff(
dequant_mode,
Expand All @@ -371,6 +382,7 @@ def _aqt_dg_raw(
calibration_mode,
readonly_dg=readonly_dg,
dims=dims,
use_asymmetric=use_asymmetric,
)


Expand All @@ -389,7 +401,7 @@ def test_empty(self):
def test_fq_noise(self, preserve_zero, prec, v, seed):
key = jax.random.PRNGKey(seed)
quantizer = config.quantizer_make(prec)
if isinstance(quantizer.numerics, int_numerics.IntNumerics):
if isinstance(quantizer.numerics, int_numerics.SymIntNumerics):
quantizer.numerics.preserve_zero = preserve_zero
if not preserve_zero:
quantizer.numerics.dtype = None
Expand Down Expand Up @@ -541,6 +553,24 @@ def test_dot_general_calibration_with_contracting_axis(
dtype=jnp.float32,
clip_gradient=False,
):
# This should be removed once asymmetric quant supports use_fwd_quant.
test_asym = not any([
dg.fwd.lhs.use_fwd_quant,
dg.fwd.rhs.use_fwd_quant,
dg.dlhs.lhs.use_fwd_quant,
dg.dlhs.rhs.use_fwd_quant,
dg.drhs.lhs.use_fwd_quant,
dg.drhs.rhs.use_fwd_quant,
])
is_quantized = not all([
isinstance(dg.fwd.dg_quantizer.lhs.numerics, no_numerics.NoNumerics),
isinstance(dg.fwd.dg_quantizer.rhs.numerics, no_numerics.NoNumerics),
isinstance(dg.dlhs.dg_quantizer.lhs.numerics, no_numerics.NoNumerics),
isinstance(dg.dlhs.dg_quantizer.rhs.numerics, no_numerics.NoNumerics),
isinstance(dg.drhs.dg_quantizer.lhs.numerics, no_numerics.NoNumerics),
isinstance(dg.drhs.dg_quantizer.rhs.numerics, no_numerics.NoNumerics),
])

readonly_dg = dg
del dg

Expand All @@ -555,9 +585,25 @@ def test_dot_general_calibration_with_contracting_axis(
dims=dims,
clip_gradient=clip_gradient,
)
asym_dg_full = functools.partial(
_aqt_dg_full,
readonly_dg=readonly_dg,
dims=dims,
clip_gradient=clip_gradient,
# As an argument to _modify_dg this must be None, not False.
# Unrelated things happen when False.
use_fwd_quant=None,
use_asymmetric=True,
)
aqt_dg_raw = functools.partial(
_aqt_dg_raw, readonly_dg=readonly_dg, dims=dims
)
asym_dg_raw = functools.partial(
_aqt_dg_raw,
readonly_dg=readonly_dg,
dims=dims,
use_asymmetric=True,
)
modify_dg = functools.partial(_modify_dg, readonly_dg=readonly_dg)
check = functools.partial(_check_result_eq, lhs=lhs, rhs=rhs, gra=gra)

Expand Down Expand Up @@ -593,19 +639,57 @@ def test_dot_general_calibration_with_contracting_axis(
dict(test_gradient=False),
),
])
if test_asym:
check([
("default ", asym_dg_full(aqt.DequantMode.OUTPUT), dict()),
("FQ ", asym_dg_full(aqt.DequantMode.THIS_INPUT), dict()),
(
"raw fwd ",
asym_dg_raw(aqt.DequantMode.OUTPUT),
dict(test_gradient=False),
),
(
"raw fwd FQ ",
asym_dg_raw(aqt.DequantMode.THIS_INPUT),
dict(test_gradient=False),
),
])

check([
(
"fwd_quant=T",
"fwd_quant=F",
aqt_dg_full(aqt.DequantMode.OUTPUT, use_fwd_quant=False),
dict(),
),
(
"fwd_quant=F",
"fwd_quant=T",
aqt_dg_full(aqt.DequantMode.OUTPUT, use_fwd_quant=True),
dict(),
),
])
if test_asym and is_quantized:
# Asymmetric quantization does not currently support forward quantization.
with self.assertRaisesRegex(NotImplementedError, r"biases.*forward"):
check([
(
"fwd_quant=F",
aqt_dg_full(
aqt.DequantMode.OUTPUT,
use_fwd_quant=False,
use_asymmetric=True,
),
dict(),
),
(
"fwd_quant=T",
aqt_dg_full(
aqt.DequantMode.OUTPUT,
use_fwd_quant=True,
use_asymmetric=True,
),
dict(),
),
])

check([
(
Expand All @@ -617,18 +701,37 @@ def test_dot_general_calibration_with_contracting_axis(
dict(),
),
(
"default ",
"FQ ",
aqt_dg_full(
aqt.DequantMode.THIS_INPUT,
local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2),
),
dict(),
),
])
if test_asym:
check([
(
"default ",
asym_dg_full(
aqt.DequantMode.OUTPUT,
local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2),
),
dict(),
),
(
"FQ ",
asym_dg_full(
aqt.DequantMode.THIS_INPUT,
local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2),
),
dict(),
),
])

if isinstance(
readonly_dg.fwd.dg_quantizer.lhs.numerics,
int_numerics.IntNumerics,
int_numerics.SymIntNumerics,
):
check([
(
Expand Down Expand Up @@ -1057,7 +1160,7 @@ def test_local_aqt(self, shard_count, lhs, expected_product):
def test_per_tensor(self):
# TODO(lew): bits=8 started failing in VLP colab due x/x != 1.0 sometimes
bits = 4
my_numerics = int_numerics.IntNumerics(
my_numerics = int_numerics.SymIntNumerics(
bits=bits,
preserve_zero=True,
preserve_max_val=False,
Expand Down
Loading

0 comments on commit 33de2e9

Please sign in to comment.