From 465ce7b2d87a83489a0d2fe578f07ba2e818ab13 Mon Sep 17 00:00:00 2001 From: Mohamed Abu El-Nasr <64566340+abuelnasr0@users.noreply.github.com> Date: Mon, 11 Mar 2024 20:16:19 +0200 Subject: [PATCH 01/13] Add bloom presets (#1501) --- keras_nlp/models/bloom/bloom_presets.py | 99 ++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 4 deletions(-) diff --git a/keras_nlp/models/bloom/bloom_presets.py b/keras_nlp/models/bloom/bloom_presets.py index 7d24c04aa5..134de5173d 100644 --- a/keras_nlp/models/bloom/bloom_presets.py +++ b/keras_nlp/models/bloom/bloom_presets.py @@ -17,14 +17,105 @@ "bloom_560m_multi": { "metadata": { "description": ( - "24-layer Bloom model. trained on 45 natural languages and " - "12 programming languages." + "24-layer Bloom model with hidden dimension of 1024. " + "trained on 45 natural languages and 12 programming languages." ), - "params": 816115712, + "params": 559214592, "official_name": "BLOOM", "path": "bloom", - "model_card": "https://huggingface.co/bigscience/bloom", + "model_card": "https://huggingface.co/bigscience/bloom-560m", }, "kaggle_handle": "kaggle://keras/bloom/keras/bloom_560m_multi/3", }, + "bloom_1.1b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 1536. " + "trained on 45 natural languages and 12 programming languages." + ), + "params": 1065314304, + "official_name": "BLOOM", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloom-1b1", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.1b_multi/1", + }, + "bloom_1.7b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 2048. " + "trained on 45 natural languages and 12 programming languages." + ), + "params": 1722408960, + "official_name": "BLOOM", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloom-1b7", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_1.7b_multi/1", + }, + "bloom_3b_multi": { + "metadata": { + "description": ( + "30-layer Bloom model with hidden dimension of 2560. " + "trained on 45 natural languages and 12 programming languages." + ), + "params": 3002557440, + "official_name": "BLOOM", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloom-3b", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloom_3b_multi/1", + }, + "bloomz_560m_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 1024. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 559214592, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-560m", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_560m_multi/1", + }, + "bloomz_1.1b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 1536. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 1065314304, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-1b1", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.1b_multi/1", + }, + "bloomz_1.7b_multi": { + "metadata": { + "description": ( + "24-layer Bloom model with hidden dimension of 2048. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 1722408960, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-1b7", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_1.7b_multi/1", + }, + "bloomz_3b_multi": { + "metadata": { + "description": ( + "30-layer Bloom model with hidden dimension of 2560. " + "finetuned on crosslingual task mixture (xP3) dataset." + ), + "params": 3002557440, + "official_name": "BLOOMZ", + "path": "bloom", + "model_card": "https://huggingface.co/bigscience/bloomz-3b", + }, + "kaggle_handle": "kaggle://keras/bloom/keras/bloomz_3b_multi/1", + }, } From 361e392f716cbca77a514216790cda09f5fe62f4 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Mon, 11 Mar 2024 11:26:31 -0700 Subject: [PATCH 02/13] Create workflow for auto assignment of issues and for stale issues (#1495) * Create auto-assignment.js * Create auto-assignment.yml * Create stale-issue-pr.yml * Minor changes to auto_labler --- .github/workflows/auto-assignment.yml | 21 ++++++++ .github/workflows/scripts/auto-assignment.js | 43 +++++++++++++++++ .github/workflows/scripts/labeler.js | 10 ++-- .github/workflows/stale-issue-pr.yml | 50 ++++++++++++++++++++ 4 files changed, 121 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/auto-assignment.yml create mode 100644 .github/workflows/scripts/auto-assignment.js create mode 100644 .github/workflows/stale-issue-pr.yml diff --git a/.github/workflows/auto-assignment.yml b/.github/workflows/auto-assignment.yml new file mode 100644 index 0000000000..de72da8ba2 --- /dev/null +++ b/.github/workflows/auto-assignment.yml @@ -0,0 +1,21 @@ +name: auto-assignment +on: + issues: + types: + - opened + +permissions: + contents: read + issues: write + pull-requests: write + +jobs: + welcome: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/github-script@v7 + with: + script: | + const script = require('./\.github/workflows/scripts/auto-assignment.js') + script({github, context}) diff --git a/.github/workflows/scripts/auto-assignment.js b/.github/workflows/scripts/auto-assignment.js new file mode 100644 index 0000000000..176b305f39 --- /dev/null +++ b/.github/workflows/scripts/auto-assignment.js @@ -0,0 +1,43 @@ +/** Automatically assign issues and PRs to users in the `assigneesList` + * on a rotating basis. + + @param {!object} + GitHub objects can call GitHub APIs using their built-in library functions. + The context object contains issue and PR details. +*/ + +module.exports = async ({ github, context }) => { + let issueNumber; + let assigneesList; + // Is this an issue? If so, assign the issue number. Otherwise, assign the PR number. + if (context.payload.issue) { + //assignee List for issues. + assigneesList = ["SuryanarayanaY", "sachinprasadhs"]; + issueNumber = context.payload.issue.number; + } else { + //assignee List for PRs. + assigneesList = [mattdangerw]; + issueNumber = context.payload.number; + } + console.log("assignee list", assigneesList); + console.log("entered auto assignment for this issue: ", issueNumber); + if (!assigneesList.length) { + console.log("No assignees found for this repo."); + return; + } + let noOfAssignees = assigneesList.length; + let selection = issueNumber % noOfAssignees; + let assigneeForIssue = assigneesList[selection]; + + console.log( + "issue Number = ", + issueNumber + " , assigning to: ", + assigneeForIssue + ); + return github.rest.issues.addAssignees({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + assignees: [assigneeForIssue], + }); +}; diff --git a/.github/workflows/scripts/labeler.js b/.github/workflows/scripts/labeler.js index 7240113cc3..aa4178645e 100644 --- a/.github/workflows/scripts/labeler.js +++ b/.github/workflows/scripts/labeler.js @@ -23,16 +23,20 @@ You may obtain a copy of the License at module.exports = async ({ github, context }) => { const issue_title = context.payload.issue ? context.payload.issue.title : context.payload.pull_request.title - const issue_discription = context.payload.issue ? context.payload.issue.body : context.payload.pull_request.body + let issue_description = context.payload.issue ? context.payload.issue.body : context.payload.pull_request.body const issue_number = context.payload.issue ? context.payload.issue.number : context.payload.pull_request.number const keyword_label = { gemma:'Gemma' } const labelsToAdd = [] - console.log(issue_title,issue_discription,issue_number) + console.log(issue_title,issue_description,issue_number) + if (issue_description==null) + { + issue_description = '' + } for(const [keyword, label] of Object.entries(keyword_label)){ - if(issue_title.toLowerCase().indexOf(keyword) !=-1 || issue_discription.toLowerCase().indexOf(keyword) !=-1 ){ + if(issue_title.toLowerCase().indexOf(keyword) !=-1 || issue_description.toLowerCase().indexOf(keyword) !=-1 ){ console.log(`'${keyword}'keyword is present inside the title or description. Pushing label '${label}' to row.`) labelsToAdd.push(label) } diff --git a/.github/workflows/stale-issue-pr.yml b/.github/workflows/stale-issue-pr.yml new file mode 100644 index 0000000000..034fb4c266 --- /dev/null +++ b/.github/workflows/stale-issue-pr.yml @@ -0,0 +1,50 @@ +name: Close inactive issues +on: + schedule: + - cron: "30 1 * * *" +jobs: + close-issues: + runs-on: ubuntu-latest + permissions: + issues: write + pull-requests: write + steps: + - name: Awaiting response issues + uses: actions/stale@v9 + with: + days-before-issue-stale: 14 + days-before-issue-close: 14 + stale-issue-label: "stale" + # reason for closed the issue default value is not_planned + close-issue-reason: completed + only-labels: "stat:awaiting response from contributor" + stale-issue-message: > + This issue is stale because it has been open for 14 days with no activity. + It will be closed if no further activity occurs. Thank you. + # List of labels to remove when issues/PRs unstale. + labels-to-remove-when-unstale: "stat:awaiting response from contributor" + close-issue-message: > + This issue was closed because it has been inactive for 28 days. + Please reopen if you'd like to work on this further. + days-before-pr-stale: 14 + days-before-pr-close: 14 + stale-pr-message: "This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you." + close-pr-message: "This PR was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further." + repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Contribution issues + uses: actions/stale@v9 + with: + days-before-issue-stale: 180 + days-before-issue-close: 365 + stale-issue-label: "stale" + # reason for closed the issue default value is not_planned + close-issue-reason: not_planned + any-of-labels: "stat:contributions welcome,good first issue" + # List of labels to remove when issues/PRs unstale. + labels-to-remove-when-unstale: "stat:contributions welcome,good first issue" + stale-issue-message: > + This issue is stale because it has been open for 180 days with no activity. + It will be closed if no further activity occurs. Thank you. + close-issue-message: > + This issue was closed because it has been inactive for more than 1 year. + repo-token: ${{ secrets.GITHUB_TOKEN }} From 26a2fb892efa8b459003541af8a2b2cf89278ef7 Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Mon, 11 Mar 2024 16:33:05 -0500 Subject: [PATCH 03/13] Update requirements to TF 2.16 GA (#1503) --- requirements-jax-cuda.txt | 4 ++-- requirements-tensorflow-cuda.txt | 4 ++-- requirements-torch-cuda.txt | 4 ++-- requirements.txt | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 10d07dffce..2d53a76c87 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tensorflow-cpu==2.16.0rc0 # Pin to rc until TF 2.16 release -tensorflow-text==2.16.0rc0 +tensorflow-cpu~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.0 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index 7cc2e705e6..14f1441924 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow with cuda support. -tensorflow[and-cuda]==2.16.0rc0 # Pin to rc until TF 2.16 release -tensorflow-text==2.16.0rc0 +tensorflow[and-cuda]~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.0 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 1bbe6a2e76..89362bb846 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. -tensorflow-cpu==2.16.0rc0 # Pin to rc until TF 2.16 release -tensorflow-text==2.16.0rc0 +tensorflow-cpu~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.0 # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 diff --git a/requirements.txt b/requirements.txt index e7cc934b17..f1e0b31956 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Tensorflow. -tensorflow-cpu==2.16.0rc0 # Pin to rc until TF 2.16 release -tensorflow-text==2.16.0rc0 +tensorflow-cpu~=2.16.1 # Pin to TF 2.16 +tensorflow-text~=2.16.0 # Torch. --extra-index-url https://download.pytorch.org/whl/cpu From 9832d8ae2faaca08ade0ce762ac138576e042bc0 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 11 Mar 2024 23:26:09 +0000 Subject: [PATCH 04/13] Expose Task and Backbone (#1506) These are already exposed on KerasCV, and I think it is time to also expose these in KerasNLP. This will give us a class to document common model functionality to all backbones such as `enable_lora` and `token_embedding` on keras.io. It can also open up a path for writing a custom architecture outside the library itself. --- keras_nlp/models/__init__.py | 2 ++ keras_nlp/models/backbone.py | 3 ++- keras_nlp/models/task.py | 3 ++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 1abfc0dc84..033a9dc874 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -20,6 +20,7 @@ ) from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer +from keras_nlp.models.backbone import Backbone from keras_nlp.models.bart.bart_backbone import BartBackbone from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor from keras_nlp.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM @@ -130,6 +131,7 @@ from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_nlp.models.t5.t5_backbone import T5Backbone from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer +from keras_nlp.models.task import Task from keras_nlp.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 867616da69..bfdc8207ad 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.utils.preset_utils import check_preset_class @@ -20,7 +21,7 @@ from keras_nlp.utils.python_utils import format_docstring -@keras.saving.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.models.Backbone") class Backbone(keras.Model): def __init__(self, *args, dtype=None, **kwargs): super().__init__(*args, **kwargs) diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 0656d2194e..9957f6546f 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -16,6 +16,7 @@ from rich import markup from rich import table as rich_table +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.utils.keras_utils import print_msg @@ -26,7 +27,7 @@ from keras_nlp.utils.python_utils import format_docstring -@keras.saving.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.models.Task") class Task(PipelineModel): """Base class for Task models.""" From bfc4d8e2cb83fc5cc54466d41525328d402e37e9 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Mon, 11 Mar 2024 23:29:18 +0000 Subject: [PATCH 05/13] Clean up and add our gemma conversion script (#1493) * Clean up and add our gemma conversion script From flax -> keras. Useful to have as reference. * Fix comments * Convert to bfloat16 weights * Review comment --- .../convert_gemma_checkpoints.py | 224 ++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 tools/checkpoint_conversion/convert_gemma_checkpoints.py diff --git a/tools/checkpoint_conversion/convert_gemma_checkpoints.py b/tools/checkpoint_conversion/convert_gemma_checkpoints.py new file mode 100644 index 0000000000..ed81e023d4 --- /dev/null +++ b/tools/checkpoint_conversion/convert_gemma_checkpoints.py @@ -0,0 +1,224 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Convert Gemma flax checkpoints to the Keras format. + +Setup: +pip install -r requirements.txt +pip install git+https://github.com/google-deepmind/gemma.git +python pip_build.py --install + +Usage: +cd tools/checkpoint_conversion +python convert_gemma_checkpoints.py --preset gemma_2b_en +""" + +import os + +os.environ["KERAS_BACKEND"] = "jax" +# No GPU for conversion, makes memory management easier. +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import kagglehub # noqa: E402 +import keras # noqa: E402 +import numpy as np # noqa: E402 +import sentencepiece # noqa: E402 +from absl import app # noqa: E402 +from absl import flags # noqa: E402 +from gemma import params as params_lib # noqa: E402 +from gemma import sampler as sampler_lib # noqa: E402 +from gemma import transformer as transformer_lib # noqa: E402 + +import keras_nlp # noqa: E402 + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "gemma_2b_en": "google/gemma/flax/2b", + "gemma_7b_en": "google/gemma/flax/7b", + "gemma_instruct_2b_en": "google/gemma/flax/2b-it", + "gemma_instruct_7b_en": "google/gemma/flax/7b-it", +} + + +flags.DEFINE_string( + "preset", + None, + f'Must be one of {",".join(PRESET_MAP.keys())}', + required=True, +) + + +def download_flax_model(handle): + return kagglehub.model_download(handle) + + +def convert_model(flax_config, vocab_size): + return keras_nlp.models.GemmaBackbone( + vocabulary_size=vocab_size, + num_layers=flax_config.num_layers, + num_query_heads=flax_config.num_heads, + num_key_value_heads=flax_config.num_kv_heads, + hidden_dim=flax_config.embed_dim, + intermediate_dim=flax_config.hidden_dim * 2, + head_dim=flax_config.head_dim, + ) + + +def convert_tokenizer(proto_path): + return keras_nlp.models.GemmaTokenizer(proto=proto_path) + + +def convert_weights(keras_model, flax_config, flax_params): + # Chomp the embedding weights. Upstream pads for TPU efficiency, but this + # leads to weird gotchas (you need to disregard part of your output logits). + embeddings = flax_params["transformer"]["embedder"]["input_embedding"] + embeddings = np.asarray(embeddings[: keras_model.vocabulary_size, :]) + keras_model.get_layer("token_embedding").set_weights([embeddings]) + keras_model.get_layer("final_normalization").set_weights( + [np.asarray(flax_params["transformer"]["final_norm"]["scale"])] + ) + for i in range(flax_config.num_layers): + flax_layer_name = f"layer_{i}" + keras_block = keras_model.get_layer(f"decoder_block_{i}") + + flax_block = flax_params["transformer"][flax_layer_name] + keras_block.pre_attention_norm.set_weights( + [flax_block["pre_attention_norm"]["scale"]] + ) + keras_block.pre_ffw_norm.set_weights( + [flax_block["pre_ffw_norm"]["scale"]] + ) + + keras_block.gating_ffw.set_weights( + [flax_block["mlp"]["gating_einsum"][0]] + ) + keras_block.gating_ffw_2.set_weights( + [flax_block["mlp"]["gating_einsum"][1]] + ) + keras_block.ffw_linear.set_weights([flax_block["mlp"]["linear"]]) + + attn_block = flax_block["attn"] + if flax_config.num_heads != flax_config.num_kv_heads: + # MQA. + keras_block.attention.query_dense.kernel.assign( + np.asarray(attn_block["q_einsum"]["w"][:, :, :]) + ) + keras_block.attention.key_dense.kernel.assign( + np.asarray(attn_block["kv_einsum"]["w"][0, :, :, :]) + ) + keras_block.attention.value_dense.kernel.assign( + np.asarray(attn_block["kv_einsum"]["w"][1, :, :, :]) + ) + else: + # MHA. + keras_block.attention.query_dense.kernel.assign( + np.asarray(attn_block["qkv_einsum"]["w"][0, :, :, :]) + ) + keras_block.attention.key_dense.kernel.assign( + np.asarray(attn_block["qkv_einsum"]["w"][1, :, :, :]) + ) + keras_block.attention.value_dense.kernel.assign( + np.asarray(attn_block["qkv_einsum"]["w"][2, :, :, :]) + ) + keras_block.attention.output_dense.kernel.assign( + flax_block["attn"]["attn_vec_einsum"]["w"] + ) + + +def validate_output( + keras_model, + keras_tokenizer, + flax_params, + flax_tokenizer, +): + input_str = "What is Keras?" + length = 32 + + # KerasNLP + preprocessor = keras_nlp.models.GemmaCausalLMPreprocessor(keras_tokenizer) + gemma_lm = keras_nlp.models.GemmaCausalLM( + backbone=keras_model, + preprocessor=preprocessor, + ) + keras_output = gemma_lm.generate([input_str], max_length=length) + keras_output = keras_output[0] + + # Flax + transformer_config = transformer_lib.TransformerConfig.from_params( + flax_params, + cache_size=length, + ) + transformer = transformer_lib.Transformer(transformer_config) + sampler = sampler_lib.Sampler( + transformer=transformer, + vocab=flax_tokenizer, + params=flax_params["transformer"], + ) + flax_output = sampler( + input_strings=[input_str], + total_generation_steps=length - 5, # Length of "What is Keras?" + ) + flax_output = input_str + flax_output.text[0] + + # Comparing the outputs. + print("🔶 KerasNLP output:", keras_output) + print("🔶 Flax output:", flax_output) + + +def main(_): + preset = FLAGS.preset + + assert ( + preset in PRESET_MAP.keys() + ), f'Invalid preset {preset}. Must be one of {",".join(PRESET_MAP.keys())}' + + print(f"🏃 Coverting {preset}") + + # Currently all flax weights are bfloat16 (and have much faster download + # times for it). We follow suit with Keras weights. + keras.config.set_floatx("bfloat16") + + handle = PRESET_MAP[preset] + flax_dir = download_flax_model(handle) + proto_path = flax_dir + "/tokenizer.model" + print("✅ Flax model downloaded from kaggle") + + variant = handle.split("/")[-1] + flax_tokenier = sentencepiece.SentencePieceProcessor() + flax_tokenier.Load(proto_path) + flax_params = params_lib.load_and_format_params(flax_dir + "/" + variant) + flax_config = transformer_lib.TransformerConfig.from_params(flax_params) + print("✅ Flax model loaded") + + keras_tokenizer = convert_tokenizer(proto_path) + vocab_size = keras_tokenizer.vocabulary_size() + keras_model = convert_model(flax_config, vocab_size) + print("✅ Keras model loaded") + + convert_weights(keras_model, flax_config, flax_params) + print("✅ Weights converted") + + validate_output(keras_model, keras_tokenizer, flax_params, flax_tokenier) + print("✅ Output validated") + + keras_nlp.src.utils.preset_utils.save_to_preset(keras_model, preset) + keras_nlp.src.utils.preset_utils.save_to_preset( + keras_tokenizer, preset, config_filename="tokenizer.json" + ) + print(f"🏁 Preset saved to ./{preset}") + + +if __name__ == "__main__": + app.run(main) From a8da424b2ff183778d6f9fee3c3415a1f5fc07ef Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Mon, 11 Mar 2024 19:32:19 -0500 Subject: [PATCH 06/13] Don't auto-update JAX GPU (#1507) * Don't auto-update JAX GPU * Ignore jax GPU updates --- .github/dependabot.yml | 3 +++ requirements-jax-cuda.txt | 2 +- requirements-tensorflow-cuda.txt | 2 +- requirements-torch-cuda.txt | 2 +- requirements.txt | 2 +- 5 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 0df37b1230..eb7a6ac0c5 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -21,3 +21,6 @@ updates: python: patterns: - "*" + ignore: + # ignore all updates for JAX GPU due to cuda version issue + - dependency-name: "jax[cuda12_pip]" diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 2d53a76c87..2ded131217 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. tensorflow-cpu~=2.16.1 # Pin to TF 2.16 -tensorflow-text~=2.16.0 +tensorflow-text~=2.16.1 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index 14f1441924..5426beb5a3 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow with cuda support. tensorflow[and-cuda]~=2.16.1 # Pin to TF 2.16 -tensorflow-text~=2.16.0 +tensorflow-text~=2.16.1 # Torch cpu-only version. --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 89362bb846..43dc4c5ef5 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,6 +1,6 @@ # Tensorflow cpu-only version. tensorflow-cpu~=2.16.1 # Pin to TF 2.16 -tensorflow-text~=2.16.0 +tensorflow-text~=2.16.1 # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 diff --git a/requirements.txt b/requirements.txt index f1e0b31956..8578a4199b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ # Tensorflow. tensorflow-cpu~=2.16.1 # Pin to TF 2.16 -tensorflow-text~=2.16.0 +tensorflow-text~=2.16.1 # Torch. --extra-index-url https://download.pytorch.org/whl/cpu From 09d2fdd3e15fad87a55860b22ed6543db1bfbee3 Mon Sep 17 00:00:00 2001 From: Gabriel Rasskin <43894452+grasskin@users.noreply.github.com> Date: Wed, 13 Mar 2024 14:20:35 -0400 Subject: [PATCH 07/13] Keep rope at float32 precision (#1497) * Keep rope at float32 precision * Carry out all of RoPE in float32 * Formatting * Cleanup * Do not cast x --- keras_nlp/models/gemma/gemma_attention.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py index 80c2ac6a63..e01c1f8ce4 100644 --- a/keras_nlp/models/gemma/gemma_attention.py +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -94,13 +94,14 @@ def _apply_rope(self, x, positions): # TODO: refactor to use RotaryEmbedding layer? max_wavelength = 10000 x_shape = ops.shape(x) - freq_exponents = (2.0 / x_shape[-1]) * ops.cast( - ops.arange(x_shape[-1] // 2, dtype="float32"), self.compute_dtype + freq_exponents = (2.0 / x_shape[-1]) * ops.arange( + x_shape[-1] // 2, dtype="float32" ) timescale = max_wavelength**freq_exponents radians = positions[..., None] / timescale[None, None, :] radians = radians[..., None, :] - sin, cos = ops.sin(radians), ops.cos(radians) + sin = ops.cast(ops.sin(radians), self.compute_dtype) + cos = ops.cast(ops.cos(radians), self.compute_dtype) x1, x2 = ops.split(x, 2, axis=-1) # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA # compilation on jax. We should be able to remove this once the @@ -156,10 +157,9 @@ def call( ): seq_len = ops.shape(x)[1] start_index = cache_update_index - positions = ops.cast( - ops.arange(seq_len, dtype="float32"), self.compute_dtype - ) - positions = positions + ops.cast(start_index, self.compute_dtype) + positions = ops.arange(seq_len, dtype="float32") + + positions = positions + ops.cast(start_index, "float32") query = self.query_dense(x) query = self._apply_rope(query, positions) From 51368769cf82fc60964d062d1de07406aa8eb176 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:11:56 -0700 Subject: [PATCH 08/13] Bump the python group with 2 updates (#1509) Bumps the python group with 2 updates: torch and torchvision. Updates `torch` from 2.1.2 to 2.2.1+cu121 Updates `torchvision` from 0.16.2 to 0.17.1+cu121 --- updated-dependencies: - dependency-name: torch dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python - dependency-name: torchvision dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-torch-cuda.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 43dc4c5ef5..050dd85b1c 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -4,8 +4,8 @@ tensorflow-text~=2.16.1 # Torch with cuda support. --extra-index-url https://download.pytorch.org/whl/cu121 -torch==2.1.2 -torchvision==0.16.2 +torch==2.2.1+cu121 +torchvision==0.17.1+cu121 # Jax cpu-only version. jax[cpu] From a59a26fc7de85b9652b434cae380a73f748634f6 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 13 Mar 2024 14:06:01 -0700 Subject: [PATCH 09/13] Fixes for the LLaMA backbone + add dropout (#1499) * Firxes for the LLaMA backbone + add dropout * Address review comments CachedLlamaAttention -> LlamaAttention and make parameter state public in the attention layer * Remove self._hidden_dim and self._head_dim --- keras_nlp/models/llama/llama_attention.py | 120 +++++++++--------- keras_nlp/models/llama/llama_backbone.py | 108 ++++++++++------ keras_nlp/models/llama/llama_backbone_test.py | 1 - keras_nlp/models/llama/llama_decoder.py | 75 +++++++---- 4 files changed, 182 insertions(+), 122 deletions(-) diff --git a/keras_nlp/models/llama/llama_attention.py b/keras_nlp/models/llama/llama_attention.py index 529e73b009..33ffcef209 100644 --- a/keras_nlp/models/llama/llama_attention.py +++ b/keras_nlp/models/llama/llama_attention.py @@ -18,34 +18,33 @@ class LlamaAttention(keras.layers.Layer): - """Grouped query attention for Llama models""" + """A cached grounded query attention layer with sliding window.""" def __init__( self, num_query_heads, num_key_value_heads, + rope_max_wavelength=10000, rope_scaling_factor=1.0, kernel_initializer="glorot_uniform", - rope_max_wavelength=10000, - max_sequence_length=512, + dropout=0, **kwargs, ): super().__init__(**kwargs) self.num_query_heads = num_query_heads self.num_key_value_heads = num_key_value_heads + self.dropout = dropout self.num_key_value_groups = num_query_heads // num_key_value_heads + self.rope_max_wavelength = rope_max_wavelength - self.kernel_initializer = keras.initializers.get(kernel_initializer) - self.max_sequence_length = max_sequence_length + self.kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) self.rope_scaling_factor = rope_scaling_factor - self.rope_max_wavelength = rope_max_wavelength def build(self, inputs_shape): - self.hidden_dim = inputs_shape[-1] - self.attn_head_size = self.hidden_dim // self.num_query_heads - # Einsum variables: # b = batch size # q = query length @@ -54,18 +53,27 @@ def build(self, inputs_shape): # u = num query heads # v = num key/value heads # h = head dim + hidden_dim = inputs_shape[-1] + head_dim = hidden_dim // self.num_query_heads + self._norm_factor = ops.sqrt(ops.cast(head_dim, self.compute_dtype)) + self._query_dense = keras.layers.EinsumDense( equation="bqm,muh->bquh", - output_shape=(None, self.num_query_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=(None, self.num_query_heads, head_dim), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="query", ) self._query_dense.build(inputs_shape) + self._key_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", - output_shape=(None, self.num_key_value_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="key", ) @@ -73,8 +81,12 @@ def build(self, inputs_shape): self._value_dense = keras.layers.EinsumDense( equation="bkm,mvh->bkvh", - output_shape=(None, self.num_key_value_heads, self.attn_head_size), - kernel_initializer=clone_initializer(self.kernel_initializer), + output_shape=( + None, + self.num_key_value_heads, + head_dim, + ), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="value", ) @@ -86,21 +98,28 @@ def build(self, inputs_shape): name="attention_softmax", ) + self._dropout_layer = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + ) + self._output_dense = keras.layers.EinsumDense( - equation="bqm,mh->bqh", - output_shape=(None, self.hidden_dim), - kernel_initializer=clone_initializer(self.kernel_initializer), + equation="bquh,uhm->bqm", + output_shape=(None, hidden_dim), + kernel_initializer=self.kernel_initializer, dtype=self.dtype_policy, name="attention_output", ) - self._output_dense.build(inputs_shape) + self._output_dense.build((None, None, self.num_query_heads, head_dim)) - self._rotary_embedding_layer = RotaryEmbedding( + self.rotary_embedding_layer = RotaryEmbedding( max_wavelength=self.rope_max_wavelength, scaling_factor=self.rope_scaling_factor, dtype=self.dtype_policy, ) - self._rotary_embedding_layer.build(inputs_shape) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" self.built = True @@ -110,6 +129,7 @@ def call( attention_mask=None, cache=None, cache_update_index=None, + training=None, ): query = self._query_dense(hidden_states) @@ -136,75 +156,61 @@ def call( key = self._key_dense(hidden_states) value = self._value_dense(hidden_states) - query = self._rotary_embedding_layer(query) - key = self._rotary_embedding_layer(key) + query = self.rotary_embedding_layer(query) + key = self.rotary_embedding_layer(key) - key = ops.tile(key, [1, 1, self.num_key_value_groups, 1]) - value = ops.tile(value, [1, 1, self.num_key_value_groups, 1]) + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) - attention_output, attention_scores = self._compute_attention( + attention_output = self._compute_attention( query, key, value, attention_mask ) - attention_output_shape = ops.shape(attention_output) - - attention_output = ops.reshape( - attention_output, - [ - attention_output_shape[0], - attention_output_shape[1], - self.hidden_dim, - ], + attention_output = self._dropout_layer( + attention_output, training=training ) attention_output = self._output_dense(attention_output) if cache is not None: - return (attention_output, cache) + return attention_output, cache return attention_output def _masked_softmax(self, attention_scores, attention_mask=None): if attention_mask is not None: - mask_expansion_axis = -3 - for _ in range( - len(attention_scores.shape) - len(attention_mask.shape) - ): - attention_mask = ops.expand_dims( - attention_mask, axis=mask_expansion_axis - ) - return self._softmax(attention_scores, attention_mask) + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) def _compute_attention(self, query, key, value, attention_mask=None): - attention_scores = ops.einsum("aecd,abcd->acbe", key, query) - - norm_factor = ops.sqrt( - ops.convert_to_tensor(self.attn_head_size, self.compute_dtype) - ) + attention_scores = ops.einsum(self._dot_product_equation, query, key) - attention_scores /= norm_factor + attention_scores = attention_scores / self._norm_factor attention_scores = self._masked_softmax( attention_scores, attention_mask ) attention_scores = ops.cast(attention_scores, self.compute_dtype) attention_output = ops.einsum( - "acbe,aecd->abcd", attention_scores, value + self._combine_equation, attention_scores, value ) - return attention_output, attention_scores + return attention_output def get_config(self): config = super().get_config() config.update( { "num_query_heads": self.num_query_heads, - "hidden_dim": self.hidden_dim, + "num_key_value_heads": self.num_key_value_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, "kernel_initializer": keras.initializers.serialize( self.kernel_initializer ), - "rope_max_wavelength": self.rope_max_wavelength, - "rope_scaling_factor": self.rope_scaling_factor, - "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, + "dropout": self.dropout, } ) return config diff --git a/keras_nlp/models/llama/llama_backbone.py b/keras_nlp/models/llama/llama_backbone.py index 733d9ef434..b5383d528a 100644 --- a/keras_nlp/models/llama/llama_backbone.py +++ b/keras_nlp/models/llama/llama_backbone.py @@ -11,14 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# import copy + from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import keras from keras_nlp.backend import ops from keras_nlp.layers.modeling.reversible_embedding import ReversibleEmbedding from keras_nlp.models.backbone import Backbone -from keras_nlp.models.llama.llama_decoder import LlamaDecoder + +# from keras_nlp.models.llama.llama_presets import backbone_presets +from keras_nlp.models.llama.llama_decoder import LlamaTransformerDecoder from keras_nlp.models.llama.llama_layernorm import LlamaLayerNorm +# from keras_nlp.utils.python_utils import classproperty + def _llama_kernel_initializer(stddev=0.02): return keras.initializers.RandomNormal(stddev=stddev) @@ -27,41 +34,64 @@ def _llama_kernel_initializer(stddev=0.02): @keras_nlp_export("keras_nlp.models.LlamaBackbone") class LlamaBackbone(Backbone): """ - LLaMA core network with hyperparameters. + The Llama Transformer core architecture with hyperparameters. This network implements a Transformer-based decoder network, - LLaMA, as described in ["LLaMA: Open Foundation and Fine-Tuned Language Models"](https://arxiv.org/abs/2302.13971). + Llama, as described in + ["Llama 7B"](https://arxiv.org/pdf/2310.06825.pdf). + It includes the embedding lookups and transformer layers. The default constructor gives a fully customizable, randomly initialized - LLaMA model with any number of layers, heads, and embedding - dimensions. This backbone also supports LLaMA2 checkpoints. + Llama model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. Args: - vocabulary_size: int. The size of the token vocabulary. - num_layers: int. The number of transformer layers. - num_query_heads: int. The number of attention heads for each transformer. - The hidden size must be divisible by the number of attention heads. - hidden_dim: int. The size of the transformer encoding and pooler layers. - intermediate_dim: int. The output dimension of the first Dense layer in - a two-layer feedforward network for each transformer. - num_key_value_heads: int. This is the number of key_value heads that - should be used to implement Grouped Query Attention. If num_key_value_heads=num_attention_heads, - the model will use Multi Head Attention (MHA), if num_key_value_heads=1 - the model will use Multi Query Attention (MQA) - rope_scaling_factor: float. The scaling factor for calculation of rotary - embedding - rope_max_wavelength: int. The maximum angular wavelength of the - sine/cosine curves, for rotary embeddings. - layer_norm_epsilon: float. a value added to the denominator for - numerical stability. - max_sequence_length: int. The maximum sequence length that this encoder - can consume. If `None`, `max_sequence_length` uses the value from - sequence length. This determines the variable shape for positional - embeddings. + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers. + num_query_heads (int): The number of query attention heads for + each transformer. + hidden_dim (int): The size of the transformer encoding and pooling layers. + intermediate_dim (int): The output dimension of the first Dense layer in a + three-layer feedforward network for each transformer. + num_key_value_heads (int): The number of key and value attention heads for + each transformer. + rope_max_wavelength (int, optional): The maximum angular wavelength of the + sine/cosine curves, for rotary embeddings. Defaults to `10000`. + rope_scaling_factor (float, optional): The scaling factor for calculation + of roatary embedding. Defaults to `1.0`. + layer_norm_epsilon (float, optional): Epsilon for the layer normalization + layers in the transformer decoder. Defaults to `1e-6`. dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use for model computations and weights. Note that some computations, such as softmax and layer normalization, will always be done at float32 precision regardless of dtype. + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Llama decoder. + model = keras_nlp.models.LlamaBackbone.from_preset("llama7b_base_en") + model(input_data) + + # Randomly initialized Llama decoder with custom config. + model = keras_nlp.models.LlamaBackbone( + vocabulary_size=10, + hidden_dim=512, + num_layers=2, + num_query_heads=32, + num_key_value_heads=8, + intermediate_dim=1024, + layer_norm_epsilon=1e-6, + dtype="float32" + ) + model(input_data) + ``` """ def __init__( @@ -72,10 +102,10 @@ def __init__( hidden_dim, intermediate_dim, num_key_value_heads, - rope_scaling_factor=1.0, rope_max_wavelength=10000, - layer_norm_epsilon=1e-5, - max_sequence_length=4096, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + dropout=0, dtype=None, **kwargs, ): @@ -83,31 +113,31 @@ def __init__( self.token_embedding = ReversibleEmbedding( input_dim=vocabulary_size, output_dim=hidden_dim, - embeddings_initializer=_llama_kernel_initializer(stddev=0.01), tie_weights=False, + embeddings_initializer=_llama_kernel_initializer(stddev=0.01), dtype=dtype, name="token_embedding", ) self.transformer_layers = [] for i in range(num_layers): - layer = LlamaDecoder( + layer = LlamaTransformerDecoder( intermediate_dim=intermediate_dim, num_query_heads=num_query_heads, num_key_value_heads=num_key_value_heads, - rope_scaling_factor=rope_scaling_factor, - max_sequence_length=max_sequence_length, rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, layer_norm_epsilon=layer_norm_epsilon, activation=ops.silu, kernel_initializer=_llama_kernel_initializer(stddev=0.02), + dropout=dropout, dtype=dtype, name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) self.layer_norm = LlamaLayerNorm( - dtype=dtype, epsilon=layer_norm_epsilon, - name="layer_norm", + dtype=dtype, + name="sequence_output_layernorm", ) # === Functional Model === @@ -140,8 +170,8 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.num_key_value_heads = num_key_value_heads self.rope_scaling_factor = rope_scaling_factor - self.max_sequence_length = max_sequence_length self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout def get_config(self): config = super().get_config() @@ -155,8 +185,12 @@ def get_config(self): "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, } ) return config + + # @classproperty + # def presets(cls): + # return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/llama/llama_backbone_test.py b/keras_nlp/models/llama/llama_backbone_test.py index efff972c6b..56d8c44bd3 100644 --- a/keras_nlp/models/llama/llama_backbone_test.py +++ b/keras_nlp/models/llama/llama_backbone_test.py @@ -28,7 +28,6 @@ def setUp(self): "num_key_value_heads": 2, "hidden_dim": 8, "intermediate_dim": 8, - "max_sequence_length": 10, } self.input_data = { "token_ids": ops.ones((2, 5), dtype="int32"), diff --git a/keras_nlp/models/llama/llama_decoder.py b/keras_nlp/models/llama/llama_decoder.py index 3b9d6906b8..1ef247c575 100644 --- a/keras_nlp/models/llama/llama_decoder.py +++ b/keras_nlp/models/llama/llama_decoder.py @@ -24,20 +24,20 @@ from keras_nlp.utils.keras_utils import clone_initializer -class LlamaDecoder(keras.layers.Layer): - """Llama decoder block.""" +class LlamaTransformerDecoder(keras.layers.Layer): + """A Transformer decoder layer for the Llama backbone.""" def __init__( self, intermediate_dim, num_query_heads, num_key_value_heads, + rope_max_wavelength=10000, rope_scaling_factor=1.0, - activation="relu", + activation="silu", layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", - rope_max_wavelength=10000, - max_sequence_length=512, + dropout=0, **kwargs, ): super().__init__(**kwargs) @@ -48,37 +48,50 @@ def __init__( self.rope_max_wavelength = rope_max_wavelength self.rope_scaling_factor = rope_scaling_factor - self.max_sequence_length = max_sequence_length + self.dropout = dropout + self.activation = keras.activations.get(activation) self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.supports_masking = True + def build(self, decoder_sequence_shape): + self._decoder_sequence_shape = decoder_sequence_shape self.hidden_dim = decoder_sequence_shape[-1] - # Self attention layers. + # Self attention layer. self._self_attention_layer = LlamaAttention( num_query_heads=self.num_query_heads, num_key_value_heads=self.num_key_value_heads, rope_max_wavelength=self.rope_max_wavelength, - max_sequence_length=self.max_sequence_length, rope_scaling_factor=self.rope_scaling_factor, kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, dtype=self.dtype_policy, + name="self_attention", ) self._self_attention_layer.build(decoder_sequence_shape) self._self_attention_layernorm = LlamaLayerNorm( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, + name="self_attention_layernorm", ) self._self_attention_layernorm.build(decoder_sequence_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) # Feedforward layers. self._feedforward_intermediate_dense = keras.layers.Dense( self.intermediate_dim, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_intermediate_dense", ) self._feedforward_intermediate_dense.build(decoder_sequence_shape) @@ -86,23 +99,30 @@ def build(self, decoder_sequence_shape): self.intermediate_dim, activation=self.activation, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_gate_dense", ) self._feedforward_gate_dense.build(decoder_sequence_shape) self._feedforward_output_dense = keras.layers.Dense( self.hidden_dim, kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, dtype=self.dtype_policy, + name="feedforward_output_dense", ) - intermediate_shape = list(decoder_sequence_shape) - intermediate_shape[-1] = self.intermediate_dim - self._feedforward_output_dense.build(tuple(intermediate_shape)) + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) self._feedforward_layernorm = LlamaLayerNorm( epsilon=self.layer_norm_epsilon, dtype=self.dtype_policy, + name="feedforward_layernorm", ) self._feedforward_layernorm.build(decoder_sequence_shape) @@ -115,6 +135,7 @@ def call( decoder_attention_mask=None, self_attention_cache=None, self_attention_cache_update_index=None, + training=None, ): self_attention_mask = self._compute_self_attention_mask( decoder_sequence=decoder_sequence, @@ -125,10 +146,9 @@ def call( ) residual = decoder_sequence - x = self._self_attention_layernorm( - decoder_sequence, - ) + x = self._self_attention_layernorm(decoder_sequence) + # Self attention block. x = self._self_attention_layer( hidden_states=x, attention_mask=self_attention_mask, @@ -139,6 +159,8 @@ def call( if self_attention_cache is not None: x, self_attention_cache = x + x = self._self_attention_dropout(x, training=training) + x = x + residual residual = x @@ -152,7 +174,7 @@ def call( decoder_output = x + residual if self_attention_cache is not None: - return (decoder_output, self_attention_cache) + return decoder_output, self_attention_cache return decoder_output def _compute_self_attention_mask( @@ -160,8 +182,8 @@ def _compute_self_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask, - self_attention_cache=None, - self_attention_cache_update_index=None, + self_attention_cache, + self_attention_cache_update_index, ): decoder_mask = merge_padding_and_attention_mask( decoder_sequence, decoder_padding_mask, decoder_attention_mask @@ -174,16 +196,16 @@ def _compute_self_attention_mask( if self_attention_cache is not None: input_length = ops.shape(self_attention_cache)[2] + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + causal_mask = compute_causal_mask( - batch_size, - input_length, - output_length, - ( - 0 - if self_attention_cache_update_index is None - else self_attention_cache_update_index - ), + batch_size, input_length, output_length, cache_update_index ) + return ( ops.minimum(decoder_mask, causal_mask) if decoder_mask is not None @@ -198,17 +220,16 @@ def get_config(self): config.update( { "intermediate_dim": self.intermediate_dim, - "hidden_dim": self.hidden_dim, "num_query_heads": self.num_query_heads, "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, - "max_sequence_length": self.max_sequence_length, "activation": keras.activations.serialize(self.activation), "layer_norm_epsilon": self.layer_norm_epsilon, "kernel_initializer": keras.initializers.serialize( self.kernel_initializer ), + "dropout": self.dropout, } ) return config From e81daa06fbc12c7b5069315c1609b4a1e267a83c Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Wed, 13 Mar 2024 15:54:46 -0700 Subject: [PATCH 10/13] Add `LlamaPreprocessor` and `LlamaCausalLMPreprocessor` (#1511) * Add a preprocessor for the Llama backbone * Add causal lm preprocessor for the Llama backbone --- .../llama/llama_causal_lm_preprocessor.py | 185 +++++++++++++++++ .../llama_causal_lm_preprocessor_test.py | 90 +++++++++ keras_nlp/models/llama/llama_preprocessor.py | 191 ++++++++++++++++++ .../models/llama/llama_preprocessor_test.py | 57 ++++++ 4 files changed, 523 insertions(+) create mode 100644 keras_nlp/models/llama/llama_causal_lm_preprocessor.py create mode 100644 keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py create mode 100644 keras_nlp/models/llama/llama_preprocessor.py create mode 100644 keras_nlp/models/llama/llama_preprocessor_test.py diff --git a/keras_nlp/models/llama/llama_causal_lm_preprocessor.py b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py new file mode 100644 index 0000000000..a221185582 --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_preprocessor.py @@ -0,0 +1,185 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from absl import logging + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import ops +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight + + +@keras_nlp_export("keras_nlp.models.LlamaCausalLMPreprocessor") +class LlamaCausalLMPreprocessor(LlamaPreprocessor): + """Llama Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_nlp.models.LlamaCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_nlp.models.LlamaCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_nlp.models.LlamaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_nlp.models.LlamaCausalLMPreprocessor.from_preset( + "llama_base_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + if y is not None or sample_weight is not None: + logging.warning( + "`LlamaCausalLMPreprocessor` generates `y` and " + "`sample_weight` based on your input data, but your data " + "already contains `y` or `sample_weight`. Your `y` and " + "`sample_weight` will be ignored." + ) + sequence_length = sequence_length or self.sequence_length + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + # The last token does not have a next token, so we truncate it out. + x = { + "token_ids": token_ids[..., :-1], + "padding_mask": padding_mask[..., :-1], + } + # Target `y` will be the next token. + y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:] + return pack_x_y_sample_weight(x, y, sample_weight) + + def generate_preprocess( + self, + x, + sequence_length=None, + ): + """Convert strings to integer token input for generation. + + Similar to calling the layer for training, this method takes in strings + or tensor strings, tokenizes and packs the input, and computes a padding + mask masking all inputs not filled in with a padded value. + + Unlike calling the layer for training, this method does not compute + labels and will never append a `tokenizer.end_token_id` to the end of + the sequence (as generation is expected to continue at the end of the + inputted prompt). + """ + if not self.built: + self.build(None) + + x = convert_inputs_to_list_of_tensor_segments(x)[0] + x = self.tokenizer(x) + token_ids, padding_mask = self.packer( + x, sequence_length=sequence_length, add_end_value=False + ) + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def generate_postprocess( + self, + x, + ): + """Convert integer token output to strings for generation. + + This method reverses `generate_preprocess()`, by first removing all + padding and start/end tokens, and then converting the integer sequence + back to a string. + """ + token_ids, padding_mask = x["token_ids"], x["padding_mask"] + # Convert the inputs to numpy arrays if they aren't a tensor already. + if not isinstance(token_ids, tf.Tensor): + token_ids = ops.convert_to_numpy(token_ids) + # Make sure the numpy array has type `int32` since + # `SentencePieceProcessor.detokenize` only accepts `int32` arrays. + token_ids = token_ids.astype("int32") + if not isinstance(padding_mask, tf.Tensor): + padding_mask = ops.convert_to_numpy(padding_mask) + padding_mask = padding_mask.astype("bool") + # Strip any special tokens during detokenization (e.g. the start and + # end markers). In the future we could make this configurable. + padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id) + padding_mask = padding_mask & ( + token_ids != self.tokenizer.start_token_id + ) + token_ids = tf.ragged.boolean_mask(token_ids, padding_mask) + return self.tokenizer.detokenize(token_ids) diff --git a/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py b/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..aa4d155c8c --- /dev/null +++ b/keras_nlp/models/llama/llama_causal_lm_preprocessor_test.py @@ -0,0 +1,90 @@ +# Copyright 2024 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from keras_nlp.models.llama.llama_causal_lm_preprocessor import ( + LlamaCausalLMPreprocessor, +) +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = (["the quick brown fox"],) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=LlamaCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [[3, 8, 4, 6, 0, 0, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 0, 0, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = LlamaCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[3, 8, 4, 6, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[8, 4, 6, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 8, 4, 6, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 8, 4, 6, 0, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = LlamaCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in LlamaCausalLMPreprocessor.presets: + self.run_preset_test( + cls=LlamaCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_nlp/models/llama/llama_preprocessor.py b/keras_nlp/models/llama/llama_preprocessor.py new file mode 100644 index 0000000000..580557f50d --- /dev/null +++ b/keras_nlp/models/llama/llama_preprocessor.py @@ -0,0 +1,191 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.layers.preprocessing.start_end_packer import StartEndPacker +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.models.preprocessor import Preprocessor +from keras_nlp.utils.keras_utils import ( + convert_inputs_to_list_of_tensor_segments, +) +from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.LlamaPreprocessor") +class LlamaPreprocessor(Preprocessor): + """A Llama preprocessing layer which tokenizes and packs inputs. + + This preprocessing layer will do three things: + + 1. Tokenize any number of input segments using the `tokenizer`. + 2. Pack the inputs together using a `keras_nlp.layers.StartEndPacker`. + with the appropriate tokens. + 3. Construct a dictionary with keys `"token_ids"`, and `"padding_mask"` + that can be passed directly to `keras_nlp.models.LlamaBackbone`. + + This layer can be used directly with `tf.data.Dataset.map` to preprocess + string data in the `(x, y, sample_weight)` format used by + `keras.Model.fit`. + + Args: + tokenizer: A `keras_nlp.models.LlamaTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A tensor of single string sequences, or a tuple of multiple + tensor sequences to be packed together. Inputs may be batched or + unbatched. For single sequences, raw python inputs will be converted + to tensors. For multiple sequences, pass tensors directly. + y: Any label data. Will be passed through unaltered. + sample_weight: Any label weight data. Will be passed through unaltered. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + + Directly calling the from_preset(). + ```python + preprocessor = keras_nlp.models.LlamaPreprocessor.from_preset( + "llama_base_en" + ) + + # Tokenize and pack a single sentence. + preprocessor("The quick brown fox jumped.") + + # Tokenize and a batch of single sentences. + preprocessor(["The quick brown fox jumped.", "Call me Ishmael."]) + + # Preprocess a batch of sentence pairs. + # When handling multiple sequences, always convert to tensors first! + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + preprocessor((first, second)) + ``` + + Mapping with `tf.data.Dataset`. + ```python + preprocessor = keras_nlp.models.LlamaPreprocessor.from_preset( + "llama_base_en" + ) + first = tf.constant(["The quick brown fox jumped.", "Call me Ishmael."]) + second = tf.constant(["The fox tripped.", "Oh look, a whale."]) + label = tf.constant([1, 1]) + + # Map labeled single sentences. + ds = tf.data.Dataset.from_tensor_slices((first, label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled single sentences. + ds = tf.data.Dataset.from_tensor_slices(first) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map labeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices(((first, second), label)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map unlabeled sentence pairs. + ds = tf.data.Dataset.from_tensor_slices((first, second)) + + # Watch out for tf.data's default unpacking of tuples here! + # Best to invoke the `preprocessor` directly in this case. + ds = ds.map( + lambda first, second: preprocessor(x=(first, second)), + num_parallel_calls=tf.data.AUTOTUNE, + ) + ``` + """ + + def __init__( + self, + tokenizer, + sequence_length=1024, + add_start_token=True, + add_end_token=False, + **kwargs, + ): + super().__init__(**kwargs) + self.tokenizer = tokenizer + self.packer = None + self.add_start_token = add_start_token + self.add_end_token = add_end_token + self.sequence_length = sequence_length + + def build(self, input_shape): + # Defer packer creation to `build()` so that we can be sure tokenizer + # assets have loaded when restoring a saved model. + self.packer = StartEndPacker( + start_value=self.tokenizer.start_token_id, + end_value=self.tokenizer.end_token_id, + sequence_length=self.sequence_length, + return_padding_mask=True, + ) + self.built = True + + def get_config(self): + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "add_start_token": self.add_start_token, + "add_end_token": self.add_end_token, + } + ) + return config + + def call( + self, + x, + y=None, + sample_weight=None, + sequence_length=None, + ): + x = convert_inputs_to_list_of_tensor_segments(x) + if len(x) != 1: + raise ValueError( + "Llama requires each input feature to contain only " + f"one segment, but received {len(x)}. If you are using Llama" + " for a multi-segment classification task, please refer to " + "classification models like BERT or RoBERTa." + ) + sequence_length = sequence_length or self.sequence_length + token_ids, padding_mask = self.packer( + self.tokenizer(x[0]), + sequence_length=sequence_length, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + ) + x = { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + return pack_x_y_sample_weight(x, y, sample_weight) + + @property + def sequence_length(self): + """The padded length of model input sequences.""" + return self._sequence_length + + @sequence_length.setter + def sequence_length(self, value): + self._sequence_length = value + if self.packer is not None: + self.packer.sequence_length = value + + @classproperty + def tokenizer_cls(cls): + return LlamaTokenizer diff --git a/keras_nlp/models/llama/llama_preprocessor_test.py b/keras_nlp/models/llama/llama_preprocessor_test.py new file mode 100644 index 0000000000..6807886812 --- /dev/null +++ b/keras_nlp/models/llama/llama_preprocessor_test.py @@ -0,0 +1,57 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from keras_nlp.models.llama.llama_preprocessor import LlamaPreprocessor +from keras_nlp.models.llama.llama_tokenizer import LlamaTokenizer +from keras_nlp.tests.test_case import TestCase + + +class LlamaPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = LlamaTokenizer( + # Generated using create_llama_test_proto.py + proto=os.path.join(self.get_test_data_dir(), "llama_test_vocab.spm") + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = ( + ["the quick brown fox"], + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=LlamaPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 0, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 0, 0, 0]], + }, + [1], # Pass through labels. + [1.0], # Pass through sample_weights. + ), + ) + + def test_errors_for_2d_list_input(self): + preprocessor = LlamaPreprocessor(**self.init_kwargs) + ambiguous_input = [["one", "two"], ["three", "four"]] + with self.assertRaises(ValueError): + preprocessor(ambiguous_input) From f1279010a8e0b9679180d98d7688ca855b232879 Mon Sep 17 00:00:00 2001 From: Tirth Patel Date: Thu, 14 Mar 2024 11:40:21 -0700 Subject: [PATCH 11/13] Always run the rotary embedding layer in float32 (#1508) * Always run the rotary embedding layer in float32 * Fix the int32 issue with TensorFlow * Only run sin/cos embedding compute step in float32 * Avoid start_index from downcasting automatically * Use stack instrad of concatenate --- keras_nlp/layers/modeling/rotary_embedding.py | 47 ++++++++++++------- keras_nlp/models/gemma/gemma_attention.py | 43 +++++++---------- 2 files changed, 46 insertions(+), 44 deletions(-) diff --git a/keras_nlp/layers/modeling/rotary_embedding.py b/keras_nlp/layers/modeling/rotary_embedding.py index 45f77ce494..b494d559bd 100644 --- a/keras_nlp/layers/modeling/rotary_embedding.py +++ b/keras_nlp/layers/modeling/rotary_embedding.py @@ -85,30 +85,42 @@ def __init__( self.built = True def call(self, inputs, start_index=0): + inputs = ops.moveaxis( + inputs, (self.feature_axis, self.sequence_axis), (-1, 1) + ) cos_emb, sin_emb = self._compute_cos_sin_embedding(inputs, start_index) - return self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) + output = self._apply_rotary_pos_emb(inputs, cos_emb, sin_emb) + return ops.moveaxis( + output, (-1, 1), (self.feature_axis, self.sequence_axis) + ) def _apply_rotary_pos_emb(self, tensor, cos_emb, sin_emb): - x1, x2 = ops.split(tensor, 2, axis=self.feature_axis) - half_rot_tensor = ops.concatenate((-x2, x1), axis=self.feature_axis) + x1, x2 = ops.split(tensor, 2, axis=-1) + # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA + # compilation on jax. We should be able to remove this once the + # following PR is in all jax releases we care about: + # https://github.com/openxla/xla/pull/7875 + half_rot_tensor = ops.stack((-x2, x1), axis=-2) + half_rot_tensor = ops.reshape(half_rot_tensor, ops.shape(tensor)) return (tensor * cos_emb) + (half_rot_tensor * sin_emb) def _compute_cos_sin_embedding(self, inputs, start_index=0): - def get_axis(axis): - return axis if axis > 0 else len(inputs.shape) + axis + start_index = ops.cast(start_index, dtype="float32") - feature_axis = get_axis(self.feature_axis) - sequence_axis = get_axis(self.sequence_axis) + feature_axis = len(inputs.shape) - 1 + sequence_axis = 1 rotary_dim = ops.shape(inputs)[feature_axis] inverse_freq = self._get_inverse_freq(rotary_dim) - seq_len = ops.shape(inputs)[self.sequence_axis] - tensor = ops.cast(ops.arange(seq_len), self.compute_dtype) + start_index + seq_len = ops.shape(inputs)[sequence_axis] + tensor = ops.arange(seq_len, dtype="float32") + start_index - tensor = ops.cast(tensor, dtype=inverse_freq.dtype) freq = ops.einsum("i,j->ij", tensor, inverse_freq) - embedding = ops.concatenate((freq, freq), axis=-1) + embedding = ops.stack((freq, freq), axis=-2) + embedding = ops.reshape( + embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2) + ) # Reshape the embedding to be broadcastable with input shape. if feature_axis < sequence_axis: @@ -117,17 +129,16 @@ def get_axis(axis): if axis != sequence_axis and axis != feature_axis: embedding = ops.expand_dims(embedding, axis) - return ops.cos(embedding), ops.sin(embedding) + cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype) + sin_emb = ops.cast(ops.sin(embedding), self.compute_dtype) + return cos_emb, sin_emb def _get_inverse_freq(self, rotary_dim): - freq_range = ops.arange(0, rotary_dim, 2) - freq_range = ops.cast(freq_range, self.compute_dtype) - freq_range = freq_range / ops.cast( - self.scaling_factor, self.compute_dtype - ) + freq_range = ops.arange(0, rotary_dim, 2, dtype="float32") + freq_range = freq_range / ops.cast(self.scaling_factor, "float32") inverse_freq = 1.0 / ( self.max_wavelength - ** (freq_range / ops.cast(rotary_dim, self.compute_dtype)) + ** (freq_range / ops.cast(rotary_dim, "float32")) ) return inverse_freq diff --git a/keras_nlp/models/gemma/gemma_attention.py b/keras_nlp/models/gemma/gemma_attention.py index e01c1f8ce4..4b391264a2 100644 --- a/keras_nlp/models/gemma/gemma_attention.py +++ b/keras_nlp/models/gemma/gemma_attention.py @@ -15,6 +15,7 @@ from keras_nlp.backend import keras from keras_nlp.backend import ops +from keras_nlp.layers.modeling.rotary_embedding import RotaryEmbedding from keras_nlp.utils.keras_utils import clone_initializer @@ -87,28 +88,23 @@ def build(self, inputs_shape): (None, None, self.num_query_heads, self.head_dim) ) self.softmax = keras.layers.Softmax(dtype="float32") + + self.rope_layer = RotaryEmbedding( + max_wavelength=10_000.0, dtype=self.dtype_policy + ) + self.built = True - def _apply_rope(self, x, positions): + def _apply_rope(self, x, start_index): """Rope rotate q or k.""" - # TODO: refactor to use RotaryEmbedding layer? - max_wavelength = 10000 - x_shape = ops.shape(x) - freq_exponents = (2.0 / x_shape[-1]) * ops.arange( - x_shape[-1] // 2, dtype="float32" + x = self.rope_layer(x, start_index=start_index) + # Gemma uses a different layout for positional embeddings. + # The transformation below ensures the embeddings are numerically + # equivalent to the original gemma implementation. + x = ops.reshape( + ops.stack(ops.split(x, 2, axis=-1), axis=-1), ops.shape(x) ) - timescale = max_wavelength**freq_exponents - radians = positions[..., None] / timescale[None, None, :] - radians = radians[..., None, :] - sin = ops.cast(ops.sin(radians), self.compute_dtype) - cos = ops.cast(ops.cos(radians), self.compute_dtype) - x1, x2 = ops.split(x, 2, axis=-1) - # Avoid `ops.concatenate` for now, to avoid a obscure bug with XLA - # compilation on jax. We should be able to remove this once the - # following PR is in all jax releases we care about: - # https://github.com/openxla/xla/pull/7875 - output = ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) - return ops.reshape(output, x_shape) + return x def _compute_attention( self, @@ -155,19 +151,14 @@ def call( cache_update_index=0, training=False, ): - seq_len = ops.shape(x)[1] - start_index = cache_update_index - positions = ops.arange(seq_len, dtype="float32") - - positions = positions + ops.cast(start_index, "float32") query = self.query_dense(x) - query = self._apply_rope(query, positions) + query = self._apply_rope(query, cache_update_index) if cache is not None: key_cache = cache[:, 0, ...] value_cache = cache[:, 1, ...] key_update = self.key_dense(x) - key_update = self._apply_rope(key_update, positions) + key_update = self._apply_rope(key_update, cache_update_index) value_update = self.value_dense(x) start = [0, cache_update_index, 0, 0] key = ops.slice_update(key_cache, start, key_update) @@ -175,7 +166,7 @@ def call( cache = ops.stack((key, value), axis=1) else: key = self.key_dense(x) - key = self._apply_rope(key, positions) + key = self._apply_rope(key, cache_update_index) value = self.value_dense(x) attention_vec = self._compute_attention( From 673c63baf22149132dafaef1652c3815bd99eb2e Mon Sep 17 00:00:00 2001 From: Ramesh Sampath <1437573+sampathweb@users.noreply.github.com> Date: Thu, 14 Mar 2024 18:36:57 -0500 Subject: [PATCH 12/13] Remove install of Python 3.9 (#1514) --- .kokoro/github/ubuntu/gpu/build.sh | 3 --- 1 file changed, 3 deletions(-) diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 87cd206495..b8d47dbe9c 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -14,11 +14,8 @@ if [[ -z "${KAGGLE_USERNAME}" ]]; then fi set -x - cd "${KOKORO_ROOT}/" -sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 - PYTHON_BINARY="/usr/bin/python3.9" "${PYTHON_BINARY}" -m venv venv From 451158027cb79d324584d28d7ea41c00dbb6c77a Mon Sep 17 00:00:00 2001 From: Qianli Scott Zhu Date: Thu, 14 Mar 2024 16:56:33 -0700 Subject: [PATCH 13/13] Update gemma_backbone.py for sharding config. (#1491) * Update gemma_backbone.py for sharding config. * Update unit test and fix format. * Update sharding spec for gemma based on gemma training. --- keras_nlp/models/gemma/gemma_backbone.py | 31 ++++++++++++++----- keras_nlp/models/gemma/gemma_backbone_test.py | 24 +++++++++----- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/keras_nlp/models/gemma/gemma_backbone.py b/keras_nlp/models/gemma/gemma_backbone.py index c829aa948f..06f5b0f601 100644 --- a/keras_nlp/models/gemma/gemma_backbone.py +++ b/keras_nlp/models/gemma/gemma_backbone.py @@ -194,7 +194,11 @@ def presets(cls): return copy.deepcopy(backbone_presets) @staticmethod - def get_layout_map(device_mesh, model_parallel_dim_name="model"): + def get_layout_map( + device_mesh, + model_parallel_dim_name="model", + data_parallel_dim_name="batch", + ): """Get a `keras.distribution.LayoutMap` for model parallel distribution. The returned `LayoutMap` contains the sharding spec for the gemma @@ -221,6 +225,8 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"): distribution. model_parallel_dim_name: The axis name of the device mesh, where the weights should be partition on. + data_parallel_dim_name: The axis name of the device mesh, where + the data should be partition on. Return: `keras.distribution.LayoutMap` that contains the sharding spec of all the model weights. @@ -248,21 +254,30 @@ def get_layout_map(device_mesh, model_parallel_dim_name="model"): f"{model_parallel_dim_name} is not found in the " f"device_mesh.axis_names. {device_mesh.axis_name=}" ) + if data_parallel_dim_name not in device_mesh.axis_names: + raise ValueError( + f"{data_parallel_dim_name} is not found in the " + f"device_mesh.axis_names. {device_mesh.axis_name=}" + ) + # Note that it is possible to further config the mesh to be 3D, eg + # (data, seq, model). We leave it as 2D for now for simplicity. + data_dim = data_parallel_dim_name model_dim = model_parallel_dim_name - # The sharding is partition for the hidden_dim of the model. + # The sharding config is based on the Gemma team training config. + # See https://arxiv.org/abs/2403.08295 layout_map = keras.distribution.LayoutMap(device_mesh) - layout_map["token_embedding/embeddings"] = (None, model_dim) + layout_map["token_embedding/embeddings"] = (model_dim, data_dim) layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ( - None, model_dim, + data_dim, None, ) layout_map["decoder_block.*attention_output.*kernel"] = ( - None, - None, model_dim, + None, + data_dim, ) - layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None) - layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim) + layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim) + layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim) return layout_map diff --git a/keras_nlp/models/gemma/gemma_backbone_test.py b/keras_nlp/models/gemma/gemma_backbone_test.py index 855d49658b..7b02de2b7a 100644 --- a/keras_nlp/models/gemma/gemma_backbone_test.py +++ b/keras_nlp/models/gemma/gemma_backbone_test.py @@ -106,26 +106,34 @@ def test_distribution(self): for w in model.weights: if "token_embedding/embeddings" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + ) if "attention/query/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/key/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/value/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, "model", None) + tuple(w.value.sharding.spec), ("model", "batch", None) ) if "attention/attention_output/kernel" in w.path: self.assertEqual( - tuple(w.value.sharding.spec), (None, None, "model") + tuple(w.value.sharding.spec), ("model", None, "batch") ) if "ffw_gating/kernel" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) if "ffw_gating_2/kernel" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), ("model", None)) + self.assertEqual( + tuple(w.value.sharding.spec), ("batch", "model") + ) if "ffw_linearl" in w.path: - self.assertEqual(tuple(w.value.sharding.spec), (None, "model")) + self.assertEqual( + tuple(w.value.sharding.spec), ("model", "batch") + )