Skip to content

Commit

Permalink
Remove build_from_signature from MHA layers (keras-team#1687)
Browse files Browse the repository at this point in the history
This was a Keras 2 work around for build only supporting a single
argument. We no longer need it.
  • Loading branch information
mattdangerw authored Jul 8, 2024
1 parent f9faaf1 commit 880c7c6
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 40 deletions.
7 changes: 0 additions & 7 deletions keras_nlp/src/layers/modeling/cached_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,6 @@ def call(
cache_update_index=None,
training=None,
):
if (
hasattr(self, "_build_from_signature")
and hasattr(self, "_built_from_signature")
and not self._built_from_signature
):
self._build_from_signature(query=query, value=value, key=key)

if key is None:
key = value

Expand Down
28 changes: 8 additions & 20 deletions keras_nlp/src/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,10 @@ def build(
dtype=self.dtype_policy,
name="self_attention",
)
if hasattr(self._self_attention_layer, "_build_from_signature"):
self._self_attention_layer._build_from_signature(
query=decoder_sequence_shape,
value=decoder_sequence_shape,
)
else:
self._self_attention_layer.build(
query_shape=decoder_sequence_shape,
value_shape=decoder_sequence_shape,
)
self._self_attention_layer.build(
query_shape=decoder_sequence_shape,
value_shape=decoder_sequence_shape,
)
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
Expand All @@ -195,16 +189,10 @@ def build(
dtype=self.dtype_policy,
name="cross_attention",
)
if hasattr(self._cross_attention_layer, "_build_from_signature"):
self._cross_attention_layer._build_from_signature(
query=decoder_sequence_shape,
value=encoder_sequence_shape,
)
else:
self._cross_attention_layer.build(
query_shape=decoder_sequence_shape,
value_shape=encoder_sequence_shape,
)
self._cross_attention_layer.build(
query_shape=decoder_sequence_shape,
value_shape=encoder_sequence_shape,
)
self._cross_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
Expand Down
14 changes: 4 additions & 10 deletions keras_nlp/src/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,10 @@ def build(self, inputs_shape):
dtype=self.dtype_policy,
name="self_attention_layer",
)
if hasattr(self._self_attention_layer, "_build_from_signature"):
self._self_attention_layer._build_from_signature(
query=inputs_shape,
value=inputs_shape,
)
else:
self._self_attention_layer.build(
query_shape=inputs_shape,
value_shape=inputs_shape,
)
self._self_attention_layer.build(
query_shape=inputs_shape,
value_shape=inputs_shape,
)
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,3 @@ def build(
output_dense_input_shape[-1] = self._value_dim
self._output_dense.build(tuple(output_dense_input_shape))
self.built = True

def _build_from_signature(self, query, value, key=None):
pass

0 comments on commit 880c7c6

Please sign in to comment.