diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 87cd206495..b8d47dbe9c 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -14,11 +14,8 @@ if [[ -z "${KAGGLE_USERNAME}" ]]; then fi set -x - cd "${KOKORO_ROOT}/" -sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 - PYTHON_BINARY="/usr/bin/python3.9" "${PYTHON_BINARY}" -m venv venv diff --git a/keras_nlp/layers/modeling/rotary_embedding.py b/keras_nlp/layers/modeling/rotary_embedding.py index 45f77ce494..b494d559bd 100644 --- a/keras_nlp/layers/modeling/rotary_embedding.py +++ b/keras_nlp/layers/modeling/rotary_embedding.py @@ -85,30 +85,42 @@ def __init__( self.built = True def call(self, inputs, start_index=0): + inputs = ops.moveaxis( + inputs, (self.feature_axis, self.sequence_axis), (-1, 1) + ) cos_emb, sin_emb = self._compute_cos_sin_embedding(inputs, start_index) - return self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) + output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) + return ops.moveaxis( + output, (-1, 1), (self.feature_axis, self.sequence_axis) + ) def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): - x1, x2 = ops.split(tensor, 2, axis=self.feature_axis) - half_rot_tensor = ops.concatenate((-x2, x1), axis=self.feature_axis) + x1, x2 = ops.split(tensor, 2, axis=-1) + # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA + # compilation on jax. We should be able to remove this once the + # following PR is in all jax releases we care about: + # https://github.com/openxla/xla/pull/7875 + half_rot_tensor = ops.stack((-x2, x1), axis=-2) + half_rot_tensor = ops.reshape(half_rot_tensor, ops.shape(tensor)) return (tensor * cos_emb) + (half_rot_tensor * sin_emb) def _compute_cos_sin_embedding(self, inputs, start_index=0): - def get_axis(axis): - return axis if axis > 0 else len(inputs.shape) + axis + start_index = ops.cast(start_index, dtype="float32") - feature_axis = get_axis(self.feature_axis) - sequence_axis = get_axis(self.sequence_axis) + feature_axis = len(inputs.shape) - 1 + sequence_axis = 1 rotary_dim = ops.shape(inputs)[feature_axis] inverse_freq = self._get_inverse_freq(rotary_dim) - seq_len = ops.shape(inputs)[self.sequence_axis] - tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index + seq_len = ops.shape(inputs)[sequence_axis] + tensor = ops.arange(seq_len, dtype="float32") + start_index - tensor = ops.cast(tensor, dtype=inverse_freq.dtype) freq = ops.einsum("i,j->ij", tensor, inverse_freq) - embedding = ops.concatenate((freq, freq), axis=-1) + embedding = ops.stack((freq, freq), axis=-2) + embedding = ops.reshape( + embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2) + ) # Reshape the embedding to be broadcastable with input shape. if feature_axis < sequence_axis: @@ -117,17 +129,16 @@ def get_axis(axis): if axis != sequence_axis and axis != feature_axis: embedding = ops.expand_dims(embedding, axis) - return ops.cos(embedding), ops.sin(embedding) + cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype) + sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype) + return cos_emb, sin_emb def _get_inverse_freq(self, rotary_dim): - freq_range = ops.arange(0, rotary_dim, 2) - freq_range = ops.cast(freq_range, self.compute_dtype) - freq_range = freq_range / ops.cast( - self.scaling_factor, self.compute_dtype - ) + freq_range = ops.arange(0, rotary_dim, 2, dtype="float32") + freq_range = freq_range / ops.cast(self.scaling_factor, "float32") inverse_freq = 1.0 / ( self.max_wavelength - ** (freq_range / ops.cast(rotary_dim, self.compute_dtype)) + ** (freq_range / ops.cast(rotary_dim, "float32")) ) return inverse_freq diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py index e01c1f8ce4..4b391264a2 100644 --- a/keras_nlp/models/gemma/gemma_attention.py +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -15,6 +15,7 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops +from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding from keras_nlp.utils.keras_utils import clone_initializer @@ -87,28 +88,23 @@ def build(self, inputs_shape): (None, None, self.num_query_heads, self.head_dim) ) self.softmax = keras.layers.Softmax(dtype="float32") + + self.rope_layer = RotaryEmbedding( + max_wavelength=10_000.0, dtype=self.dtype_policy + ) + self.built = True - def _apply_rope(self, x, positions): + def _apply_rope(self, x, start_index): """Rope rotate q or k.""" - # TODO: refactor to use RotaryEmbedding layer? - max_wavelength = 10000 - x_shape = ops.shape(x) - freq_exponents = (2.0 / x_shape[-1]) * ops.arange( - x_shape[-1] // 2, dtype="float32" + x = self.rope_layer(x, start_index=start_index) + # Gemma uses a different layout for positional embeddings. + # The transformation below ensures the embeddings are numerically + # equivalent to the original gemma implementation. + x = ops.reshape( + ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x) ) - timescale = max_wavelength**freq_exponents - radians = positions[..., None] / timescale[None, None, :] - radians = radians[..., None, :] - sin = ops.cast(ops.sin(radians), self.compute_dtype) - cos = ops.cast(ops.cos(radians), self.compute_dtype) - x1, x2 = ops.split(x, 2, axis=-1) - # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA - # compilation on jax. We should be able to remove this once the - # following PR is in all jax releases we care about: - # https://github.com/openxla/xla/pull/7875 - output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) - return ops.reshape(output, x_shape) + return x def _compute_attention( self, @@ -155,19 +151,14 @@ def call( cache_update_index=0, training=False, ): - seq_len = ops.shape(x)[1] - start_index = cache_update_index - positions = ops.arange(seq_len, dtype="float32") - - positions = positions + ops.cast(start_index, "float32") query = self.query_dense(x) - query = self._apply_rope(query, positions) + query = self._apply_rope(query, cache_update_index) if cache is not None: key_cache = cache[:, 0, ...] value_cache = cache[:, 1, ...] key_update = self.key_dense(x) - key_update = self._apply_rope(key_update, positions) + key_update = self._apply_rope(key_update, cache_update_index) value_update = self.value_dense(x) start = [0, cache_update_index, 0, 0] key = ops.slice_update(key_cache, start, key_update) @@ -175,7 +166,7 @@ def call( cache = ops.stack((key, value), axis=1) else: key = self.key_dense(x) - key = self._apply_rope(key, positions) + key = self._apply_rope(key, cache_update_index) value = self.value_dense(x) attention_vec = self._compute_attention( diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py index c829aa948f..06f5b0f601 100644 --- a/keras_nlp/models/gemma/gemma_backbone.py +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -194,7 +194,11 @@ def presets(cls): return copy.deepcopy(backbone_presets) @staticmethod - def get_layout_map(device_mesh, model_parallel_dim_name="model"): + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): """Get a `keras.distribution.LayoutMap` for model parallel distribution. The returned `LayoutMap` contains the sharding spec for the gemma @@ -221,6 +225,8 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"): distribution. model_parallel_dim_name: The axis name of the device mesh, where the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. Return: `keras.distribution.LayoutMap` that contains the sharding spec of all the model weights. @@ -248,21 +254,30 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"): f"{model_parallel_dim_name} is not found in the " f"device_mesh.axis_names. {device_mesh.axis_name=}" ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name model_dim = model_parallel_dim_name - # The sharding is partition for the hidden_dim of the model. + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 layout_map = keras.distribution.LayoutMap(device_mesh) - layout_map["token_embedding/embeddings"] = (None, model_dim) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ( - None, model_dim, + data_dim, None, ) layout_map["decoder_block.*attention_output.*kernel"] = ( - None, - None, model_dim, + None, + data_dim, ) - layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None) - layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim) + layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim) + layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim) return layout_map diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py index 855d49658b..7b02de2b7a 100644 --- a/keras_nlp/models/gemma/gemma_backbone_test.py +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -106,26 +106,34 @@ def test_distribution(self): for w in model.weights: if "token_embedding/embeddings" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) if "attention/query/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/key/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/value/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/attention_output/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, None, "model") + tuple(w.value.sharding.spec), ("model", None, "batch") ) if "ffw_gating/kernel" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) if "ffw_gating_2/kernel" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) if "ffw_linearl" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) diff --git a/keras_nlp/models/llama/llama_causal_lm_preprocessor.py b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py new file mode 100644 index 0000000000..a221185582 --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py @@ -0,0 +1,185 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.LlamaCausalLMPreprocessor") +class LlamaCausalLMPreprocessor(LlamaPreprocessor): + """Llama Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.LlamaCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.LlamaCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.LlamaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.LlamaCausalLMPreprocessor.from_preset( + "llama_base_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`LlamaCausalLMPreprocessor` generates `y` and " + "`sample_weight` based on your input data, but your data " + "already contains `y` or `sample_weight`. Your `y` and " + "`sample_weight` will be ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # Convert the inputs to numpy arrays if they aren't a tensor already. + if not isinstance(token_ids, tf.Tensor): + token_ids = ops.convert_to_numpy(token_ids) + # Make sure the numpy array has type `int32` since + # `SentencePieceProcessor.detokenize` only accepts `int32` arrays. + token_ids = token_ids.astype("int32") + if not isinstance(padding_mask, tf.Tensor): + padding_mask = ops.convert_to_numpy(padding_mask) + padding_mask = padding_mask.astype("bool") + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + padding_mask = padding_mask & ( + token_ids != self.tokenizer.start_token_id + ) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py b/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..aa4d155c8c --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py @@ -0,0 +1,90 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = (["the quick brown fox"],) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=LlamaCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = LlamaCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[3, 8, 4, 6, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[8, 4, 6, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 8, 4, 6, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 8, 4, 6, 0, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaCausalLMPreprocessor.presets: + self.run_preset_test( + cls=LlamaCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_preprocessor.py b/keras_nlp/models/llama/llama_preprocessor.py new file mode 100644 index 0000000000..580557f50d --- /dev/null +++ b/keras_nlp/models/llama/llama_preprocessor.py @@ -0,0 +1,191 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.LlamaPreprocessor") +class LlamaPreprocessor(Preprocessor): + """A Llama preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do three things: + + 1. Tokenize any number of input segments using the `tokenizer`. + 2. Pack the inputs together using a `keras_nlp.layers.StartEndPacker`. + with the appropriate tokens. + 3. Construct a dictionary with keys `"token_ids"`, and `"padding_mask"` + that can be passed directly to `keras_nlp.models.LlamaBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + Args: + tokenizer: A `keras_nlp.models.LlamaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A tensor of single string sequences, or a tuple of multiple + tensor sequences to be packed together. Inputs may be batched or + unbatched. For single sequences, raw python inputs will be converted + to tensors. For multiple sequences, pass tensors directly. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the from_preset(). + ```python + preprocessor = keras_nlp.models.LlamaPreprocessor.from_preset( + "llama_base_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize and a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Preprocess a batch of sentence pairs. + # When handling multiple sequences, always convert to tensors first! + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + preprocessor((first, second)) + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.LlamaPreprocessor.from_preset( + "llama_base_en" + ) + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((first, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(first) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map labeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices(((first, second), label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices((first, second)) + + # Watch out for tf.data's default unpacking of tuples here! + # Best to invoke the `preprocessor` directly in this case. + ds = ds.map( + lambda first, second: preprocessor(x=(first, second)), + num_parallel_calls=tf.data.AUTOTUNE, + ) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=1024, + add_start_token=True, + add_end_token=False, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.add_start_token = add_start_token + self.add_end_token = add_end_token + self.sequence_length = sequence_length + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "Llama requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using Llama" + " for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + + @classproperty + def tokenizer_cls(cls): + return LlamaTokenizer diff --git a/keras_nlp/models/llama/llama_preprocessor_test.py b/keras_nlp/models/llama/llama_preprocessor_test.py new file mode 100644 index 0000000000..6807886812 --- /dev/null +++ b/keras_nlp/models/llama/llama_preprocessor_test.py @@ -0,0 +1,57 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ( + ["the quick brown fox"], + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=LlamaPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ), + ) + + def test_errors_for_2d_list_input(self): + preprocessor = LlamaPreprocessor(**self.init_kwargs) + ambiguous_input = [["one", "two"], ["three", "four"]] + with self.assertRaises(ValueError): + preprocessor(ambiguous_input)