Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Experimental implementation of CausalLM with a Keras Functional backbone_with_cache #1598

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions keras_nlp/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,48 @@ def get_config(self):

def compute_output_shape(self, decoder_sequence_shape):
return decoder_sequence_shape

def compute_output_spec(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kinda clunky, but might be a good idea to add regardless. Can we just **kwargs the args we don't actually care about here?

I'm not sure if we need compute_ouptut_shape if we do this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, compute_output shape is probably not required when compute_output_spec is implemented. Here compute_output_spec was necessary because the layer returns differently shaped outputs depending on inputs.

self,
decoder_sequence,
encoder_sequence=None,
decoder_padding_mask=None,
decoder_attention_mask=None,
encoder_padding_mask=None,
encoder_attention_mask=None,
self_attention_cache=None,
self_attention_cache_update_index=None,
cross_attention_cache=None,
cross_attention_cache_update_index=None,
use_causal_mask=True,
):
if self_attention_cache is not None:
has_cross_attention = self._cross_attention_layer is not None
if has_cross_attention:
return (
keras.KerasTensor(
decoder_sequence.shape, dtype=decoder_sequence.dtype
),
keras.KerasTensor(
self_attention_cache.shape,
dtype=self_attention_cache.dtype,
),
keras.KerasTensor(
cross_attention_cache.shape,
dtype=cross_attention_cache.dtype,
),
)
else:
return (
keras.KerasTensor(
decoder_sequence.shape, dtype=decoder_sequence.dtype
),
keras.KerasTensor(
self_attention_cache.shape,
dtype=self_attention_cache.dtype,
),
)
else:
return keras.KerasTensor(
decoder_sequence.shape, dtype=decoder_sequence.dtype
)
106 changes: 106 additions & 0 deletions keras_nlp/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@
)
import tree

from keras.src.models.cloning import clone_model
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.models.task import Task
from keras_nlp.samplers.serialization import get as get_sampler
from keras_nlp.utils.tensor_utils import tensor_to_list
from keras_nlp.layers.modeling.position_embedding import PositionEmbedding
from keras_nlp.layers.modeling.transformer_decoder import TransformerDecoder


@keras_nlp_export("keras_nlp.models.CausalLM")
Expand Down Expand Up @@ -387,3 +390,106 @@ def postprocess(x):
outputs = [postprocess(x) for x in outputs]

return self._normalize_generate_outputs(outputs, input_is_scalar)

"""Wires LLM decoding caches into a backbone.

Returns a new functional backbone with the same graph layout as the
original backbone, but with caches wired into TransformerDecoder blocks.
"""
@staticmethod
def _rewire_backbone_with_cache(backbone, cache_shape, cache_dtype):

# Define new inputs for caches.
cache_update_index_input = keras.Input(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our cache might change to be a tuple of individual layer caches to avoid stacking/concating as described here. #1562

And further down the road, we might want to add more cache shape options, e.g. for things like token attention.

Interestingly, a cache of tuples would invalidate our current restriction on functional model inputs. We'd want a nested structure where one dictionary key contains a tuple of inputs, would break here https://github.com/keras-team/keras/blob/9f4da5159a098256dfbccd2c926107953a6812e5/keras/src/models/functional.py#L134-L141

So we may need to do more thinking here if we "unstack our cache".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanding functional models to arbitrary pytree inputs and outputs (as long as leaves are KerasTensors) is on the roadmap(look under "Modeling").

shape=(), dtype="int32", name="cache_update_index"
)
# cache_update_index_input is always a scalar. We must force the
# shape to scalar because keras.Input assumes a batch dim.
cache_update_index_input.shape = ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems kinda hacky, is this something we want to support generally in Keras? Unbatched functional model inputs? And if so, is this the way we would like to do it?

# Input for a combined cache for all TransformerDecoder layers
cache_input = keras.Input(
shape=cache_shape, dtype=cache_dtype, name="cache"
)

# Split the cache on the num_layers axis=1 (number of transformer blocks).
caches = ops.unstack(cache_input, cache_input.shape[1], axis=1)
decoder_block_idx = 0
next_caches = []

def rewire_positionembedding(layer, *args, **kwargs):
# wire in a new input: cache_update_index_input by calling the layer on this input
nonlocal cache_update_index_input
return layer(*args, start_index=cache_update_index_input, **kwargs)

def rewire_transformerdecoder(layer, *args, **kwargs):
# wire in caches, next_caches by calling the layer on these inputs
nonlocal caches, next_caches, decoder_block_idx
# no mask when decoding with cache
kwargs.pop("decoder_padding_mask")
output, next_cache = layer(
*args,
self_attention_cache=caches[decoder_block_idx],
self_attention_cache_update_index=cache_update_index_input,
**kwargs,
)
decoder_block_idx += 1
next_caches.append(next_cache)
return output

def rewire_fn(layer, *args, **kwargs):
if isinstance(layer, PositionEmbedding):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would not work for a few models where the position embedding is part of a composite layer (TokenAndPositionEmbedding).

return rewire_positionembedding(layer, *args, **kwargs)
elif isinstance(layer, TransformerDecoder):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would not work for most decoder models (as model decoder models write their own decoder block).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can test for being in the set of backbone.transformer_blocks rather than a specific type. This can be solved with a convention of what "backbone" should contain (which makes sense - not any backbone works for cached text generation).

return rewire_transformerdecoder(layer, *args, **kwargs)
else:
return layer(*args, **kwargs) # identity

