Skip to content

Commit

Permalink
Fix a numerical bug in operations.einsum when using asym weights an…
Browse files Browse the repository at this point in the history
…d sym activations.

Also consolidates dequantization logic for weights and activations, allowing support for arbitrary weight shapes when `zp_act` is not `None`.

Tested by adding a test for this case and by examining the outputs of Gemma2B IT with asymmetric `int8` PTQ and symmetric `int8` activations.

PiperOrigin-RevId: 633644780
  • Loading branch information
phoenix-meadowlark authored and pax authors committed May 14, 2024
1 parent 6450194 commit a0362d4
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 40 deletions.
52 changes: 37 additions & 15 deletions praxis/layers/quantization/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,26 @@ def _get_expand_dims_lhs(eqn: str) -> list[int]:
return filling_dims


def _dequantize(
x: JTensor,
scale: JTensor | None,
zp: JTensor | None,
contraction_dims: Sequence[int],
) -> JTensor:
"""Dequantize x, unsqueezing with contraction_dims if needed."""
if scale is None and zp is not None:
raise ValueError('scale must not be None if zp is not None')
if scale is not None:
if scale.ndim > 0 and x.ndim != scale.ndim:
scale = jnp.expand_dims(scale, contraction_dims)
x *= scale
if zp is not None:
if zp.ndim > 0 and x.ndim != zp.ndim:
zp = jnp.expand_dims(zp, contraction_dims)
x -= zp
return x


def get_min_max(
bits: int = 8,
unsigned: bool = False,
Expand Down Expand Up @@ -176,6 +196,11 @@ def compute_offset(eqn_normalized: str, x: JTensor, zp: JTensor) -> JTensor:
'eqn_normalized should not contain broadcast ellipsis "...". Use'
' opt_einsum.parser to normalize the eqn before using this function.'
)
if x.dtype in INT_TYPES:
raise ValueError(f'x should not be quantized, but got {x.dtype=}')
if zp.dtype in INT_TYPES:
raise ValueError(f'zp should not be quantized, but got {zp.dtype=}')

ins, out = eqn_normalized.split('->')
lhs, rhs = ins.split(',')
rhs_out_dims = ''.join([c for c in out if c in rhs])
Expand Down Expand Up @@ -570,14 +595,11 @@ def einsum(
# Non performent equation for inference testing purposes
# TODO: b/305735188 - Improve the performance by using the integer einsum op.
if zp_act is not None:
dequantized_x = jnp.multiply(x, scale_act) - zp_act
# explicit broadcast if necessary.
if w.ndim == 3 and scale.ndim == 1:
scale = jnp.expand_dims(scale, (1, 2))
dequantized_w = jnp.multiply(w, scale)
if zp is not None:
dequantized_w = dequantized_w - zp
return jnp.einsum(eqn, dequantized_x, dequantized_w)
x_dequantized = _dequantize(
x, scale_act, zp_act, eqn_to_activation_contract_dims(eqn)
)
w_dequantized = _dequantize(w, scale, zp, eqn_to_weight_contract_dims(eqn))
return jnp.einsum(eqn, x_dequantized, w_dequantized)

if (
jax.dtypes.scalar_type_of(w.dtype) == float
Expand Down Expand Up @@ -610,10 +632,7 @@ def einsum(
if scale_act.ndim == 0:
scale *= scale_act
else:
filling_dims_lhs = _get_expand_dims_lhs(eqn)
if filling_dims_lhs:
scale_act = jnp.expand_dims(scale_act, filling_dims_lhs)
ret = jnp.multiply(ret, scale_act)
ret *= jnp.expand_dims(scale_act, _get_expand_dims_lhs(eqn))

# Potentially expand dimensions of scale to match einsum output.
filling_dims_rhs = _get_expand_dims_rhs(eqn)
Expand All @@ -626,10 +645,13 @@ def einsum(
ret = jnp.multiply(ret, scale)

if zp is not None:
x_dequantized = _dequantize(
x, scale_act, zp_act, eqn_to_activation_contract_dims(eqn)
)
if zp_eqn is not None:
offset = jnp.einsum(zp_eqn, x, zp)
offset = jnp.einsum(zp_eqn, x_dequantized, zp)
else:
offset = compute_offset(eqn_normalized, x, zp)
offset = compute_offset(eqn_normalized, x_dequantized, zp)
ret = ret - offset

return ret
Expand Down Expand Up @@ -1061,7 +1083,7 @@ def reduce_precision_activation(
contract_dims,
need_gradient,
bits,
False,
optimization_on_bound=False,
use_symmetric=symmetric,
percentile=percentile,
)
Expand Down
50 changes: 25 additions & 25 deletions praxis/layers/quantization/operations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,39 +156,39 @@ def test_quantized_einsum_with_zp(self, eqn):
)
self.assertAllClose(ret, expected, rtol=0.02, atol=0.02)

@parameterized.named_parameters(
('eqn_with_dot', '...y,yz->...z'),
@parameterized.product(
sym_weights=[True, False],
sym_acts=[True, False],
zp_eqn=[None, 'byh,z->bzh'],
)
def test_quantized_einsum_with_asym_weight_act(self, eqn):
w = jax.random.uniform(jax.random.PRNGKey(0), (4, 3))
x = jax.random.uniform(jax.random.PRNGKey(0), (2, 4))
def test_quantized_einsum_with_mixed_symmetry(
self, sym_weights, sym_acts, zp_eqn
):
eqn = '...yh,zy->...zh'
w = jax.random.uniform(jax.random.PRNGKey(0), (3, 4))
x = jax.random.uniform(jax.random.PRNGKey(0), (2, 4, 5))
qw, sw, zpw = operations.reduce_einsum_weight_precision(
eqn, w, use_symmetric=False
eqn, w, use_symmetric=sym_weights
)
qx, sx, zpx = operations.reduce_einsum_activation_precision(
eqn, x, symmetric=False
eqn,
x,
symmetric=sym_acts,
per_channel=True,
)

ret = operations.einsum(eqn, qx, qw, sw, zpw, sx, zpx)
expected = jnp.einsum(eqn, x, w)
self.assertAllClose(ret, expected, rtol=0.02, atol=0.02)

@parameterized.named_parameters(
('eqn_with_dot', '...y,yz->...z'),
)
def test_quantized_einsum_with_aym_weight_asym_act(self, eqn):
w = jax.random.uniform(jax.random.PRNGKey(0), (4, 3))
x = jax.random.uniform(jax.random.PRNGKey(0), (2, 4))
qw, sw, zpw = operations.reduce_einsum_weight_precision(
eqn, w, use_symmetric=True
)
qx, sx, zpx = operations.reduce_einsum_activation_precision(
eqn, x, symmetric=False
ret = operations.einsum(
eqn=eqn,
x=qx,
w=qw,
scale=sw,
zp=zpw,
scale_act=sx,
zp_act=zpx,
zp_eqn=zp_eqn,
)

ret = operations.einsum(eqn, qx, qw, sw, zpw, sx, zpx)
expected = jnp.einsum(eqn, x, w)
self.assertAllClose(ret, expected, rtol=0.02, atol=0.02)
self.assertAllClose(ret, expected, rtol=0.02, atol=0.01)

@parameterized.parameters(
('ab,bc->ac', (10, 4), (4, 5)),
Expand Down

0 comments on commit a0362d4

Please sign in to comment.