Skip to content

Commit

Permalink
A lot of time positions of kv states are just paddings. We can identi…
Browse files Browse the repository at this point in the history
…fy the start and slice widths of non-padding kv states by checking atten_mask and the present time_step during extend_step. We then do qk_einsum and pv_enisum only over the non-paddding slices. However, jax won't allow dynamic_slice with variable widths, we have to predefine `chunked_one_step_attn_num_seq_split` number of partial functions, each with fixed slice width, and use jax.lax.switch to pick the corresponding function based on the actual dynamic non-padding width from the input batch.

PiperOrigin-RevId: 614069997
  • Loading branch information
bignamehyp authored and pax authors committed Mar 9, 2024
1 parent f48a941 commit 0056f19
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
3 changes: 3 additions & 0 deletions praxis/layers/gpu_fast_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def _dot_atten_one_step(
value_state_name: str,
atten_mask: JTensor,
relative_bias: JTensor | None = None,
time_step: JTensor | None = None,
) -> tuple[JTensor, JTensor]:
"""Dot attention function for queries with 1 time step.
Expand All @@ -404,11 +405,13 @@ def _dot_atten_one_step(
be of size 1, if the mask is shared by all items in the batch (e.g.,
only a causal mask).
relative_bias: Relative bias of shape [1|B, N, 1, S].
time_step: A scalar. The time step tensor.
Returns:
encoded: JTensor of shape [B, N, H].
probs: JTensor of shape [B, N, S].
"""
del time_step
if not self.use_flash_decoding:
return super()._dot_atten_one_step(
query,
Expand Down
22 changes: 15 additions & 7 deletions praxis/layers/multi_query_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ def _dot_atten_one_step(
value_state_name: str,
atten_mask: JTensor,
relative_bias: JTensor | None = None,
time_step: JTensor | None = None,
) -> tuple[JTensor, JTensor]:
"""Dot attention function for queries with 1 time step.
Expand All @@ -616,6 +617,7 @@ def _dot_atten_one_step(
is allowed to be of size 1, if the mask is shared by all items in the
batch (e.g., only a causal mask).
relative_bias: Relative bias of shape [1/B, N, 1, S].
time_step: A scalar or JTensor. Current time-step, 0-based.
Returns:
encoded: JTensor of shape [B, N, H].
Expand All @@ -632,7 +634,7 @@ def _dot_atten_one_step(
key = self._shard_blh(key)
value = self._shard_blh(value)
encoded, probs = self._dot_atten_one_step_from_qkv(
query, key, value, atten_mask, relative_bias
query, key, value, atten_mask, relative_bias, time_step
)
return self._shard_bnh(encoded), probs
else:
Expand All @@ -651,9 +653,9 @@ def _dot_atten_one_step(
with self._context_for_kv_vmap():
encoded, probs = jax.vmap(
self._dot_atten_one_step_from_qkv,
in_axes=(1, 2, 2, None, 1),
in_axes=(1, 2, 2, None, 1, None),
out_axes=(1, 1),
)(v_q, key, value, atten_mask, v_rb)
)(v_q, key, value, atten_mask, v_rb, time_step)
encoded = self._shard_bnh(jnp.reshape(encoded, (b, n, h)))
probs = jnp.reshape(probs, (b, n, -1))
return encoded, probs
Expand All @@ -665,8 +667,10 @@ def _dot_atten_one_step_from_qkv(
value: JTensor,
atten_mask: JTensor,
relative_bias: JTensor | None,
time_step: JTensor | None = None,
) -> tuple[JTensor, JTensor]:
"""_dot_atten_one_step with tensors instead of state names."""
del time_step
# query is 3d.
extend_one_step = len(query.shape) == 3
b, s, h = key.shape
Expand Down Expand Up @@ -1030,10 +1034,14 @@ def _extend_decode_state_and_shard_blh(name: str,
else:
relative_bias = None

encoded, atten_prob = self._dot_atten_one_step(query_proj,
key_state_name,
value_state_name, atten_mask,
relative_bias)
encoded, atten_prob = self._dot_atten_one_step(
query_proj,
key_state_name,
value_state_name,
atten_mask,
relative_bias,
time_step,
)
# TODO(yonghui): return atten_probs back to the caller.
del atten_prob
# Post projection.
Expand Down

0 comments on commit 0056f19

Please sign in to comment.