# Original code with "clone_layer_graph" API (to be deleted later).
# # Rewire the graph of layers with caches
# token_ids_input = backbone.input["token_ids"]
# input = {
# "token_ids": token_ids_input,
# "cache": cache_input,
# "cache_update_index": cache_update_index_input,
# }
# hidden_states = clone_layer_graph(input, backbone.output, rewire_fn)
# # During the rewiring process, output caches were collected
# next_cache = ops.stack(next_caches, axis=1)
# logits = backbone.token_embedding(hidden_states, reverse=True)
#
# # create a new backbone that now uses caches in its forward pass
# output = (logits, hidden_states, next_cache)
# return keras.Model(input, output, name=backbone.name + "_with_cache")

# Copy the layer graph (not the layers themselves!) while adding cache
# inputs and outputs to TransformerDecoder layers.
rewired_backbone = clone_model(backbone,
clone_function=lambda x:x, # no cloning
call_function=rewire_fn)

# Build a new model with caches in inputs and outputs.
input = {
"token_ids": rewired_backbone.input["token_ids"],
"cache": cache_input,
"cache_update_index": cache_update_index_input,
}

# During the rewiring process, output caches were collected
next_cache = ops.stack(next_caches, axis=1)
# This is the original output of the backbone.
hidden_states = rewired_backbone.output
# For text generation, we also want a decoded output.
logits = backbone.token_embedding(hidden_states, reverse=True)
output = (logits, hidden_states, next_cache)

# create a new backbone that now uses caches in its forward pass
return keras.Model(input, output, name=backbone.name + "_with_cache")

# cache shape without batch dimension
@staticmethod
def _compute_cache_shape(backbone, preprocessor):
num_layers = backbone.num_layers
max_length = preprocessor.sequence_length
num_heads = backbone.num_heads
head_dim = backbone.hidden_dim // backbone.num_heads
return [num_layers, 2, max_length, num_heads, head_dim]
30 changes: 14 additions & 16 deletions keras_nlp/models/gemma/gemma_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ def __init__(
self.backbone = backbone
self.preprocessor = preprocessor

# === Backbone with cache ===
# The backbone with a cache is used in call_with_cache
cache_shape = self._compute_cache_shape(backbone, preprocessor)
self.backbone_with_cache = self._rewire_backbone_with_cache(
backbone, cache_shape, backbone.compute_dtype
)

# === Functional Model ===
inputs = backbone.input
hidden_states = backbone(inputs)
Expand Down Expand Up @@ -210,22 +217,13 @@ def call_with_cache(
the final hidden representation of the input tokens, and `cache` is
the decoding cache.
"""
x = self.backbone.token_embedding(token_ids)
x = x * ops.cast(ops.sqrt(self.backbone.hidden_dim), x.dtype)
# Each decoder layer has a cache; we update them separately.
caches = []
for i, transformer_layer in enumerate(self.backbone.transformer_layers):
current_cache = cache[:, i, ...]
x, next_cache = transformer_layer(
x,
cache=current_cache,
cache_update_index=cache_update_index,
)
caches.append(next_cache)
cache = ops.stack(caches, axis=1)
hidden_states = x = self.backbone.layer_norm(x)
logits = self.backbone.token_embedding(x, reverse=True)
return logits, hidden_states, cache
return self.backbone_with_cache(
{
"token_ids": token_ids,
"cache": cache,
"cache_update_index": ops.convert_to_tensor(cache_update_index),
}
)

def _build_cache(self, token_ids):
"""Build an empty cache for use with `call_with_cache()`."""
Expand Down
32 changes: 13 additions & 19 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ def __init__(
self.backbone = backbone
self.preprocessor = preprocessor

# === Backbone with cache ===
# The backbone with a cache is used in call_with_cache
cache_shape = self._compute_cache_shape(backbone, preprocessor)
self.backbone_with_cache = self._rewire_backbone_with_cache(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might fail with Keras 2 saving for obscure reasons. Basically, we might try to save backbone_with_cache before the internal model layers, invalidating the whole checkpoint structure. (Just yet another reason to try to ditch Keras 2 asap).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

backbone, cache_shape, backbone.compute_dtype
)

# === Functional Model ===
inputs = backbone.input
hidden_states = backbone(inputs)
Expand Down Expand Up @@ -195,26 +202,13 @@ def call_with_cache(
the final hidden representation of the input tokens, and `cache` is
the decoding cache.
"""
tokens = self.backbone.token_embedding(token_ids)
positions = self.backbone.position_embedding(
tokens, start_index=cache_update_index
return self.backbone_with_cache(
{
"token_ids": token_ids,
"cache": cache,
"cache_update_index": ops.convert_to_tensor(cache_update_index),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need conver_to_tensor here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why. It would not work without it.

}
)
x = self.backbone.embeddings_add((tokens, positions))
x = self.backbone.embeddings_dropout(x)
# Each decoder layer has a cache; we update them separately.
caches = []
for i, transformer_layer in enumerate(self.backbone.transformer_layers):
current_cache = cache[:, i, ...]
x, next_cache = transformer_layer(
x,
self_attention_cache=current_cache,
self_attention_cache_update_index=cache_update_index,
)
caches.append(next_cache)
cache = ops.stack(caches, axis=1)
hidden_states = x = self.backbone.layer_norm(x)
logits = self.backbone.token_embedding(x, reverse=True)
return logits, hidden_states, cache

def _build_cache(self, token_ids):
"""Build an empty cache for use with `call_with_cache()`."""
Expand Down
Loading