Skip to content

Commit

Permalink
Fix a Task bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Apr 10, 2024
1 parent f772e41 commit a465d2b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 48 deletions.
88 changes: 43 additions & 45 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import config
from keras_nlp.backend import keras
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.utils.keras_utils import print_msg
from keras_nlp.utils.pipeline_model import PipelineModel
Expand Down Expand Up @@ -229,29 +230,31 @@ def from_preset(
f"Received: backbone={kwargs['backbone']}."
)

if not get_file(preset, PREPROCESSOR_CONFIG_FILE):
# Load tokenizer and create a preprocessor based on that.
preprocessor = cls.preprocessor_cls(
tokenizer=cls.preprocessor_clstokenizer_cls.from_preset(preset)
)
else:
# Load preprocessor from preset.
preprocessor_preset_cls = check_config_class(
preset, PREPROCESSOR_CONFIG_FILE
)
if not issubclass(preprocessor_preset_cls, Preprocessor):
task = None
try:
# Task case.
task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
if not issubclass(task_preset_cls, cls):
raise ValueError(
f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`."
f"`{TASK_CONFIG_FILE}` has type `{task_preset_cls.__name__}` "
f"which is not a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{task_preset_cls.__name__}` instead."
)
preprocessor = preprocessor_preset_cls.from_preset(preset)

# Backbone case.
backbone_preset_cls = check_config_class(preset)
task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE)
try:
get_file(preset, TASK_CONFIG_FILE)
backbone_config = load_config(preset, CONFIG_FILE)
# TODO: this is not really an override! It's an addition! Should I rename this?
config_overrides = {"backbone": backbone_config}
task = load_serialized_object(
preset,
TASK_CONFIG_FILE,
config_overrides=config_overrides,
)
if load_weights:
task.load_weights(get_file(preset, TASK_WEIGHTS_FILE))
task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
except FileNotFoundError:
if not issubclass(task_preset_cls, cls):
# Backbone case.
backbone_preset_cls = check_config_class(preset, CONFIG_FILE)
if issubclass(backbone_preset_cls, Backbone):
if backbone_preset_cls is not cls.backbone_cls:
subclasses = list_subclasses(cls)
subclasses = tuple(
Expand Down Expand Up @@ -282,33 +285,28 @@ def from_preset(
load_weights=load_weights,
config_overrides=config_overrides,
)
return cls(
backbone=backbone,
preprocessor=preprocessor,
**kwargs,
)

# Load task from preset if it exists.
if not issubclass(task_preset_cls, cls):
raise ValueError(
f"`{TASK_CONFIG_FILE}` has type `{task_preset_cls.__name__}` "
f"which is not a subclass of calling class `{cls.__name__}`. Call "
f"`from_preset` directly on `{task_preset_cls.__name__}` instead."
try:
# Load preprocessor from preset.
preprocessor_preset_cls = check_config_class(
preset, PREPROCESSOR_CONFIG_FILE
)
backbone_config = load_config(preset, CONFIG_FILE)
# TODO: this is not really an override! It's an addition! Should I rename this?
config_overrides = {"backbone": backbone_config}
task = load_serialized_object(
preset,
TASK_CONFIG_FILE,
config_overrides=config_overrides,
)
if load_weights:
task.load_weights(get_file(preset, TASK_WEIGHTS_FILE))
task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
# `task.preprocessor` is None before this assignment.
task.preprocessor = preprocessor
return task
if not issubclass(preprocessor_preset_cls, Preprocessor):
raise ValueError(
f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`."
)
preprocessor = preprocessor_preset_cls.from_preset(preset)
except FileNotFoundError:
# Load tokenizer and create a preprocessor based on that.
preprocessor = cls.preprocessor_cls(
tokenizer=cls.preprocessor_cls.tokenizer_cls.from_preset(preset)
)

if task:
# `task.preprocessor` is None before this assignment.
task.preprocessor = preprocessor
return task
return cls(backbone=backbone, preprocessor=preprocessor, **kwargs)

def load_weights(self, filepath):
"""Load only the tasks specific weights not in the backbone."""
Expand Down
16 changes: 13 additions & 3 deletions keras_nlp/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import pytest

from keras_nlp.backend import keras
from keras_nlp.models import CausalLM
from keras_nlp.models import Preprocessor
from keras_nlp.models import Task
from keras_nlp.models import Tokenizer
from keras_nlp.models.bert.bert_classifier import BertClassifier
from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.models.task import Task
from keras_nlp.tests.test_case import TestCase
from keras_nlp.tokenizers.tokenizer import Tokenizer


class SimpleTokenizer(Tokenizer):
Expand Down Expand Up @@ -51,6 +52,15 @@ def test_preset_accessors(self):
self.assertContainsSubset(bert_presets, all_presets)
self.assertContainsSubset(gpt2_presets, all_presets)

@pytest.mark.large
def test_from_preset(self):
self.assertIsInstance(
CausalLM.from_preset("gpt2_base_en", load_weights=False),
GPT2CausalLM,
)
# TODO: Add a classifier task loading test when there is a classifier
# with new design available on Kaggle.

@pytest.mark.large
def test_from_preset_errors(self):
with self.assertRaises(ValueError):
Expand Down

0 comments on commit a465d2b

Please sign in to comment.