Skip to content

Commit

Permalink
For backward compatibility, allow loading preprocessor if preprocesso…
Browse files Browse the repository at this point in the history
…r.json doesn't exist.
  • Loading branch information
SamanehSaadat committed Apr 11, 2024
1 parent a465d2b commit 7b95077
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 81 deletions.
74 changes: 31 additions & 43 deletions keras_nlp/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
Expand All @@ -20,14 +21,13 @@
from keras_nlp.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import check_keras_version
from keras_nlp.utils.preset_utils import get_asset_dir
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_config
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import load_tokenizer
from keras_nlp.utils.preset_utils import make_preset_dir
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.python_utils import classproperty
Expand Down Expand Up @@ -138,51 +138,39 @@ def from_preset(
"`keras_nlp.models.BertPreprocessor.from_preset()`."
)

# For backward compatibility, if preset doesn't have `preprocessor.json`
# `from_preset` creates a preprocessor based on `tokenizer.json`.
try:
# `preprocessor.json` exists.
get_file(preset, PREPROCESSOR_CONFIG_FILE)
except FileNotFoundError:
raise FileNotFoundError(
f"Preset directory `{preset}` should contain preprocessor "
"config `{PREPROCESSOR_CONFIG_FILE}`."
tokenizer_config = load_config(preset, TOKENIZER_CONFIG_FILE)
# TODO: this is not really an override! It's an addition! Should I rename this?
config_overrides = {"tokenizer": tokenizer_config}
preprocessor = load_serialized_object(
preset,
PREPROCESSOR_CONFIG_FILE,
config_overrides=config_overrides,
)

tokenizer_preset_cls = check_config_class(
preset,
config_file=TOKENIZER_CONFIG_FILE,
)
if tokenizer_preset_cls is not cls:
subclasses = list_subclasses(cls)
subclasses = tuple(
filter(
lambda x: x.tokenizer_cls == tokenizer_preset_cls,
subclasses,
)
for asset in preprocessor.tokenizer.file_assets:
get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset))
tokenizer_asset_dir = get_asset_dir(
preset,
config_file=TOKENIZER_CONFIG_FILE,
asset_dir=TOKENIZER_ASSET_DIR,
)
if len(subclasses) == 0:
raise ValueError(
f"No registered subclass of `{cls.__name__}` can load "
f"a `{tokenizer_preset_cls.__name__}`."
)
if len(subclasses) > 1:
names = ", ".join(f"`{x.__name__}`" for x in subclasses)
raise ValueError(
f"Ambiguous call to `{cls.__name__}.from_preset()`. "
f"Found multiple possible subclasses {names}. "
"Please call `from_preset` on a subclass directly."
)
tokenizer_config = load_config(preset, TOKENIZER_CONFIG_FILE)
# TODO: this is not really an override! It's an addition! Should I rename this?
config_overrides = {"tokenizer": tokenizer_config}
preprocessor = load_serialized_object(
preset,
PREPROCESSOR_CONFIG_FILE,
config_overrides=config_overrides,
)
preprocessor.tokenizer = load_tokenizer(
preset,
config_file=TOKENIZER_CONFIG_FILE,
asset_dir=TOKENIZER_ASSET_DIR,
)
preprocessor.tokenizer.load_assets(tokenizer_asset_dir)
except FileNotFoundError:
# `preprocessor.json` doesn't exist.
tokenizer = load_serialized_object(preset, TOKENIZER_CONFIG_FILE)
for asset in tokenizer.file_assets:
get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset))
tokenizer_asset_dir = get_asset_dir(
preset,
config_file=TOKENIZER_CONFIG_FILE,
asset_dir=TOKENIZER_ASSET_DIR,
)
tokenizer.load_assets(tokenizer_asset_dir)
preprocessor = cls(tokenizer=tokenizer)

return preprocessor

Expand Down
34 changes: 11 additions & 23 deletions keras_nlp/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.layers.preprocessing.preprocessing_layer import (
Expand All @@ -20,9 +21,11 @@
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import check_keras_version
from keras_nlp.utils.preset_utils import get_asset_dir
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_tokenizer
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import make_preset_dir
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import save_tokenizer_assets
Expand Down Expand Up @@ -213,29 +216,14 @@ class like `keras_nlp.models.Tokenizer.from_preset()`, or from
f"a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{preset_cls.__name__}` instead."
)
if preset_cls is not cls:
subclasses = list_subclasses(cls)
subclasses = tuple(
filter(
lambda x: x.tokenizer_cls == preset_cls,
subclasses,
)
)
if len(subclasses) == 0:
raise ValueError(
f"No registered subclass of `{cls.__name__}` can load "
f"a `{preset_cls.__name__}`."
)
if len(subclasses) > 1:
names = ", ".join(f"`{x.__name__}`" for x in subclasses)
raise ValueError(
f"Ambiguous call to `{cls.__name__}.from_preset()`. "
f"Found multiple possible subclasses {names}. "
"Please call `from_preset` on a subclass directly."
)

return load_tokenizer(

tokenizer = load_serialized_object(preset, TOKENIZER_CONFIG_FILE)
for asset in tokenizer.file_assets:
get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset))
tokenizer_asset_dir = get_asset_dir(
preset,
config_file=TOKENIZER_CONFIG_FILE,
asset_dir=TOKENIZER_ASSET_DIR,
)
tokenizer.load_assets(tokenizer_asset_dir)
return tokenizer
15 changes: 0 additions & 15 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,18 +436,3 @@ def jax_memory_cleanup(layer):
for weight in layer.weights:
if getattr(weight, "_value", None) is not None:
weight._value.delete()


def load_tokenizer(
preset, config_file=TOKENIZER_CONFIG_FILE, asset_dir=TOKENIZER_ASSET_DIR
):
tokenizer = load_serialized_object(preset, config_file)
for asset in tokenizer.file_assets:
get_file(preset, os.path.join(asset_dir, asset))
tokenizer_asset_dir = get_asset_dir(
preset,
config_file,
asset_dir,
)
tokenizer.load_assets(tokenizer_asset_dir)
return tokenizer

0 comments on commit 7b95077

Please sign in to comment.