Skip to content

Commit

Permalink
Quantized training, design general custom einsum.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 624999870
  • Loading branch information
rybakov authored and pax authors committed Apr 15, 2024
1 parent cfe086d commit 08fed24
Show file tree
Hide file tree
Showing 4 changed files with 392 additions and 79 deletions.
337 changes: 271 additions & 66 deletions praxis/layers/quantization/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,88 +233,293 @@ def differentiable_dot_general_int(lhs_, rhs_):
return y, y_tangent


@jax.custom_vjp
def _get_reduction(left: str, right: str) -> list[int]:
"""Gets reduction dims based on left, rights parts of einsum eqn.
Args:
left: Eqn for left tensor of eisnum eqn.
right: Eqn for left tensor of eisnum eqn.
Returns:
List of rediction dims.
"""
reduction = []
for i, ch in enumerate(left):
if ch == '.':
raise ValueError(f'left is not normalized: {left}')
if ch in right:
reduction.append(i)
return reduction


def get_scale_eqn(left: str, right: str) -> str:
"""Gets scale eqn for einsum.
Args:
left: Eqn for left tensor of eisnum eqn.
right: Eqn for left tensor of eisnum eqn.
Returns:
Scale eqn.
"""
eqn = []
for ch in right:
if ch in left:
eqn.append(ch)
return ''.join(eqn)


def custom_einsum(
x: JTensor,
w: JTensor,
*,
eqn: str,
prng_key: jax.Array,
bits_fwd: int | None = 8,
bits_bwd: int | None = 8,
params: quantization_hparams.QuantizedTrainingParams,
) -> jnp.ndarray:
return jnp.einsum('abc,cd->abd', x, w)
"""Einsum with custom forward and backward propagation.
It has custom einsum for forward propagation with quantization
of both weights and activations. It also has custom backward propagation for
the einsum with quantization.
Args:
x: Activation (left side) tensor
w: Weights (right side).
eqn: Einsum equation.
prng_key: Rng key.
params: Quantization parameters.
Returns:
Einsum output.
"""

def custom_einsum_fwd(
x: JTensor,
w: JTensor,
prng_key: jax.Array,
bits_fwd: int | None,
bits_bwd: int | None,
):
"""Custom forward pass for custom_einsum."""
# Currently support only abc,cd->abd
# TODO(jianlijianli): make this more general.
assert x.ndim == 3
assert w.ndim == 2
assert x.shape[2] == w.shape[0]
if bits_fwd is not None:
qx, sx, _ = reduce_precision(x, bits=8, contract_dims=[2])
qw, sw, _ = reduce_precision(w, bits=8, contract_dims=[0])
else:
qx, sx = x, None
qw, sw = w, None
@jax.jit
def _custom_einsum_fwd(
x: JTensor,
w: JTensor,
key: jax.Array,
):
"""Custom forward pass for custom_einsum."""
# TODO(rybakov) refactor the code.
input_str, output_str, _ = einsum_parser.parse_einsum_input((eqn, x, w))
left, right = input_str.split(',')[0], input_str.split(',')[1]

acc = jnp.einsum('abc,cd->abd', qx, qw, preferred_element_type=jnp.bfloat16)
if params.bits_fwd is not None:
qx, sx, _ = reduce_precision(
x, bits=params.bits_fwd, contract_dims=_get_reduction(left, right)
)
qw, sw, _ = reduce_precision(
w, bits=params.bits_fwd, contract_dims=_get_reduction(right, left)
)
sx = jnp.squeeze(sx)
sw = jnp.squeeze(sw)
else:
if params.cast_bits_fwd is not None:
# Explicit casting to estimate speed up opportunity. It will do
# only casting without scale estimation and scale multiplication.
if params.cast_bits_fwd == 8:
qx, qw = x.astype(jnp.int8), w.astype(jnp.int8)
elif params.cast_bits_fwd == 4:
qx, qw = x.astype(jnp.int4), w.astype(jnp.int4)
else:
raise ValueError(
f'Unsupported params.cast_bits_fwd: {params.cast_bits_fwd}'
)
else:
# No quantization.
qx, qw = x, w

