Skip to content

Commit

Permalink
Makes the dtype of intermediate attention masks the same as `atten_ma…
Browse files Browse the repository at this point in the history
…sk`. This avoids the unnecessary memory usage of float32 when `atten_mask` is bfloat16.

PiperOrigin-RevId: 665756973
  • Loading branch information
pen-li authored and pax authors committed Aug 21, 2024
1 parent 6b7edb2 commit f7c5e7f
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions praxis/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3272,7 +3272,10 @@ def _dot_atten(
query_blocks = convert_to_block(query, block_size=block_size)
_, _, w, _, _ = query_blocks.shape

minus_inf = py_utils.get_large_negative_number(jnp.float32)
# Avoids large values when dtype is float64, which causes numerical issues.
minus_inf = py_utils.get_large_negative_number(
jnp.float32 if atten_mask.dtype == jnp.float64 else atten_mask.dtype
)

if atten_mask.shape[2] == 1:
# Attention mask with shape [1|B, 1, 1, S]
Expand All @@ -3296,9 +3299,7 @@ def _dot_atten(

# -> [B, U, W, S]
mask_block_context = convert_to_block( # pytype: disable=wrong-arg-types # jax-ndarray
atten_mask[:, 0].astype(jnp.float32),
block_size=block_size,
padding_val=minus_inf,
atten_mask[:, 0], block_size=block_size, padding_val=minus_inf
)
mask_block_context = jnp.reshape(mask_block_context, [b * u * w, s])
# -> [B, U, W, U, C]
Expand Down Expand Up @@ -3403,15 +3404,14 @@ def _dot_atten_one_step(
# right_context can be non-zero if is_cross_attention is True.
f = self.left_context + self.right_context

# Avoids large values when dtype is float64, which causes numerical issues.
minus_inf = py_utils.get_large_negative_number(
jnp.float32 if atten_mask.dtype == jnp.float64 else atten_mask.dtype
)

key = _padded_slice(key, time_step + 1 - l, f, 1, 0.0)
value = _padded_slice(value, time_step + 1 - l, f, 1, 0.0)
atten_mask = _padded_slice(
atten_mask,
time_step + 1 - l,
f,
-1,
py_utils.get_large_negative_number(atten_mask.dtype),
)
atten_mask = _padded_slice(atten_mask, time_step + 1 - l, f, -1, minus_inf)

b, f, n, h = key.shape
asserts.eq(f, self.left_context + self.right_context)
Expand Down Expand Up @@ -3451,10 +3451,9 @@ def _dot_atten_one_step(

if self.zero_fully_masked:
# Return zeros for tokens which don't attend anything.
fully_masked = jnp.all(
atten_mask < py_utils.get_large_negative_number(jnp.float32) / 2,
axis=-1,
)[..., jnp.newaxis]
fully_masked = jnp.all(atten_mask < minus_inf / 2, axis=-1)[
..., jnp.newaxis
]
encoded *= 1 - fully_masked

encoded = self._shard_bnh(encoded)
Expand Down

0 comments on commit f7c5e7f

Please sign in to comment.