-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DO NOT MERGE] Experimental implementation of CausalLM with a Keras Functional backbone_with_cache #1598
[DO NOT MERGE] Experimental implementation of CausalLM with a Keras Functional backbone_with_cache #1598
Changes from all commits
e8b5d44
e42f8bb
4546615
f5ccde5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,13 +24,16 @@ | |
) | ||
import tree | ||
|
||
from keras.src.models.cloning import clone_model | ||
from keras_nlp.api_export import keras_nlp_export | ||
from keras_nlp.backend import config | ||
from keras_nlp.backend import keras | ||
from keras_nlp.backend import ops | ||
from keras_nlp.models.task import Task | ||
from keras_nlp.samplers.serialization import get as get_sampler | ||
from keras_nlp.utils.tensor_utils import tensor_to_list | ||
from keras_nlp.layers.modeling.position_embedding import PositionEmbedding | ||
from keras_nlp.layers.modeling.transformer_decoder import TransformerDecoder | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.CausalLM") | ||
|
@@ -387,3 +390,106 @@ def postprocess(x): | |
outputs = [postprocess(x) for x in outputs] | ||
|
||
return self._normalize_generate_outputs(outputs, input_is_scalar) | ||
|
||
"""Wires LLM decoding caches into a backbone. | ||
|
||
Returns a new functional backbone with the same graph layout as the | ||
original backbone, but with caches wired into TransformerDecoder blocks. | ||
""" | ||
@staticmethod | ||
def _rewire_backbone_with_cache(backbone, cache_shape, cache_dtype): | ||
|
||
# Define new inputs for caches. | ||
cache_update_index_input = keras.Input( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our cache might change to be a tuple of individual layer caches to avoid stacking/concating as described here. #1562 And further down the road, we might want to add more cache shape options, e.g. for things like token attention. Interestingly, a cache of tuples would invalidate our current restriction on functional model inputs. We'd want a nested structure where one dictionary key contains a tuple of inputs, would break here https://github.com/keras-team/keras/blob/9f4da5159a098256dfbccd2c926107953a6812e5/keras/src/models/functional.py#L134-L141 So we may need to do more thinking here if we "unstack our cache". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Expanding functional models to arbitrary pytree inputs and outputs (as long as leaves are KerasTensors) is on the roadmap(look under "Modeling"). |
||
shape=(), dtype="int32", name="cache_update_index" | ||
) | ||
# cache_update_index_input is always a scalar. We must force the | ||
# shape to scalar because keras.Input assumes a batch dim. | ||
cache_update_index_input.shape = () | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems kinda hacky, is this something we want to support generally in Keras? Unbatched functional model inputs? And if so, is this the way we would like to do it? |
||
# Input for a combined cache for all TransformerDecoder layers | ||
cache_input = keras.Input( | ||
shape=cache_shape, dtype=cache_dtype, name="cache" | ||
) | ||
|
||
# Split the cache on the num_layers axis=1 (number of transformer blocks). | ||
caches = ops.unstack(cache_input, cache_input.shape[1], axis=1) | ||
decoder_block_idx = 0 | ||
next_caches = [] | ||
|
||
def rewire_positionembedding(layer, *args, **kwargs): | ||
# wire in a new input: cache_update_index_input by calling the layer on this input | ||
nonlocal cache_update_index_input | ||
return layer(*args, start_index=cache_update_index_input, **kwargs) | ||
|
||
def rewire_transformerdecoder(layer, *args, **kwargs): | ||
# wire in caches, next_caches by calling the layer on these inputs | ||
nonlocal caches, next_caches, decoder_block_idx | ||
# no mask when decoding with cache | ||
kwargs.pop("decoder_padding_mask") | ||
output, next_cache = layer( | ||
*args, | ||
self_attention_cache=caches[decoder_block_idx], | ||
self_attention_cache_update_index=cache_update_index_input, | ||
**kwargs, | ||
) | ||
decoder_block_idx += 1 | ||
next_caches.append(next_cache) | ||
return output | ||
|
||
def rewire_fn(layer, *args, **kwargs): | ||
if isinstance(layer, PositionEmbedding): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this would not work for a few models where the position embedding is part of a composite layer ( |
||
return rewire_positionembedding(layer, *args, **kwargs) | ||
elif isinstance(layer, TransformerDecoder): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this would not work for most decoder models (as model decoder models write their own decoder block). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can test for being in the set of backbone.transformer_blocks rather than a specific type. This can be solved with a convention of what "backbone" should contain (which makes sense - not any backbone works for cached text generation). |
||
return rewire_transformerdecoder(layer, *args, **kwargs) | ||
else: | ||
return layer(*args, **kwargs) # identity | ||
|
||
# Original code with "clone_layer_graph" API (to be deleted later). | ||
# # Rewire the graph of layers with caches | ||
# token_ids_input = backbone.input["token_ids"] | ||
# input = { | ||
# "token_ids": token_ids_input, | ||
# "cache": cache_input, | ||
# "cache_update_index": cache_update_index_input, | ||
# } | ||
# hidden_states = clone_layer_graph(input, backbone.output, rewire_fn) | ||
# # During the rewiring process, output caches were collected | ||
# next_cache = ops.stack(next_caches, axis=1) | ||
# logits = backbone.token_embedding(hidden_states, reverse=True) | ||
# | ||
# # create a new backbone that now uses caches in its forward pass | ||
# output = (logits, hidden_states, next_cache) | ||
# return keras.Model(input, output, name=backbone.name + "_with_cache") | ||
|
||
# Copy the layer graph (not the layers themselves!) while adding cache | ||
# inputs and outputs to TransformerDecoder layers. | ||
rewired_backbone = clone_model(backbone, | ||
clone_function=lambda x:x, # no cloning | ||
call_function=rewire_fn) | ||
|
||
# Build a new model with caches in inputs and outputs. | ||
input = { | ||
"token_ids": rewired_backbone.input["token_ids"], | ||
"cache": cache_input, | ||
"cache_update_index": cache_update_index_input, | ||
} | ||
|
||
# During the rewiring process, output caches were collected | ||
next_cache = ops.stack(next_caches, axis=1) | ||
# This is the original output of the backbone. | ||
hidden_states = rewired_backbone.output | ||
# For text generation, we also want a decoded output. | ||
logits = backbone.token_embedding(hidden_states, reverse=True) | ||
output = (logits, hidden_states, next_cache) | ||
|
||
# create a new backbone that now uses caches in its forward pass | ||
return keras.Model(input, output, name=backbone.name + "_with_cache") | ||
|
||
# cache shape without batch dimension | ||
@staticmethod | ||
def _compute_cache_shape(backbone, preprocessor): | ||
num_layers = backbone.num_layers | ||
max_length = preprocessor.sequence_length | ||
num_heads = backbone.num_heads | ||
head_dim = backbone.hidden_dim // backbone.num_heads | ||
return [num_layers, 2, max_length, num_heads, head_dim] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -160,6 +160,13 @@ def __init__( | |
self.backbone = backbone | ||
self.preprocessor = preprocessor | ||
|
||
# === Backbone with cache === | ||
# The backbone with a cache is used in call_with_cache | ||
cache_shape = self._compute_cache_shape(backbone, preprocessor) | ||
self.backbone_with_cache = self._rewire_backbone_with_cache( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might fail with Keras 2 saving for obscure reasons. Basically, we might try to save There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||
backbone, cache_shape, backbone.compute_dtype | ||
) | ||
|
||
# === Functional Model === | ||
inputs = backbone.input | ||
hidden_states = backbone(inputs) | ||
|
@@ -195,26 +202,13 @@ def call_with_cache( | |
the final hidden representation of the input tokens, and `cache` is | ||
the decoding cache. | ||
""" | ||
tokens = self.backbone.token_embedding(token_ids) | ||
positions = self.backbone.position_embedding( | ||
tokens, start_index=cache_update_index | ||
return self.backbone_with_cache( | ||
{ | ||
"token_ids": token_ids, | ||
"cache": cache, | ||
"cache_update_index": ops.convert_to_tensor(cache_update_index), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why. It would not work without it. |
||
} | ||
) | ||
x = self.backbone.embeddings_add((tokens, positions)) | ||
x = self.backbone.embeddings_dropout(x) | ||
# Each decoder layer has a cache; we update them separately. | ||
caches = [] | ||
for i, transformer_layer in enumerate(self.backbone.transformer_layers): | ||
current_cache = cache[:, i, ...] | ||
x, next_cache = transformer_layer( | ||
x, | ||
self_attention_cache=current_cache, | ||
self_attention_cache_update_index=cache_update_index, | ||
) | ||
caches.append(next_cache) | ||
cache = ops.stack(caches, axis=1) | ||
hidden_states = x = self.backbone.layer_norm(x) | ||
logits = self.backbone.token_embedding(x, reverse=True) | ||
return logits, hidden_states, cache | ||
|
||
def _build_cache(self, token_ids): | ||
"""Build an empty cache for use with `call_with_cache()`.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kinda clunky, but might be a good idea to add regardless. Can we just
**kwargs
the args we don't actually care about here?I'm not sure if we need
compute_ouptut_shape
if we do this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, compute_output shape is probably not required when compute_output_spec is implemented. Here compute_output_spec was necessary because the layer returns differently shaped outputs depending on inputs.