From bd57aeda9e3e13ed514f08d25ad62b9de951642c Mon Sep 17 00:00:00 2001 From: David Landup <60978046+DavidLandup0@users.noreply.github.com> Date: Wed, 30 Oct 2024 11:25:56 +0900 Subject: [PATCH] [T5 1.1] Enable v1.1 Presets (#1948) * enable t5 1.1 weights * add xxl config * consolidate usage of relu vs geglu vs variable activation functions * update preset param counts and xll config * remove commented code * revert to use_gated_activation * add kaggle links * fix comment * remove xl preset * update kerashub links to keras links --- keras_hub/.DS_Store | Bin 0 -> 6148 bytes keras_hub/src/models/t5/t5_backbone.py | 9 +- keras_hub/src/models/t5/t5_presets.py | 32 ++++- tools/.DS_Store | Bin 0 -> 8196 bytes .../convert_t5_checkpoints.py | 110 +++++++++++++++--- 5 files changed, 128 insertions(+), 23 deletions(-) create mode 100644 keras_hub/.DS_Store create mode 100644 tools/.DS_Store diff --git a/keras_hub/.DS_Store b/keras_hub/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..4147d62e9eab062b0a418ce3f8fa741f925aacf5 GIT binary patch literal 6148 zcmeHKu};G<5IrXyipo$n#_$RKfJ&hX3u_tJTACt?)G7glRG9J$EUXMnd;u#HpTx|& zJB!2$MPdLUbSK#_Ip5_u&yg=CA~(984~Y6i)J0>Aw$OcIJkG6Uk@su^ojfC_6Uu2p zr!;9ro8vDkz-!mTYE7u58?0Tuev{?dV4N4LQl}mFHb<_CESoKh86vueN5lKO%Dc&=SOS4_GrNO_`-Z z>p&+*0AL4hFakMPbaKLs=Xhr4Pbf~< z;ZG!-T&ifjDPRhO3dDNb=lXxJ`TQRy*_|n13j8YtTqhf6BfOHYt*w{iTAQFv(bzby mQv684g_L5%aw*-a9VB_Bu@60_FMySv@l+4sD@y*U$++dApgiE2btL1lbbg_%;g zo$Ewd@>iCD4*o-gdzhFa=BjQ@|831zv&z*t5A*X6*agYSt7m1zx2B ze18Z~8AFe$qTD)QBn1GLa9IkDkq3xR;4$==DvB${)a=1|SLagltkWc|U{y5WM%~Z3e&5 z=;pZ7`o;X7S>E>_um9X^3-3AO_Tw?<;@&yFdNVCk7vE_gpS?V%ix1$GY0b|2*Wb0v z!Q>{D7FoofHD96B2gj2uWcyn?!Prx`y~A5_RjL(+z7_}3yd)S zexEi%p8u^rO|aDs-1dK75mror#VhbeL)JL||2q5q|KcZR5lsP8U||ZlQlr_}LO7pW yPXy+zIzZh;<-+|^MR9?V-5e KerasHub output:", keras_outputs[0:5]) - print("-> HF output:", hf_outputs[0:5]) + print("-> KerasHub output:", keras_hidden_states[0:5]) + print("-> HF output:", hf_hidden_states[0:5]) np.testing.assert_allclose( - keras_outputs.detach().numpy(), hf_outputs.detach().numpy(), atol=1e-5 + keras_hidden_states.numpy(), + hf_hidden_states.detach().numpy(), + atol=1e-2, ) if keras_model.tie_embedding_weights: @@ -333,7 +405,7 @@ def check_output( print("-> KerasHub logits:", keras_logits[0:5]) print("-> HF logits:", hf_logits[0:5]) np.testing.assert_allclose( - keras_logits.detach().numpy(), hf_logits.detach().numpy(), atol=1e-3 + keras_logits.numpy(), hf_logits.detach().numpy(), atol=1e-1 ) @@ -352,16 +424,18 @@ def main(_): keras_model = convert_checkpoints(hf_model) # Save the model. - model_path = f"./{FLAGS.preset}/model.weights.h5" + model_path = f"./{FLAGS.preset}" + weight_path = os.path.join(model_path, "model.weights.h5") print(f"\n-> Save KerasHub model weights to `{model_path}`.") - keras_model.save_weights(model_path) + keras_model.save_to_preset(model_path) print("-> Print MD5 checksum of the model weights files.") - print(f"`{model_path}` md5sum: ", get_md5_checksum(model_path)) + print(f"`{model_path}` md5sum: ", get_md5_checksum(weight_path)) print(f"-> Param count {count_params(keras_model.weights)}") print("\n-> Convert vocab.") hf_tokenizer = transformers.AutoTokenizer.from_pretrained(hf_id) keras_tokenizer = extract_vocab(hf_tokenizer) + keras_tokenizer.save_to_preset(model_path) check_output( keras_model,