-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DO NOT MERGE] Experimental implementation of CausalLM with a Keras Functional backbone_with_cache #1598
Conversation
…This uses the new experimental clone_layer_graph functionality.
For a demo of the fuctionality see this Colab: Model rewiring demo with LLMs.ipynb. For example, you can insert control vectors into an LLM backbone with this def clone_fn(layer, *args, **kwargs):
if isinstance(layer, keras_nlp.layers.TransformerDecoder):
x = layer(*args, **kwargs)
x = ControlVectorLayer()(x)
return x
else:
return layer(*args, **kwargs) # identity |
Known issue: the |
I have changed the implementation to use the new For this use case, i.e. rewiring a language model backbone with KV caches, the new API is a bit awkward, as it forces the user to use an intermediate model. In simplified code:
The intermediate model The previously suggested API did not have this awkwardness as it did not involve an intermediate
Additional Note: I also noticed that the new API clones input tensors. For example, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Took a look!
Stopgap for now
Below shows a colab that would work today, and be decently readable (patching the call method of each transformer layer):
This is maybe not the most conceptually beautiful, but it's short, practical, and works for modifying attention and key/query/value projections. It's unclear to me if the approach on this PR could grow to include those types of surgery without substantive rewrites and breaking all weight compatibility.
I think for better or for worse, "layer patching" is the best cohesive approach we have for surgery today. It seems prudent to point users in this direction for now, as we think through new forms of model surgery, and work on a more flexible form of our current generation code.
Forward looking thoughts
What is on this PR is not general for all models (see comments below), and the general form might require pushing the rewiring to individual model implementations, to handle different layer types, input types, arg names, and overall structures (e.g. seq2seq).
I think it'd be doable in the technical sense, if a good bit more code. The rewiring code is a bit clunky and not super readable.
I'd be interested in scoping out support in Keras functional models for optional inputs. If we did this, we could write a backbone that supports caching out of the box. As well as other common "power user" inputs e.g. attention_mask
, token_positions
, that would cover a lot of other important CUJs without any need for cloning or surgery.
Optional inputs could allow a rewrite the our generative task code to treat the backbone as a functional black box, without any assumptions of internal layer structure. That could allow a lot of the types of functional surgeries you are interested in, with a much smaller blast radius on our model implementation code.
@@ -503,3 +503,48 @@ def get_config(self): | |||
|
|||
def compute_output_shape(self, decoder_sequence_shape): | |||
return decoder_sequence_shape | |||
|
|||
def compute_output_spec( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kinda clunky, but might be a good idea to add regardless. Can we just **kwargs
the args we don't actually care about here?
I'm not sure if we need compute_ouptut_shape
if we do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, compute_output shape is probably not required when compute_output_spec is implemented. Here compute_output_spec was necessary because the layer returns differently shaped outputs depending on inputs.
return output | ||
|
||
def rewire_fn(layer, *args, **kwargs): | ||
if isinstance(layer, PositionEmbedding): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would not work for a few models where the position embedding is part of a composite layer (TokenAndPositionEmbedding
).
def rewire_fn(layer, *args, **kwargs): | ||
if isinstance(layer, PositionEmbedding): | ||
return rewire_positionembedding(layer, *args, **kwargs) | ||
elif isinstance(layer, TransformerDecoder): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would not work for most decoder models (as model decoder models write their own decoder block).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can test for being in the set of backbone.transformer_blocks rather than a specific type. This can be solved with a convention of what "backbone" should contain (which makes sense - not any backbone works for cached text generation).
def _rewire_backbone_with_cache(backbone, cache_shape, cache_dtype): | ||
|
||
# Define new inputs for caches. | ||
cache_update_index_input = keras.Input( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our cache might change to be a tuple of individual layer caches to avoid stacking/concating as described here. #1562
And further down the road, we might want to add more cache shape options, e.g. for things like token attention.
Interestingly, a cache of tuples would invalidate our current restriction on functional model inputs. We'd want a nested structure where one dictionary key contains a tuple of inputs, would break here https://github.com/keras-team/keras/blob/9f4da5159a098256dfbccd2c926107953a6812e5/keras/src/models/functional.py#L134-L141
So we may need to do more thinking here if we "unstack our cache".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expanding functional models to arbitrary pytree inputs and outputs (as long as leaves are KerasTensors) is on the roadmap(look under "Modeling").
) | ||
# cache_update_index_input is always a scalar. We must force the | ||
# shape to scalar because keras.Input assumes a batch dim. | ||
cache_update_index_input.shape = () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems kinda hacky, is this something we want to support generally in Keras? Unbatched functional model inputs? And if so, is this the way we would like to do it?
# === Backbone with cache === | ||
# The backbone with a cache is used in call_with_cache | ||
cache_shape = self._compute_cache_shape(backbone, preprocessor) | ||
self.backbone_with_cache = self._rewire_backbone_with_cache( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might fail with Keras 2 saving for obscure reasons. Basically, we might try to save backbone_with_cache
before the internal model layers, invalidating the whole checkpoint structure. (Just yet another reason to try to ditch Keras 2 asap).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
{ | ||
"token_ids": token_ids, | ||
"cache": cache, | ||
"cache_update_index": ops.convert_to_tensor(cache_update_index), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need conver_to_tensor
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why. It would not work without it.
Update here... @fchollet has added support for optional functional inputs. So what I think we can do is write a backbone that allows two optional inputs I think this is the right solution abstraction wise, and will allow a lot more aggressive model surgeries. But landing this will still take some effort as we will need to drop Keras 2 codepaths in the library (Keras 2 will not support optional inputs). |
This is a proof of concept PR for the new layer graph cloning API in Keras (keras-team/keras#19600).
It is not meant to be merged as such but provide a tangible use case for the design of the new layer graph cloning API.
The problem to solve was:
In order to let users implement what they want in the backbone and have call_with_cache still work in XXXCausalLLM, it is necessary to add the caching to the backbone in a Keras Functional way, and respect the Functional layer graph of the backbone.
The new layer graph cloning API can be used:
This PR implements a Keras Functional call_with_cache for GPT2 and Gemma.