Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support asymmetric fake-quantization to AQTv2.
Integration of native quantization with biases will require computing the cross terms, likely in the AQT operation quantizer (`DefaultGeneralQuantizer`). Itemized changes: - `AqtNumerics`: - Rename `AqtNumerics.abs_val_mapped_to` to `AqtNumerics.get_scaled_bound` to reflect that the calibration bound may be span the whole quantization range (instead of ~half the range for a strictly linear transformation). - Refactor `IntNumerics` into `BaseIntNumerics`, `SymIntNumerics` and `AsymIntNumerics`. - `AsymIntNumerics` doesn't need `preserve_zero` or `preserve_max_val`. - Add `MinMaxCalibration`. I additionally tested this change by training MNIST models using `flax_e2e_model`. With symmetric quantization the model fails to converge for `config.config_v4(fwd_bits=2, dlhs_bits=None, drhs_bits=None)` (due to `NaN` losses). With asymmetric quantization the model converges even with `config.config_v4(fwd_bits=2, dlhs_bits=2, drhs_bits=4)`. PiperOrigin-RevId: 651580879
- Loading branch information