From dab5dd66756404b12097621828333559005381de Mon Sep 17 00:00:00 2001 From: Samaneh Saadat Date: Mon, 8 Apr 2024 23:27:10 +0000 Subject: [PATCH] Move saving logic to the base classes' from_preset. --- keras_nlp/models/preprocessor.py | 1 + keras_nlp/models/task.py | 5 +++-- keras_nlp/tokenizers/tokenizer.py | 3 +++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/keras_nlp/models/preprocessor.py b/keras_nlp/models/preprocessor.py index fc6509954e..b85d706b90 100644 --- a/keras_nlp/models/preprocessor.py +++ b/keras_nlp/models/preprocessor.py @@ -17,6 +17,7 @@ from keras_nlp.layers.preprocessing.preprocessing_layer import ( PreprocessingLayer, ) +from keras_nlp.models import Tokenizer from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index dcc9b6786c..f1a14bee61 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -19,6 +19,7 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras +from keras_nlp.models import Backbone from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import print_msg from keras_nlp.utils.pipeline_model import PipelineModel @@ -236,7 +237,7 @@ def from_preset( raise ValueError( f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`." ) - preprocessor = preprocessor_preset_cls.from_preset(preset) + preprocessor = Preprocessor.from_preset(preset) # Backbone case. backbone_preset_cls = check_config_class(preset) @@ -270,7 +271,7 @@ def from_preset( config_overrides = {} if "dtype" in kwargs: config_overrides["dtype"] = kwargs.pop("dtype") - backbone = backbone_preset_cls.from_preset( + backbone = Backbone.from_preset( preset, load_weights=load_weights, config_overrides=config_overrides, diff --git a/keras_nlp/tokenizers/tokenizer.py b/keras_nlp/tokenizers/tokenizer.py index 55c5f54d77..2fc3150349 100644 --- a/keras_nlp/tokenizers/tokenizer.py +++ b/keras_nlp/tokenizers/tokenizer.py @@ -19,6 +19,8 @@ from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_nlp.utils.preset_utils import check_config_class +from keras_nlp.utils.preset_utils import get_asset_dir +from keras_nlp.utils.preset_utils import get_file from keras_nlp.utils.preset_utils import list_presets from keras_nlp.utils.preset_utils import list_subclasses from keras_nlp.utils.preset_utils import load_tokenizer @@ -237,3 +239,4 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from config_file=TOKENIZER_CONFIG_FILE, asset_dir=TOKENIZER_ASSET_DIR, ) + tokenizer.load_assets(asset_dir)