Skip to content

Commit

Permalink
Put numerics and numerics-related logic into QTensor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636006198
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed May 22, 2024
1 parent 02baee0 commit e6267ff
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 67 deletions.
28 changes: 9 additions & 19 deletions aqt/jax/v2/aqt_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,18 @@ def quant(
) -> tuple[aqt_tensor.QTensor, aqt_tensor.GradientFn]:
"""The core quantizing function."""
qt = self.calibrate(x, calibration_axes=calibration_axes)
qt, quant_grad = self.calculate_qvalue(x, qt)
return qt, quant_grad
return self.calculate_qvalue(x, qt)

def calibrate(self, x, *, calibration_axes) -> aqt_tensor.QTensor:
"""Create incomplete QTensor with only quantization parameters."""
if isinstance(self.numerics, no_numerics.NoNumerics):
qt = aqt_tensor.QTensor(
qvalue=x, scale=[], scale_t=None, dequant_dtype=x.dtype
return aqt_tensor.QTensor(
qvalue=x,
scale=[],
scale_t=None,
dequant_dtype=x.dtype,
numerics=self.numerics,
)
return qt

dequant_dtype = x.dtype
# TODO(lew): We should cast earlier. xhs_q should be in cfg.xhs.dtype
Expand Down Expand Up @@ -94,6 +96,7 @@ def calibrate(self, x, *, calibration_axes) -> aqt_tensor.QTensor:
scale=[scale],
scale_t=None,
dequant_dtype=dequant_dtype,
numerics=self.numerics,
)
return qt

Expand All @@ -103,20 +106,7 @@ def calculate_qvalue(
qt: aqt_tensor.QTensor
) -> tuple[aqt_tensor.QTensor, aqt_tensor.GradientFn]:
"""Uses the quantization parameters in qt to quantize x."""
if isinstance(self.numerics, no_numerics.NoNumerics):
return qt, None

# TODO: b/333984742 - make numeric as a member of QTensor, and put
# numerics-related logics into the QTensor.
qt = qt.quant(x)

# TODO(lew): A logical thing would be if this call was part of
# QTensor.quant.
x_q, res = self.numerics.vjp_fwd(qt.qvalue, self.context)
quant_grad = jax.tree_util.Partial(self.numerics.vjp_bwd, res)

qt = qt.replace(qvalue=x_q)
return qt, quant_grad
return qt.quant(x, self.context)


