diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 47f2a43d88..9c06a6d8d3 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -19,6 +19,7 @@ from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras +from keras_nlp.models.backbone import Backbone from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.utils.keras_utils import print_msg from keras_nlp.utils.pipeline_model import PipelineModel @@ -229,29 +230,31 @@ def from_preset( f"Received: backbone={kwargs['backbone']}." ) - if not get_file(preset, PREPROCESSOR_CONFIG_FILE): - # Load tokenizer and create a preprocessor based on that. - preprocessor = cls.preprocessor_cls( - tokenizer=cls.preprocessor_clstokenizer_cls.from_preset(preset) - ) - else: - # Load preprocessor from preset. - preprocessor_preset_cls = check_config_class( - preset, PREPROCESSOR_CONFIG_FILE - ) - if not issubclass(preprocessor_preset_cls, Preprocessor): + task = None + try: + # Task case. + task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE) + if not issubclass(task_preset_cls, cls): raise ValueError( - f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`." + f"`{TASK_CONFIG_FILE}` has type `{task_preset_cls.__name__}` " + f"which is not a subclass of calling class `{cls.__name__}`. Call " + f"`from_preset` directly on `{task_preset_cls.__name__}` instead." ) - preprocessor = preprocessor_preset_cls.from_preset(preset) - - # Backbone case. - backbone_preset_cls = check_config_class(preset) - task_preset_cls = check_config_class(preset, TASK_CONFIG_FILE) - try: - get_file(preset, TASK_CONFIG_FILE) + backbone_config = load_config(preset, CONFIG_FILE) + # TODO: this is not really an override! It's an addition! Should I rename this? + config_overrides = {"backbone": backbone_config} + task = load_serialized_object( + preset, + TASK_CONFIG_FILE, + config_overrides=config_overrides, + ) + if load_weights: + task.load_weights(get_file(preset, TASK_WEIGHTS_FILE)) + task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE)) except FileNotFoundError: - if not issubclass(task_preset_cls, cls): + # Backbone case. + backbone_preset_cls = check_config_class(preset, CONFIG_FILE) + if issubclass(backbone_preset_cls, Backbone): if backbone_preset_cls is not cls.backbone_cls: subclasses = list_subclasses(cls) subclasses = tuple( @@ -282,33 +285,28 @@ def from_preset( load_weights=load_weights, config_overrides=config_overrides, ) - return cls( - backbone=backbone, - preprocessor=preprocessor, - **kwargs, - ) - # Load task from preset if it exists. - if not issubclass(task_preset_cls, cls): - raise ValueError( - f"`{TASK_CONFIG_FILE}` has type `{task_preset_cls.__name__}` " - f"which is not a subclass of calling class `{cls.__name__}`. Call " - f"`from_preset` directly on `{task_preset_cls.__name__}` instead." + try: + # Load preprocessor from preset. + preprocessor_preset_cls = check_config_class( + preset, PREPROCESSOR_CONFIG_FILE ) - backbone_config = load_config(preset, CONFIG_FILE) - # TODO: this is not really an override! It's an addition! Should I rename this? - config_overrides = {"backbone": backbone_config} - task = load_serialized_object( - preset, - TASK_CONFIG_FILE, - config_overrides=config_overrides, - ) - if load_weights: - task.load_weights(get_file(preset, TASK_WEIGHTS_FILE)) - task.backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE)) - # `task.preprocessor` is None before this assignment. - task.preprocessor = preprocessor - return task + if not issubclass(preprocessor_preset_cls, Preprocessor): + raise ValueError( + f"`{PREPROCESSOR_CONFIG_FILE}` in `{preset}` should be a subclass of `Preprocessor`." + ) + preprocessor = preprocessor_preset_cls.from_preset(preset) + except FileNotFoundError: + # Load tokenizer and create a preprocessor based on that. + preprocessor = cls.preprocessor_cls( + tokenizer=cls.preprocessor_cls.tokenizer_cls.from_preset(preset) + ) + + if task: + # `task.preprocessor` is None before this assignment. + task.preprocessor = preprocessor + return task + return cls(backbone=backbone, preprocessor=preprocessor, **kwargs) def load_weights(self, filepath): """Load only the tasks specific weights not in the backbone.""" diff --git a/keras_nlp/models/task_test.py b/keras_nlp/models/task_test.py index 1b2458d098..18c321e2b0 100644 --- a/keras_nlp/models/task_test.py +++ b/keras_nlp/models/task_test.py @@ -15,12 +15,13 @@ import pytest from keras_nlp.backend import keras +from keras_nlp.models import CausalLM +from keras_nlp.models import Preprocessor +from keras_nlp.models import Task +from keras_nlp.models import Tokenizer from keras_nlp.models.bert.bert_classifier import BertClassifier from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM -from keras_nlp.models.preprocessor import Preprocessor -from keras_nlp.models.task import Task from keras_nlp.tests.test_case import TestCase -from keras_nlp.tokenizers.tokenizer import Tokenizer class SimpleTokenizer(Tokenizer): @@ -51,6 +52,15 @@ def test_preset_accessors(self): self.assertContainsSubset(bert_presets, all_presets) self.assertContainsSubset(gpt2_presets, all_presets) + @pytest.mark.large + def test_from_preset(self): + self.assertIsInstance( + CausalLM.from_preset("gpt2_base_en", load_weights=False), + GPT2CausalLM, + ) + # TODO: Add a classifier task loading test when there is a classifier + # with new design available on Kaggle. + @pytest.mark.large def test_from_preset_errors(self): with self.assertRaises(ValueError):