Skip to content

Commit

Permalink
Address reviews.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 11, 2024
1 parent f60c23b commit 988f47e
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 73 deletions.
7 changes: 1 addition & 6 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import check_keras_version
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import jax_memory_cleanup
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import make_preset_dir
from keras_nlp.utils.preset_utils import save_metadata
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import save_weights
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -219,10 +216,8 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
check_keras_version()
make_preset_dir(preset)
save_serialized_object(self, preset, config_file=CONFIG_FILE)
save_weights(self, preset, MODEL_WEIGHTS_FILE)
self.save_weights(get_file(preset, MODEL_WEIGHTS_FILE))
save_metadata(self, preset)

def enable_lora(self, rank):
Expand Down
22 changes: 4 additions & 18 deletions keras_nlp/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import check_keras_version
from keras_nlp.utils.preset_utils import check_file_exists
from keras_nlp.utils.preset_utils import get_asset_dir
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_config
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import make_preset_dir
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.python_utils import classproperty

Expand Down Expand Up @@ -164,19 +162,11 @@ def from_preset(
"Please call `from_preset` on a subclass directly."
)

# For backward compatibility, if preset doesn't have `preprocessor.json`
# `from_preset` creates a preprocessor based on `tokenizer.json`.
try:
# `preprocessor.json` exists.
# TODO: che
if check_file_exists(preset, PREPROCESSOR_CONFIG_FILE):
get_file(preset, PREPROCESSOR_CONFIG_FILE)
tokenizer_config = load_config(preset, TOKENIZER_CONFIG_FILE)
# TODO: this is not really an override! It's an addition! Should I rename this?
config_overrides = {"tokenizer": tokenizer_config}
preprocessor = load_serialized_object(
preset,
PREPROCESSOR_CONFIG_FILE,
config_overrides=config_overrides,
)
for asset in preprocessor.tokenizer.file_assets:
get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset))
Expand All @@ -186,8 +176,7 @@ def from_preset(
asset_dir=TOKENIZER_ASSET_DIR,
)
preprocessor.tokenizer.load_assets(tokenizer_asset_dir)
except FileNotFoundError:
# `preprocessor.json` doesn't exist.
else:
tokenizer = load_serialized_object(preset, TOKENIZER_CONFIG_FILE)
for asset in tokenizer.file_assets:
get_file(preset, os.path.join(TOKENIZER_ASSET_DIR, asset))
Expand All @@ -207,12 +196,9 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
check_keras_version()
make_preset_dir(preset)
self.tokenizer.save_to_preset(preset)
save_serialized_object(
self,
preset,
config_file=PREPROCESSOR_CONFIG_FILE,
config_to_skip=["tokenizer"],
)
self.tokenizer.save_to_preset(preset)
43 changes: 13 additions & 30 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,12 @@
from keras_nlp.utils.preset_utils import TASK_CONFIG_FILE
from keras_nlp.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import check_keras_version
from keras_nlp.utils.preset_utils import check_file_exists
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_config
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import make_preset_dir
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import save_weights
from keras_nlp.utils.python_utils import classproperty


Expand Down Expand Up @@ -231,7 +228,7 @@ def from_preset(
)

task = None
try:
if check_file_exists(preset, TASK_CONFIG_FILE):
# Task case.
task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
if not issubclass(task_preset_cls, cls):
Expand All @@ -240,18 +237,11 @@ def from_preset(
f"which is not a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{task_preset_cls.__name__}` instead."
)
backbone_config = load_config(preset, CONFIG_FILE)
# TODO: this is not really an override! It's an addition! Should I rename this?
config_overrides = {"backbone": backbone_config}
task = load_serialized_object(
preset,
TASK_CONFIG_FILE,
config_overrides=config_overrides,
)
task = load_serialized_object(preset, TASK_CONFIG_FILE)
if load_weights:
task.load_weights(get_file(preset, TASK_WEIGHTS_FILE))
task.load_task_weights(get_file(preset, TASK_WEIGHTS_FILE))
task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
except FileNotFoundError:
else:
# Backbone case.
backbone_preset_cls = check_config_class(preset, CONFIG_FILE)
if issubclass(backbone_preset_cls, Backbone):
Expand Down Expand Up @@ -286,7 +276,7 @@ def from_preset(
config_overrides=config_overrides,
)

