diff --git a/keras_nlp/src/layers/modeling/cached_multi_head_attention.py b/keras_nlp/src/layers/modeling/cached_multi_head_attention.py index bb5c4e92f4..0344d2f8e7 100644 --- a/keras_nlp/src/layers/modeling/cached_multi_head_attention.py +++ b/keras_nlp/src/layers/modeling/cached_multi_head_attention.py @@ -122,21 +122,14 @@ def call( key = self._key_dense(key) value = self._value_dense(value) - query = ops.multiply( - query, - 1.0 / ops.sqrt(ops.cast(self._key_dim, query.dtype)), - ) - attention_scores = ops.einsum(self._dot_product_equation, key, query) - attention_scores = self._masked_softmax( - attention_scores, attention_mask - ) - attention_scores = self._dropout_layer( - attention_scores, training=training + attention_output, attention_scores = self._compute_attention( + query=query, + key=key, + value=value, + attention_mask=attention_mask, + training=training, ) - attention_output = ops.einsum( - self._combine_equation, attention_scores, value - ) attention_output = self._output_dense(attention_output) if cache is not None: