Skip to content

Commit

Permalink
Add Kaggle upload validation tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed Mar 25, 2024
1 parent 8c189ce commit 0248b1f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
2 changes: 1 addition & 1 deletion keras_nlp/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _validate_backbone(preset):
weights_path = os.path.join(preset, config["weights"])
if not os.path.exists(weights_path):
raise FileNotFoundError(
f"The weights file doesn't exist in preset directory `{preset}`."
f"The weights file is missing from the preset directory `{preset}`."
)
else:
raise ValueError(
Expand Down
53 changes: 52 additions & 1 deletion keras_nlp/utils/preset_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from absl.testing import parameterized

from keras_nlp import upload_preset
from keras_nlp.models import DistilBertBackbone
from keras_nlp.models import DistilBertTokenizer
from keras_nlp.models.albert.albert_classifier import AlbertClassifier
from keras_nlp.models.backbone import Backbone
from keras_nlp.models.bert.bert_classifier import BertClassifier
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.task import Task
from keras_nlp.tests.test_case import TestCase
from keras_nlp.utils.preset_utils import CONFIG_FILE
from keras_nlp.utils.preset_utils import TOKENIZER_CONFIG_FILE
from keras_nlp.utils.preset_utils import check_preset_class
from keras_nlp.utils.preset_utils import load_from_preset
from keras_nlp.utils.preset_utils import save_to_preset
Expand Down Expand Up @@ -116,4 +120,51 @@ def test_upload_empty_preset(self):
with self.assertRaises(FileNotFoundError):
upload_preset(uri, empty_preset)

# TODO: add more test to cover various invalid scenarios such as invalid json, missing files, etc.
@parameterized.parameters(
(TOKENIZER_CONFIG_FILE), (CONFIG_FILE), ("model.weights.h5")
)
@pytest.mark.keras_3_only
@pytest.mark.large
def test_upload_with_missing_file(self, missing_file):
# Load a model from Kaggle to run tests on it.
preset = "distil_bert_base_en"
backbone = DistilBertBackbone.from_preset(preset)
tokenizer = DistilBertTokenizer.from_preset(preset)

# Save the model on a local directory.
temp_dir = self.get_temp_dir()
local_preset_dir = os.path.join(temp_dir, "distil_bert_preset")
backbone.save_to_preset(local_preset_dir)
tokenizer.save_to_preset(local_preset_dir)

# Delete the file that is supposed to be missing.
missing_path = os.path.join(local_preset_dir, missing_file)
os.remove(missing_path)

# Verify error handling.
with self.assertRaisesRegex(FileNotFoundError, "is missing"):
upload_preset("kaggle://test/test/test", local_preset_dir)

@parameterized.parameters((TOKENIZER_CONFIG_FILE), (CONFIG_FILE))
@pytest.mark.keras_3_only
@pytest.mark.large
def test_upload_with_invalid_json(self, json_file):
# Load a model from Kaggle to run tests on it.
preset = "distil_bert_base_en"
backbone = DistilBertBackbone.from_preset(preset)
tokenizer = DistilBertTokenizer.from_preset(preset)

# Save the model on a local directory.
temp_dir = self.get_temp_dir()
local_preset_dir = os.path.join(temp_dir, "distil_bert_preset")
backbone.save_to_preset(local_preset_dir)
tokenizer.save_to_preset(local_preset_dir)

# Re-write json file content to an invalid format.
json_path = os.path.join(local_preset_dir, json_file)
with open(json_path, "w") as file:
file.write("Invalid!")

# Verify error handling.
with self.assertRaisesRegex(ValueError, "is an invalid json"):
upload_preset("kaggle://test/test/test", local_preset_dir)

0 comments on commit 0248b1f

Please sign in to comment.