Skip to content

Commit

Permalink
Add Falcon CausalLM.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed May 20, 2024
1 parent 294304b commit 000e7d2
Show file tree
Hide file tree
Showing 3 changed files with 480 additions and 4 deletions.
287 changes: 287 additions & 0 deletions keras_nlp/src/models/falcon/falcon_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.backend import ops
from keras_nlp.src.models.causal_lm import CausalLM
from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
from keras_nlp.src.models.falcon.falcon_causal_lm_preprocessor import (
FalconCausalLMPreprocessor,
)


@keras_nlp_export("keras_nlp.models.FalconCausalLM")
class FalconCausalLM(CausalLM):
"""An end-to-end Falcon model for causal language modeling.
A causal language model (LM) predicts the next token based on previous
tokens. This task setup can be used to train the model unsupervised on
plain text input, or to autoregressively generate plain text similar to
the data used for training. This task can be used for pre-training or
fine-tuning a Falcon model, simply by calling `fit()`.
This model has a `generate()` method, which generates text based on a
prompt. The generation strategy used is controlled by an additional
`sampler` argument on `compile()`. You can recompile the model with
different `keras_nlp.samplers` objects to control the generation. By
default, `"greedy"` sampling will be used.
This model can optionally be configured with a `preprocessor` layer, in
which case it will automatically apply preprocessing to string inputs during
`fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
when creating the model with `from_preset()`.
Args:
backbone: A `keras_nlp.models.FalconBackbone` instance.
preprocessor: A `keras_nlp.models.FalconCausalLMPreprocessor` or `None`.
If `None`, this model will not apply preprocessing, and inputs
should be preprocessed before calling the model.
Examples:
Use `generate()` to do text generation.
```python
falcon_lm = keras_nlp.models.FalconCausalLM.from_preset("falcon_refinedweb_1b_en")
falcon_lm.generate("I want to say", max_length=30)
# Generate with batched prompts.
falcon_lm.generate(["This is a", "Where are you"], max_length=30)
```
Compile the `generate()` function with a custom sampler.
```python
falcon_lm = keras_nlp.models.FalconCausalLM.from_preset("falcon_refinedweb_1b_en")
falcon_lm.compile(sampler="top_k")
falcon_lm.generate("I want to say", max_length=30)
falcon_lm.compile(sampler=keras_nlp.samplers.BeamSampler(num_beams=2))
falcon_lm.generate("I want to say", max_length=30)
```
Use `generate()` without preprocessing.
```python
prompt = {
# Token ids for "<bos> Keras is".
"token_ids": np.array([[2, 214064, 603, 0, 0, 0, 0]] * 2),
# Use `"padding_mask"` to indicate values that should not be overridden.
"padding_mask": np.array([[1, 1, 1, 0, 0, 0, 0]] * 2),
}
falcon_lm = keras_nlp.models.FalconCausalLM.from_preset(
"falcon_refinedweb_1b_en",
preprocessor=None,
)
falcon_lm.generate(prompt)
```
Call `fit()` on a single batch.
```python
features = ["The quick brown fox jumped.", "I forgot my homework."]
falcon_lm = keras_nlp.models.FalconCausalLM.from_preset("falcon_refinedweb_1b_en")
falcon_lm.fit(x=features, batch_size=2)
```
Call `fit()` without preprocessing.
```python
x = {
# Token ids for "<bos> Keras is deep learning library<eos>"
"token_ids": np.array([[2, 214064, 603, 5271, 6044, 9581, 1, 0]] * 2),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 0]] * 2),
}
y = np.array([[214064, 603, 5271, 6044, 9581, 3, 0, 0]] * 2)
sw = np.array([[1, 1, 1, 1, 1, 1, 0, 0]] * 2)
falcon_lm = keras_nlp.models.FalconCausalLM.from_preset(
"falcon_refinedweb_1b_en",
preprocessor=None,
)
falcon_lm.fit(x=x, y=y, sample_weight=sw, batch_size=2)
```
Custom backbone and vocabulary.
```python
tokenizer = keras_nlp.models.FalconTokenizer(
proto="proto.spm",
)
preprocessor = keras_nlp.models.FalconCausalLMPreprocessor(
tokenizer=tokenizer,
sequence_length=128,
)
backbone = keras_nlp.models.FalconBackbone(
vocabulary_size=30552,
num_layers=4,
num_heads=4,
hidden_dim=256,
intermediate_dim=512,
max_sequence_length=128,
)
falcon_lm = keras_nlp.models.FalconCausalLM(
backbone=backbone,
preprocessor=preprocessor,
)
falcon_lm.fit(x=features, batch_size=2)
```
"""