try:
if check_file_exists(preset, PREPROCESSOR_CONFIG_FILE):
# Load preprocessor from preset.
preprocessor_preset_cls = check_config_class(
preset, PREPROCESSOR_CONFIG_FILE
Expand All @@ -296,7 +286,7 @@ def from_preset(
f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`."
)
preprocessor = preprocessor_preset_cls.from_preset(preset)
except FileNotFoundError:
else:
# Load tokenizer and create a preprocessor based on that.
preprocessor = cls.preprocessor_cls(
tokenizer=cls.preprocessor_cls.tokenizer_cls.from_preset(preset)
Expand All @@ -308,20 +298,20 @@ def from_preset(
return task
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)

def load_weights(self, filepath):
def load_task_weights(self, filepath):
"""Load only the tasks specific weights not in the backbone."""
if not str(filepath).endswith(".weights.h5"):
raise ValueError(
"The filename must end in `.weights.h5`. Received: filepath={filepath}"
)
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
keras.saving.save_weights(
keras.saving.load_weights(
self,
filepath,
objects_to_skip=backbone_layer_ids,
)

def save_weights(self, filepath):
def save_task_weights(self, filepath):
"""Save only the tasks specific weights not in the backbone."""
if not str(filepath).endswith(".weights.h5"):
raise ValueError(
Expand All @@ -348,24 +338,17 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
check_keras_version()
make_preset_dir(preset)
if self.preprocessor is None:
raise ValueError(
"Cannot save `task` to preset: `Preprocessor` is not initialized."
)

save_serialized_object(self, preset, config_file=TASK_CONFIG_FILE)
self.save_task_weights(get_file(preset, TASK_WEIGHTS_FILE))

self.preprocessor.save_to_preset(preset)
self.backbone.save_to_preset(preset)

save_serialized_object(
self,
preset,
config_file=TASK_CONFIG_FILE,
config_to_skip=["preprocessor", "backbone"],
)
save_weights(self, preset, TASK_WEIGHTS_FILE)

@property
def layers(self):
# Remove preprocessor from layers so it does not show up in the summary.
Expand Down
7 changes: 1 addition & 6 deletions keras_nlp/tokenizers/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
from keras_nlp.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_config_class
from keras_nlp.utils.preset_utils import check_keras_version
from keras_nlp.utils.preset_utils import get_asset_dir
from keras_nlp.utils.preset_utils import get_file
from keras_nlp.utils.preset_utils import list_presets
from keras_nlp.utils.preset_utils import list_subclasses
from keras_nlp.utils.preset_utils import load_serialized_object
from keras_nlp.utils.preset_utils import make_preset_dir
from keras_nlp.utils.preset_utils import save_serialized_object
from keras_nlp.utils.preset_utils import save_tokenizer_assets
from keras_nlp.utils.python_utils import classproperty
Expand Down Expand Up @@ -145,11 +143,8 @@ def save_to_preset(self, preset):
Args:
preset: The path to the local model preset directory.
"""
check_keras_version()
make_preset_dir(preset)
save_tokenizer_assets(self, preset)
save_serialized_object(self, preset, config_file=TOKENIZER_CONFIG_FILE)
# save_to_preset(self, preset, config_filename=TOKENIZER_CONFIG_FILE)
save_tokenizer_assets(self, preset)

def call(self, inputs, *args, training=None, **kwargs):
return self.tokenize(inputs, *args, **kwargs)
Expand Down
24 changes: 11 additions & 13 deletions keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ def get_file(preset, path):
)


def check_file_exists(preset, path):
try:
get_file(preset, path)
except FileNotFoundError:
return False
return True


def get_tokenizer(layer):
"""Get the tokenizer from any KerasNLP model or layer."""
# Avoid circular import.
Expand Down Expand Up @@ -227,6 +235,8 @@ def save_serialized_object(
config_file=CONFIG_FILE,
config_to_skip=[],
):
check_keras_version()
make_preset_dir(preset)
config_path = os.path.join(preset, config_file)
config = keras.saving.serialize_keras_object(layer)
config_to_skip += ["compile_config", "build_config"]
Expand All @@ -251,13 +261,6 @@ def save_metadata(layer, preset):
metadata_file.write(json.dumps(metadata, indent=4))


def save_weights(layer, preset, weights_file):
if not hasattr(layer, "save_weights"):
raise ValueError(f"`save_weights` hasn't been defined for `{layer}`.")
weights_path = os.path.join(preset, weights_file)
layer.save_weights(weights_path)


def _validate_tokenizer(preset, allow_incomplete=False):
config_path = get_file(preset, TOKENIZER_CONFIG_FILE)
if not os.path.exists(config_path):
Expand All @@ -282,12 +285,7 @@ def _validate_tokenizer(preset, allow_incomplete=False):
)
layer = keras.saving.deserialize_keras_object(config)

if not config["assets"]:
raise ValueError(
f"Tokenizer config file {config_path} is missing `asset`."
)

for asset in config["assets"]:
for asset in layer.file_assets:
asset_path = os.path.join(preset, asset)
if not os.path.exists(asset_path):
raise FileNotFoundError(
Expand Down

0 comments on commit 988f47e

Please sign in to comment.