if bits_fwd is not None:
res = jnp.multiply(sx, jnp.multiply(acc, sw))
else:
res = acc

return res, (qx, qw, sx, sw, prng_key, bits_bwd)


def custom_einsum_bwd(res: Any, g: Any):
"""Custom gradient for custom_einsum."""
qx, qw, sx, sw, prng_key, bits_bwd = res
if bits_bwd is not None:
g_with_sw = jnp.multiply(g, sw)
g_with_sx = jnp.multiply(g, sx)
qg_for_w, sg_for_w, _ = reduce_precision(
t=g_with_sw,
bits=bits_bwd,
contract_dims=[2],
random_rounding=True,
key=prng_key,
# e.g. 'abc,cd->abd'
acc = jnp.einsum(
eqn, qx, qw, preferred_element_type=params.einsum_output_dtype
)

if params.bits_fwd is not None:
scale_right = get_scale_eqn(output_str, right)
scale_eqn = output_str + ',' + scale_right + '->' + output_str
res = jnp.einsum(
scale_eqn,
acc,
sw,
preferred_element_type=params.einsum_scale_output_dtype,
)

scale_left = get_scale_eqn(output_str, left)
scale_eqn = output_str + ',' + scale_left + '->' + output_str
res = jnp.einsum(
scale_eqn,
res,
sx,
preferred_element_type=params.einsum_scale_output_dtype,
)
else:
res = acc
sx, sw = None, None

if params.bwd_output_dtype is not None:
res = res.astype(params.bwd_output_dtype)

return res, (qx, qw, sx, sw, key)

@jax.custom_vjp
def _custom_einsum(
x: JTensor,
w: JTensor,
key: jax.Array, # pylint: disable=unused-argument
) -> jnp.ndarray:
# About eqn: str,
# custom_vjp can not have str argument because of xla/jax limitations
# it will generate error: type <class 'str'> is not a valid JAX type
# so moved "eqn: str" up.
return jnp.einsum(eqn, x, w)

@jax.jit
def _custom_einsum_bwd(res: Any, g: Any):
"""Custom gradient for custom_einsum."""
# Below comments are based on eqn=abc,cd->abd.
qx, qw, sx, sw, key = res
assert not (
params.cast_bits_bwd is not None and (sx is not None or sw is not None)
)

input_str, output_str, _ = einsum_parser.parse_einsum_input((eqn, qx, qw))
left1, right1 = input_str.split(',')[0], input_str.split(',')[1]
gx_eqn = output_str + ',' + right1 + '->' + left1
gw_eqn = left1 + ',' + output_str + '->' + right1

if params.cast_bits_bwd is not None:
# Explicit casting to estimate speed up opportunity. It will do
# only casting without scale estimation and scale multiplication.
if params.cast_bits_bwd == 8:
g = g.astype(jnp.int8)
elif params.cast_bits_bwd == 4:
g = g.astype(jnp.int4)
else:
raise ValueError(
f'Unsupported params.cast_bits_bwd: {params.cast_bits_bwd}'
)
if sw is not None:
scale_right = get_scale_eqn(output_str, right1)
scale_eqn = output_str + ',' + scale_right + '->' + output_str
g_with_sw = jnp.einsum(
scale_eqn, g, sw, preferred_element_type=params.einsum_output_dtype
)
else:
g_with_sw = g

if sx is not None:
scale_left = get_scale_eqn(output_str, left1)
scale_eqn = output_str + ',' + scale_left + '->' + output_str
g_with_sx = jnp.einsum(
scale_eqn, g, sx, preferred_element_type=params.einsum_output_dtype
)
else:
g_with_sx = g

if params.bits_bwd is not None:
input_str, _, _ = einsum_parser.parse_einsum_input(
(gx_eqn, g_with_sw, qw)
)
left, right = input_str.split(',')[0], input_str.split(',')[1]
contract_dims = _get_reduction(left, right)
qg_for_w, sg_for_w, _ = reduce_precision(
t=g_with_sw,
bits=params.bits_bwd,
contract_dims=contract_dims,
random_rounding=params.random_rounding_bwd,
key=key,
)
sg_for_w = jnp.squeeze(sg_for_w)
else:
sg_for_w = None
if params.cast_bits_bwd is not None:
# Explicit casting to estimate speed up opportunity. It will do
# only casting without scale estimation and scale multiplication.
if params.cast_bits_fwd == 8:
qg_for_w = g_with_sw.astype(jnp.int8)
elif params.cast_bits_fwd == 4:
qg_for_w = g_with_sw.astype(jnp.int4)
else:
raise ValueError(
f'Unsupported params.cast_bits_fwd: {params.cast_bits_fwd}'
)
else:
qg_for_w = g_with_sw

if params.bits_bwd is not None:
input_str, _, _ = einsum_parser.parse_einsum_input(
(gw_eqn, qx, g_with_sx)
)
left, right = input_str.split(',')[0], input_str.split(',')[1]
contract_dims = _get_reduction(right, left)
qg_for_x, sg_for_x, _ = reduce_precision(
t=g_with_sx,
bits=params.bits_bwd,
contract_dims=contract_dims,
random_rounding=params.random_rounding_bwd,
key=key,
)
sg_for_x = jnp.squeeze(sg_for_x)
else:
sg_for_x = None
if params.cast_bits_bwd is not None:
# Explicit casting to estimate speed up opportunity. It will do
# only casting without scale estimation and scale multiplication.
if params.cast_bits_fwd == 8:
qg_for_x = g_with_sx.astype(jnp.int8)
elif params.cast_bits_fwd == 4:
qg_for_x = g_with_sx.astype(jnp.int4)
else:
raise ValueError(
f'Unsupported params.cast_bits_fwd: {params.cast_bits_fwd}'
)
else:
qg_for_x = g_with_sx

# abd,cd->abc
gx = jnp.einsum(
gx_eqn, qg_for_w, qw, preferred_element_type=params.einsum_output_dtype
)
qg_for_x, sg_for_x, _ = reduce_precision(
t=g_with_sx,
bits=bits_bwd,
contract_dims=[0, 1],
random_rounding=True,
key=prng_key,

# abc,abd->cd
gw = jnp.einsum(
gw_eqn, qx, qg_for_x, preferred_element_type=params.einsum_output_dtype
)
else:
qg_for_w = g
qg_for_x = g

gx = jnp.einsum(
'abd,cd->abc', qg_for_w, qw, preferred_element_type=jnp.bfloat16
)
gw = jnp.einsum(
'abc,abd->cd', qx, qg_for_x, preferred_element_type=jnp.bfloat16
)
if sg_for_w is not None:
scale_left = get_scale_eqn(left1, output_str)
scale_eqn = left1 + ',' + scale_left + '->' + left1
gx = jnp.einsum(
scale_eqn,
gx,
sg_for_w,
preferred_element_type=params.einsum_scale_output_dtype,
)

if sg_for_x is not None:
scale = get_scale_eqn(right1, output_str)
scale_eqn = right1 + ',' + scale + '->' + right1
gw = jnp.einsum(
scale_eqn,
gw,
sg_for_x,
preferred_element_type=params.einsum_scale_output_dtype,
)

if bits_bwd is not None:
gx = jnp.multiply(gx, sg_for_w)
gw = jnp.multiply(gw, jnp.squeeze(sg_for_x))
if params.bwd_output_dtype is not None:
gx = gx.astype(params.bwd_output_dtype)
gw = gw.astype(params.bwd_output_dtype)

# Custom VJP bwd rule must produce an output with the same container (pytree)
# structure as the args tuple of the primal function (custom_einsum).
return gx, gw, None, None, None
# Custom VJP bwd rule must produce an output with the same container pytree
# structure as the args tuple of the primal function (custom_einsum).
return gx, gw, None

assert not (params.bits_fwd is not None and params.cast_bits_fwd is not None)
assert not (params.bits_bwd is not None and params.cast_bits_bwd is not None)

custom_einsum.defvjp(custom_einsum_fwd, custom_einsum_bwd)
_custom_einsum.defvjp(_custom_einsum_fwd, _custom_einsum_bwd)
out = _custom_einsum(x, w, prng_key)
return out


def einsum(
Expand Down
Loading

0 comments on commit 08fed24

Please sign in to comment.