From 320c9fae6f9251cb1e77b96354d00d31c03ed865 Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Wed, 10 Apr 2024 16:55:42 +0000 Subject: [PATCH] Fixes. --- keras_nlp/models/backbone.py | 2 ++ keras_nlp/models/preprocessor.py | 5 +++-- keras_nlp/models/task.py | 2 ++ keras_nlp/tokenizers/tokenizer.py | 5 ++--- keras_nlp/utils/preset_utils.py | 5 ++++- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 962003d0f6..41e67c2717 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -18,6 +18,7 @@ from keras_nlp.utils.preset_utils import CONFIG_FILE from keras_nlp.utils.preset_utils import MODEL_WEIGHTS_FILE from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import check_keras_version from keras_nlp.utils.preset_utils import get_file from keras_nlp.utils.preset_utils import jax_memory_cleanup from keras_nlp.utils.preset_utils import list_presets @@ -218,6 +219,7 @@ def save_to_preset(self, preset): Args: preset: The path to the local model preset directory. """ + check_keras_version() make_preset_dir(preset) save_serialized_object(self, preset, config_file=CONFIG_FILE) save_weights(self, preset, MODEL_WEIGHTS_FILE) diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index b85d706b90..70150bc6e4 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -17,11 +17,11 @@ from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) -from keras_nlp.models import Tokenizer from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import check_keras_version from keras_nlp.utils.preset_utils import get_file from keras_nlp.utils.preset_utils import list_presets from keras_nlp.utils.preset_utils import list_subclasses @@ -150,7 +150,7 @@ def from_preset( preset, config_file=TOKENIZER_CONFIG_FILE, ) - if tokenizer_preset_cls is not cls: + if tokenizer_preset_cls is not cls.tokenizer_cls: subclasses = list_subclasses(cls) subclasses = tuple( filter( @@ -192,6 +192,7 @@ def save_to_preset(self, preset): Args: preset: The path to the local model preset directory. """ + check_keras_version() make_preset_dir(preset) self.tokenizer.save_to_preset(preset) save_serialized_object( diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index f1a14bee61..f50c59753c 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -29,6 +29,7 @@ from keras_nlp.utils.preset_utils import TASK_CONFIG_FILE from keras_nlp.utils.preset_utils import TASK_WEIGHTS_FILE from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import check_keras_version from keras_nlp.utils.preset_utils import get_file from keras_nlp.utils.preset_utils import list_presets from keras_nlp.utils.preset_utils import list_subclasses @@ -345,6 +346,7 @@ def save_to_preset(self, preset): Args: preset: The path to the local model preset directory. """ + check_keras_version() make_preset_dir(preset) if self.preprocessor is None: raise ValueError( diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 2fc3150349..7874e40282 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -19,8 +19,7 @@ from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_nlp.utils.preset_utils import check_config_class -from keras_nlp.utils.preset_utils import get_asset_dir -from keras_nlp.utils.preset_utils import get_file +from keras_nlp.utils.preset_utils import check_keras_version from keras_nlp.utils.preset_utils import list_presets from keras_nlp.utils.preset_utils import list_subclasses from keras_nlp.utils.preset_utils import load_tokenizer @@ -143,6 +142,7 @@ def save_to_preset(self, preset): Args: preset: The path to the local model preset directory. """ + check_keras_version() make_preset_dir(preset) save_tokenizer_assets(self, preset) save_serialized_object(self, preset, config_file=TOKENIZER_CONFIG_FILE) @@ -239,4 +239,3 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from config_file=TOKENIZER_CONFIG_FILE, asset_dir=TOKENIZER_ASSET_DIR, ) - tokenizer.load_assets(asset_dir) diff --git a/keras_nlp/utils/preset_utils.py b/keras_nlp/utils/preset_utils.py index e63cf83846..dfc95def36 100644 --- a/keras_nlp/utils/preset_utils.py +++ b/keras_nlp/utils/preset_utils.py @@ -201,13 +201,16 @@ def recursive_pop(config, key): recursive_pop(value, key) -def make_preset_dir(preset): +def check_keras_version(): if not backend_config.keras_3(): raise ValueError( "`save_to_preset` requires Keras 3. Run `pip install -U keras` " "upgrade your Keras version, or see https://keras.io/getting_started/ " "for more info on Keras versions and installation." ) + + +def make_preset_dir(preset): os.makedirs(preset, exist_ok=True)