def quantizer_make(
Expand Down
64 changes: 45 additions & 19 deletions aqt/jax/v2/aqt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
import typing
from typing import Any, Callable, Optional, Sequence, TypeAlias
from aqt.jax.v2 import utils
from aqt.jax.v2.numerics import int_numerics
from aqt.jax.v2.numerics import no_numerics
from aqt.jax.v2.numerics import numerics
import flax.cursor
import flax.struct
import jax
import jax.numpy as jnp
import jax.typing as jax_typing
from typing_extensions import Self # for python version < 3.11

AbstractAqtNumerics = numerics.AqtNumerics
GradientFn = Callable[..., Any] | None # None when there is no numerics
_MSG_NO_QVALUE = (
'QTensor does not have qvalue, but it is asked to access the qvalue.'
Expand Down Expand Up @@ -75,6 +79,9 @@ class QTensor:
pytree_node=False, default=None
)

# Numerics of the QTensor.
numerics: AbstractAqtNumerics = utils.static_field(default=None)

@property
def dtype(self) -> jnp.dtype | None:
return self.dequant_dtype
Expand All @@ -89,8 +96,12 @@ def without_qvalue(self) -> Self:
def astype(self, dtype: jnp.dtype) -> Self:
return self.replace(dequant_dtype=dtype) # pytype: disable=attribute-error

def quant(self, x):
"""Quantizes the QTensor."""
def quant(self, x, context: utils.Context) -> tuple[Self, GradientFn]:
"""Uses the quantization parameters in qt to quantize x."""
assert self.numerics is not None, 'Missing numerics used for quantization.'
if isinstance(self.numerics, no_numerics.NoNumerics):
return self, None

assert not self.is_full(), 'Already quantized QTensor.'
assert self.scale is not None, 'Missing scales to be used for quantization.'

Expand All @@ -101,9 +112,10 @@ def quant(self, x):
s_inv = jnp.where(jnp.isinf(s_inv), jnp.ones_like(s_inv), s_inv)
qvalue = qvalue * s_inv

# TODO(lew): We should apply numerics here, so that 'quant' function
# Can be considered a part of API.
return self.replace(qvalue=qvalue) # pytype: disable=attribute-error
x_q, res = self.numerics.vjp_fwd(qvalue, context)
quant_grad = jax.tree_util.Partial(self.numerics.vjp_bwd, res)

return self.replace(qvalue=x_q), quant_grad # pytype: disable=attribute-error

def dequant(self) -> jnp.ndarray:
"""Dequantizes the QTensor."""
Expand Down Expand Up @@ -152,26 +164,14 @@ def __len__(self) -> int:


def zeros(
shape: Sequence[int],
*,
container_dtype: jnp.dtype,
dequant_dtype: jnp.dtype = jnp.bfloat16,
) -> QTensor:
return QTensor(
qvalue=jnp.zeros(shape, dtype=container_dtype),
scale=[],
scale_t=None,
dequant_dtype=dequant_dtype,
)


def zeros_with_scale(
shape: Sequence[int],
calibration_axis: Sequence[utils.AxisIdx],
*,
container_dtype: jnp.dtype,
scale_dtype: jnp.dtype | None = None,
dequant_dtype: jnp.dtype = jnp.bfloat16,
n_bits: int | None = None,
preserve_max_val: bool = False,
) -> QTensor:
"""Initializes a QTensor with empty qvalue along with empty scale value."""
scale_shape = list(shape)
Expand All @@ -186,13 +186,16 @@ def zeros_with_scale(
scale=[jnp.ones(scale_shape, dtype=scale_dtype)],
scale_t=None,
dequant_dtype=dequant_dtype,
numerics=_get_numerics(n_bits, preserve_max_val),
)


def partition_spec(
partitions: Sequence[Any],
calibration_axis: Sequence[utils.AxisIdx],
dtype: jnp.dtype,
n_bits: int | None,
preserve_max_val: bool = False,
) -> QTensor:
"""Returns a QTensor filled with partition specs."""
scale_partitions = list(partitions)
Expand All @@ -203,6 +206,26 @@ def partition_spec(
scale=[jax.sharding.PartitionSpec(*scale_partitions)],
scale_t=None,
dequant_dtype=dtype,
numerics=_get_numerics(n_bits, preserve_max_val),
)


def _get_numerics(
n_bits: int | None, preserve_max_val: bool = False
) -> numerics.AqtNumerics:
if n_bits is None:
return no_numerics.NoNumerics()
pz = False if n_bits == 1 else True
dtype = utils.infer_dtype_from_bits(n_bits) if pz else None
return int_numerics.IntNumerics(
bits=n_bits,
preserve_zero=pz,
preserve_max_val=preserve_max_val,
clip=True,
round=True,
noise_fn=None,
clip_gradient=False, # This can be disabled when using abs-max scaling.
dtype=dtype,
)


Expand Down Expand Up @@ -242,6 +265,7 @@ def get_sliced_scales(scale):
scale=[get_sliced_scales(s) for s in operand.scale],
scale_t=None,
dequant_dtype=operand.dequant_dtype,
numerics=operand.numerics,
)


Expand Down Expand Up @@ -290,6 +314,7 @@ def dynamic_update_slice(
scale=scales,
scale_t=None,
dequant_dtype=operand.dequant_dtype,
numerics=operand.numerics,
)


Expand All @@ -306,4 +331,5 @@ def update_frame(operand: QTensor, frame: int, update: QTensor) -> QTensor:
],
scale_t=None,
dequant_dtype=operand.dequant_dtype,
numerics=operand.numerics,
)
5 changes: 4 additions & 1 deletion aqt/jax/v2/aqt_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def test_dynamic_update(self):

def test_dtype(self):
qt = aqt_tensor.zeros(
shape=(1,), container_dtype=jnp.int8, dequant_dtype=jnp.float32
shape=(1,),
calibration_axis=(),
container_dtype=jnp.int8,
dequant_dtype=jnp.float32,
)
self.assertEqual(qt.dtype, jnp.float32)
self.assertEqual(qt.dequant_dtype, jnp.float32)
Expand Down
Loading

0 comments on commit e6267ff

Please sign in to comment.