diff --git a/praxis/layers/attentions.py b/praxis/layers/attentions.py index b7026d3b..652aa28c 100644 --- a/praxis/layers/attentions.py +++ b/praxis/layers/attentions.py @@ -1007,6 +1007,7 @@ class DotProductAttention(base_layer.BaseLayer): decode_cache: bool = True attention_mask_summary: bool = False zero_fully_masked: bool = False + mha_mask_addition_pattern: bool = True qk_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) pv_einsum_tpl: LayerTpl = template_field(base_ops.EinsumOp) @@ -1342,8 +1343,14 @@ def _dot_atten( logits = self._cap_logits(logits) # Attention softmax is always carried out in fp32. logits = logits.astype(jnp.float32) + # Apply attention masking - padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask) + if self.mha_mask_addition_pattern: + padded_logits = logits + atten_mask.astype(jnp.float32) + else: + padded_logits = py_utils.apply_mask_to_logits(logits, atten_mask) + + if self.attention_mask_summary: self.add_summary('attention_mask', atten_mask) if self.attention_extra_logit is None: