Skip to content

Commit

Permalink
Refactoring: in CachedMultiHeadAttention call MHA methods instead of …
Browse files Browse the repository at this point in the history
…recoding the attention calculation (keras-team#1684)

* Call "_compute_attention" instead of recoding the calculation

* Fix formatting

---------

Co-authored-by: Matt Watson <[email protected]>
  • Loading branch information
apehex and mattdangerw authored Jul 8, 2024
1 parent 880c7c6 commit 29c85c0
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions keras_nlp/src/layers/modeling/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,14 @@ def call(
key = self._key_dense(key)
value = self._value_dense(value)

query = ops.multiply(
query,
1.0 / ops.sqrt(ops.cast(self._key_dim, query.dtype)),
)
attention_scores = ops.einsum(self._dot_product_equation, key, query)
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)
attention_scores = self._dropout_layer(
attention_scores, training=training
attention_output, attention_scores = self._compute_attention(
query=query,
key=key,
value=value,
attention_mask=attention_mask,
training=training,
)

attention_output = ops.einsum(
self._combine_equation, attention_scores, value
)
attention_output = self._output_dense(attention_output)

if cache is not None:
Expand Down

0 comments on commit 29c85c0

Please sign in to comment.