From f7c5e7ffc81a2a57182eaa3a5666742f50e7c6c3 Mon Sep 17 00:00:00 2001 From: Yunpeng Li Date: Wed, 21 Aug 2024 01:23:20 -0700 Subject: [PATCH] Makes the dtype of intermediate attention masks the same as `atten_mask`. This avoids the unnecessary memory usage of float32 when `atten_mask` is bfloat16. PiperOrigin-RevId: 665756973 --- praxis/layers/attentions.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/praxis/layers/attentions.py b/praxis/layers/attentions.py index ed9bc9fd..515b7cac 100644 --- a/praxis/layers/attentions.py +++ b/praxis/layers/attentions.py @@ -3272,7 +3272,10 @@ def _dot_atten( query_blocks = convert_to_block(query, block_size=block_size) _, _, w, _, _ = query_blocks.shape - minus_inf = py_utils.get_large_negative_number(jnp.float32) + # Avoids large values when dtype is float64, which causes numerical issues. + minus_inf = py_utils.get_large_negative_number( + jnp.float32 if atten_mask.dtype == jnp.float64 else atten_mask.dtype + ) if atten_mask.shape[2] == 1: # Attention mask with shape [1|B, 1, 1, S] @@ -3296,9 +3299,7 @@ def _dot_atten( # -> [B, U, W, S] mask_block_context = convert_to_block( # pytype: disable=wrong-arg-types # jax-ndarray - atten_mask[:, 0].astype(jnp.float32), - block_size=block_size, - padding_val=minus_inf, + atten_mask[:, 0], block_size=block_size, padding_val=minus_inf ) mask_block_context = jnp.reshape(mask_block_context, [b * u * w, s]) # -> [B, U, W, U, C] @@ -3403,15 +3404,14 @@ def _dot_atten_one_step( # right_context can be non-zero if is_cross_attention is True. f = self.left_context + self.right_context + # Avoids large values when dtype is float64, which causes numerical issues. + minus_inf = py_utils.get_large_negative_number( + jnp.float32 if atten_mask.dtype == jnp.float64 else atten_mask.dtype + ) + key = _padded_slice(key, time_step + 1 - l, f, 1, 0.0) value = _padded_slice(value, time_step + 1 - l, f, 1, 0.0) - atten_mask = _padded_slice( - atten_mask, - time_step + 1 - l, - f, - -1, - py_utils.get_large_negative_number(atten_mask.dtype), - ) + atten_mask = _padded_slice(atten_mask, time_step + 1 - l, f, -1, minus_inf) b, f, n, h = key.shape asserts.eq(f, self.left_context + self.right_context) @@ -3451,10 +3451,9 @@ def _dot_atten_one_step( if self.zero_fully_masked: # Return zeros for tokens which don't attend anything. - fully_masked = jnp.all( - atten_mask < py_utils.get_large_negative_number(jnp.float32) / 2, - axis=-1, - )[..., jnp.newaxis] + fully_masked = jnp.all(atten_mask < minus_inf / 2, axis=-1)[ + ..., jnp.newaxis + ] encoded *= 1 - fully_masked encoded = self._shard_bnh(encoded)