Skip to content

Commit

Permalink
[T5 1.1] Enable v1.1 Presets (#1948)
Browse files Browse the repository at this point in the history
* enable t5 1.1 weights

* add xxl config

* consolidate usage of relu vs geglu vs variable activation functions

* update preset param counts and xll config

* remove commented code

* revert to use_gated_activation

* add kaggle links

* fix comment

* remove xl preset

* update kerashub links to keras links
  • Loading branch information
DavidLandup0 authored Oct 30, 2024
1 parent 316775f commit bd57aed
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 23 deletions.
Binary file added keras_hub/.DS_Store
Binary file not shown.
9 changes: 5 additions & 4 deletions keras_hub/src/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ class T5Backbone(Backbone):
projections in the multi-head attention layers. Defaults to
hidden_dim / num_heads.
dropout: float. Dropout probability for the Transformer layers.
activation: activation function (or activation string name). The
activation to be used in the inner dense blocks of the
Transformer layers. Defaults to `"relu"`.
activation: string. The activation function to use in the dense blocks
of the Transformer Layers.
use_gated_activation: boolean. Whether to use activation gating in
the inner dense blocks of the Transformer layers.
the inner dense blocks of the Transformer layers. When used with
the GELU activation function, this is referred to as GEGLU
(gated GLU) from https://arxiv.org/pdf/2002.05202.
The original T5 architecture didn't use gating, but more
recent versions do. Defaults to `True`.
layer_norm_epsilon: float. Epsilon factor to be used in the
Expand Down
32 changes: 31 additions & 1 deletion keras_hub/src/models/t5/t5_presets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""XLM-RoBERTa model preset configurations."""
"""T5 model preset configurations."""

backbone_presets = {
"t5_small_multi": {
Expand All @@ -14,6 +14,16 @@
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_small_multi/2",
},
"t5_1.1_small": {
"metadata": {
"description": (""),
"params": 60511616,
"official_name": "T5 1.1",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_small",
},
"t5_base_multi": {
"metadata": {
"description": (
Expand All @@ -27,6 +37,16 @@
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_base_multi/2",
},
"t5_1.1_base": {
"metadata": {
"description": (""),
"params": 247577856,
"official_name": "T5 1.1",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_base",
},
"t5_large_multi": {
"metadata": {
"description": (
Expand All @@ -40,6 +60,16 @@
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_large_multi/2",
},
"t5_1.1_large": {
"metadata": {
"description": (""),
"params": 750251008,
"official_name": "T5 1.1",
"path": "t5",
"model_card": "https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md",
},
"kaggle_handle": "kaggle://keras/t5/keras/t5_1.1_large",
},
"flan_small_multi": {
"metadata": {
"description": (
Expand Down
Binary file added tools/.DS_Store
Binary file not shown.
110 changes: 92 additions & 18 deletions tools/checkpoint_conversion/convert_t5_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,98 @@
from absl import app
from absl import flags
from checkpoint_conversion_utils import get_md5_checksum
from keras import ops

import keras_hub

PRESET_MAP = {
"t5_small_multi": "t5-small",
"t5_base_multi": "t5-base",
"t5_large_multi": "t5-large",
"t5_1.1_small": "google/t5-v1_1-small",
"t5_1.1_base": "google/t5-v1_1-base",
"t5_1.1_large": "google/t5-v1_1-large",
"t5_1.1_xl": "google/t5-v1_1-xl",
"t5_1.1_xxl": "google/t5-v1_1-xxl",
"flan_small_multi": "google/flan-t5-small",
"flan_base_multi": "google/flan-t5-base",
"flan_large_multi": "google/flan-t5-large",
}


PARAM_MAP = {
"t5_1.1_small": {
"trainable": True,
"vocabulary_size": 32128,
"hidden_dim": 512,
"intermediate_dim": 1024,
"num_layers": 8,
"num_heads": 6,
"activation": "gelu",
"key_value_dim": 64,
"dropout": 0.1,
"use_gated_activation": True,
"layer_norm_epsilon": 1e-6,
"tie_embedding_weights": False,
},
"t5_1.1_base": {
"trainable": True,
"vocabulary_size": 32128,
"hidden_dim": 768,
"intermediate_dim": 2048,
"num_layers": 12,
"num_heads": 12,
"activation": "gelu",
"key_value_dim": 64,
"dropout": 0.1,
"use_gated_activation": True,
"layer_norm_epsilon": 1e-6,
"tie_embedding_weights": False,
},
"t5_1.1_large": {
"trainable": True,
"vocabulary_size": 32128,
"hidden_dim": 1024,
"intermediate_dim": 2816,
"num_layers": 24,
"num_heads": 16,
"activation": "gelu",
"key_value_dim": 64,
"dropout": 0.1,
"use_gated_activation": True,
"layer_norm_epsilon": 1e-6,
"tie_embedding_weights": False,
},
"t5_1.1_xl": {
"trainable": True,
"vocabulary_size": 32128,
"hidden_dim": 2048,
"intermediate_dim": 5120,
"num_layers": 24,
"num_heads": 32,
"activation": "gelu",
"key_value_dim": 64,
"dropout": 0.1,
"use_gated_activation": True,
"layer_norm_epsilon": 1e-6,
"tie_embedding_weights": False,
},
"t5_1.1_xxl": {
"trainable": True,
"vocabulary_size": 32128,
"hidden_dim": 4096,
"intermediate_dim": 10240,
"num_layers": 24,
"num_heads": 64,
"activation": "gelu",
"key_value_dim": 64,
"dropout": 0.1,
"use_gated_activation": True,
"layer_norm_epsilon": 1e-6,
"tie_embedding_weights": False,
},
}


FLAGS = flags.FLAGS

flags.DEFINE_string(
Expand Down Expand Up @@ -52,9 +131,7 @@ def extract_vocab(hf_tokenizer):


def convert_checkpoints(hf_model):
keras_hub_model = keras_hub.models.T5Backbone.from_preset(
FLAGS.preset, load_weights=False
)
keras_hub_model = keras_hub.models.T5Backbone(**PARAM_MAP[FLAGS.preset])

hf_wts = hf_model.state_dict()
print("Original weights:")
Expand Down Expand Up @@ -308,17 +385,12 @@ def check_output(
keras_hidden_states = keras_out["decoder_sequence_output"]
hf_hidden_states = hf_out.decoder_hidden_states[-1]

keras_outputs = ops.take_along_axis(
keras_hidden_states, ops.where(decoder_padding_mask)
)
hf_outputs = ops.take_along_axis(
hf_hidden_states, ops.where(decoder_padding_mask)
)

print("-> KerasHub output:", keras_outputs[0:5])
print("-> HF output:", hf_outputs[0:5])
print("-> KerasHub output:", keras_hidden_states[0:5])
print("-> HF output:", hf_hidden_states[0:5])
np.testing.assert_allclose(
keras_outputs.detach().numpy(), hf_outputs.detach().numpy(), atol=1e-5
keras_hidden_states.numpy(),
hf_hidden_states.detach().numpy(),
atol=1e-2,
)

if keras_model.tie_embedding_weights:
Expand All @@ -333,7 +405,7 @@ def check_output(
print("-> KerasHub logits:", keras_logits[0:5])
print("-> HF logits:", hf_logits[0:5])
np.testing.assert_allclose(
keras_logits.detach().numpy(), hf_logits.detach().numpy(), atol=1e-3
keras_logits.numpy(), hf_logits.detach().numpy(), atol=1e-1
)


Expand All @@ -352,16 +424,18 @@ def main(_):
keras_model = convert_checkpoints(hf_model)

# Save the model.
model_path = f"./{FLAGS.preset}/model.weights.h5"
model_path = f"./{FLAGS.preset}"
weight_path = os.path.join(model_path, "model.weights.h5")
print(f"\n-> Save KerasHub model weights to `{model_path}`.")
keras_model.save_weights(model_path)
keras_model.save_to_preset(model_path)
print("-> Print MD5 checksum of the model weights files.")
print(f"`{model_path}` md5sum: ", get_md5_checksum(model_path))
print(f"`{model_path}` md5sum: ", get_md5_checksum(weight_path))
print(f"-> Param count {count_params(keras_model.weights)}")

print("\n-> Convert vocab.")
hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_id)
keras_tokenizer = extract_vocab(hf_tokenizer)
keras_tokenizer.save_to_preset(model_path)

check_output(
keras_model,
Expand Down

0 comments on commit bd57aed

Please sign in to comment.