backbone_cls = FalconBackbone
preprocessor_cls = FalconCausalLMPreprocessor

def __init__(
self,
backbone,
preprocessor=None,
**kwargs,
):
# === Layers ===
self.backbone = backbone
self.preprocessor = preprocessor

# === Functional Model ===
inputs = backbone.input
hidden_states = backbone(inputs)
outputs = backbone.token_embedding(hidden_states, reverse=True)
super().__init__(
inputs=inputs,
outputs=outputs,
**kwargs,
)

def call_with_cache(
self,
token_ids,
cache,
cache_update_index,
):
"""Forward pass of `FalconCausalLM` with cache.
`call_with_cache` adds an additional forward pass for the model for
autoregressive inference. Unlike calling the model directly, this method
allows caching previous key/value Tensors in multi-head attention layer,
and avoids recomputing the outputs of seen tokens.
Args:
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
cache: a dense float Tensor, the cache of key and value.
cache_update_index: int, or int Tensor. The index of current inputs in the
whole sequence.
Returns:
A (logits, hidden_states, cache) tuple. Where `logits` is the
language model logits for the input token_ids, `hidden_states` is
the final hidden representation of the input tokens, and `cache` is
the decoding cache.
"""
x = self.backbone.token_embedding(token_ids)
# Each decoder layer has a cache; we update them separately.
caches = []
for i, transformer_layer in enumerate(self.backbone.transformer_layers):
current_cache = cache[:, i, ...]
x, next_cache = transformer_layer(
x,
attention_cache=current_cache,
attention_cache_update_index=cache_update_index,
)
caches.append(next_cache)
cache = ops.stack(caches, axis=1)
hidden_states = x = self.backbone.final_layernorm(x)
logits = self.backbone.token_embedding(x, reverse=True)
return logits, hidden_states, cache

def _build_cache(self, token_ids):
"""Build an empty cache for use with `call_with_cache()`."""
batch_size = ops.shape(token_ids)[0]
max_length = ops.shape(token_ids)[1]
num_layers = self.backbone.num_layers
num_heads = self.backbone.num_attention_heads
head_dim = self.backbone.hidden_dim // self.backbone.num_attention_heads
shape = [batch_size, num_layers, 2, max_length, num_heads, head_dim]
cache = ops.zeros(shape, dtype=self.compute_dtype)
# Seed the cache.
_, hidden_states, cache = self.call_with_cache(token_ids, cache, 0)
return hidden_states, cache

def generate_step(
self,
inputs,
stop_token_ids=None,
):
"""A compilable generation function for a single batch of inputs.
This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
Args:
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
stop_token_ids: Tuple of id's of end token's to stop on. If all
sequences have produced a new stop token, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
index = ops.min(row_lengths)

def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_update_index = index - 1
batch_size = ops.shape(prompt)[0]
prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1])
logits, hidden_states, cache = self.call_with_cache(
prompt,
cache,
cache_update_index,
)
return (
ops.squeeze(logits, axis=1),
ops.squeeze(hidden_states, axis=1),
cache,
)

token_ids = self.sampler(
next=next,
prompt=token_ids,
cache=cache,
index=index,
mask=padding_mask,
stop_token_ids=stop_token_ids,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
if stop_token_ids is not None:
# Build a mask of `end_token_id` locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = ops.logical_and(
ops.equal(token_ids, stop_token_ids),
ops.logical_not(padding_mask),
)
end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
padding_mask = ops.ones_like(token_ids, dtype="bool")
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}
Loading

0 comments on commit 000e7d2

Please sign in to comment.