diff --git a/docs/changes.rst b/docs/changes.rst index f5e13aeb..6bf5219e 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -13,6 +13,8 @@ v0.11 ----- - Correctly restore ``default_factory`` when saving and loading a ``defaultdict``. :pr:`433` by `Adrin Jalali`_. +- Fix dumping Scikeras model failing because of maximum recursion depth. :pr:`388` + by :user:`Thomas Lazarus `. v0.10 ----- diff --git a/pyproject.toml b/pyproject.toml index f697b327..ce9dde54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,9 @@ filterwarnings = [ "ignore:DataFrameGroupBy.apply operated on the grouping columns.:DeprecationWarning", # Ignore Pandas 2.2 warning on PyArrow. It might be reverted in a later release. "ignore:\\s*Pyarrow will become a required dependency of pandas.*:DeprecationWarning", + # Ignore Google protobuf deprecation warnings. This should be reverted once tensorflow supports protobuf 5.0+ + "ignore:Type google._upb._message.MessageMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14.:DeprecationWarning:importlib.*:", + "ignore:Type google._upb._message.ScalarMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14.:DeprecationWarning:importlib.*:", ] markers = [ "network: marks tests as requiring internet (deselect with '-m \"not network\"')", diff --git a/skops/_min_dependencies.py b/skops/_min_dependencies.py index 58f12c4c..ca5503c0 100644 --- a/skops/_min_dependencies.py +++ b/skops/_min_dependencies.py @@ -36,6 +36,8 @@ "catboost": ("1.0", "tests", None), "fairlearn": ("0.7.0", "docs, tests", None), "rich": ("12", "tests, rich", None), + "scikeras": ("0.12.0", "docs, tests", None), + "tensorflow": ("2.12.0", "docs, tests", None), } diff --git a/skops/io/_keras.py b/skops/io/_keras.py new file mode 100644 index 00000000..01e0dd12 --- /dev/null +++ b/skops/io/_keras.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import io +import os +import tempfile +from typing import Sequence, Type + +from ._audit import Node +from ._protocol import PROTOCOL +from ._utils import Any, LoadContext, SaveContext, get_module + +try: + from tensorflow.keras.models import Model, Sequential, load_model, save_model + + tf_present = True +except ImportError: + tf_present = False + + +def keras_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]: + res = { + "__class__": obj.__class__.__name__, + "__module__": get_module(type(obj)), + "__loader__": "KerasNode", + } + + # Memoize the object and then check if it's file name (containing + # the object id) already exists. If it does, there is no need to + # save the object again. Memoizitation is necessary since for + # ephemeral objects, the same id might otherwise be reused. + obj_id = save_context.memoize(obj) + f_name = f"{obj_id}.keras" + if f_name in save_context.zip_file.namelist(): + return res + + with tempfile.TemporaryDirectory() as temp_dir: + file_name = os.path.join(temp_dir, "model.keras") + save_model(obj, file_name) + save_context.zip_file.write(file_name, f_name) + res.update(file=f_name) + return res + + +class KerasNode(Node): + def __init__( + self, + state: dict[str, Any], + load_context: LoadContext, + trusted: Sequence[str] | None = None, + ) -> None: + if not tf_present: + raise ImportError( + "`tf.keras` is missing and needs to be installed in order to load this" + " object." + ) + super().__init__(state, load_context, trusted) + self.trusted = self._get_trusted(trusted, default=[]) + + self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))} + + def _construct(self): + with tempfile.TemporaryDirectory() as temp_dir: + file_path = os.path.join(temp_dir, "model.keras") + with open(file_path, "wb") as f: + f.write(self.children["content"].getbuffer()) + model = load_model(file_path, compile=False, safe_mode=True) + return model + + +if tf_present: + GET_STATE_DISPATCH_FUNCTIONS = [ + (Model, keras_get_state), + (Sequential, keras_get_state), + ] + +NODE_TYPE_MAPPING: dict[tuple[str, int], Type[Node]] = { + ("KerasNode", PROTOCOL): KerasNode +} diff --git a/skops/io/_persist.py b/skops/io/_persist.py index aaed469c..8e620f0e 100644 --- a/skops/io/_persist.py +++ b/skops/io/_persist.py @@ -15,7 +15,14 @@ # We load the dispatch functions from the corresponding modules and register # them. Old protocols are found in the 'old/' directory, with the protocol # version appended to the corresponding module name. -modules = ["._general", "._numpy", "._scipy", "._sklearn", "._quantile_forest"] +modules = [ + "._general", + "._numpy", + "._keras", + "._scipy", + "._sklearn", + "._quantile_forest", +] modules.extend([".old._general_v0", ".old._numpy_v0", ".old._numpy_v1"]) for module_name in modules: # register exposed functions for get_state and get_tree diff --git a/skops/io/tests/test_external.py b/skops/io/tests/test_external.py index d9fa7916..72d8d5ae 100644 --- a/skops/io/tests/test_external.py +++ b/skops/io/tests/test_external.py @@ -18,6 +18,7 @@ from sklearn.datasets import make_classification, make_regression from skops.io import dumps, loads, visualize +from skops.io.exceptions import UntrustedTypesFoundException from skops.io.tests._utils import assert_method_outputs_equal, assert_params_equal # Default settings for generated data @@ -431,3 +432,76 @@ def test_quantile_forest(self, quantile_forest, regr_data, trusted, tree_method) assert_method_outputs_equal(estimator, loaded, X) visualize(dumped, trusted=trusted) + + +class TestSciKeras: + """Tests for SciKerasRegressor and SciKerasClassifier""" + + @pytest.fixture + def trusted(self): + return [ + "collections.defaultdict", + "keras.src.models.sequential.Sequential", + "numpy.dtype", + "scikeras.utils.transformers.ClassifierLabelEncoder", + "scikeras.utils.transformers.TargetReshaper", + "scikeras.wrappers.KerasClassifier", + ] + + @pytest.fixture(autouse=True) + def capture_stdout(self): + # Mock print and rich.print so that running these tests with pytest -s + # does not spam stdout. Other, more common methods of suppressing + # printing to stdout don't seem to work, perhaps because of pytest. + with patch("builtins.print", Mock()), patch("rich.print", Mock()): + yield + + @pytest.fixture(autouse=True) + def tensorflow(self): + tensorflow = pytest.importorskip("tensorflow") + return tensorflow + + def test_dumping_model(self, tensorflow, trusted): + # This simplifies the basic usage tutorial from https://adriangb.com/scikeras/stable/notebooks/Basic_Usage.html + + n_features_in_ = 20 + model = tensorflow.keras.models.Sequential() + model.add(tensorflow.keras.layers.Input(shape=(n_features_in_,))) + model.add(tensorflow.keras.layers.Dense(1, activation="sigmoid")) + + from scikeras.wrappers import KerasClassifier + + clf = KerasClassifier(model=model, loss="binary_crossentropy") + + X, y = make_classification(1000, 20, n_informative=10, random_state=0) + clf.fit(X, y) + + predictions = clf.predict(X) + + dumped = dumps(clf) + + # Loads returns the Keras model so we need to initialize it as a SciKeras model + clf_new = loads(dumped, trusted=trusted) + new_preidctions = clf_new.predict(X) + assert all(new_preidctions == predictions) + + def test_dumping_untrusted_model(self, tensorflow): + # This simplifies the basic usage tutorial from https://adriangb.com/scikeras/stable/notebooks/Basic_Usage.html + + n_features_in_ = 20 + model = tensorflow.keras.models.Sequential() + model.add(tensorflow.keras.layers.Input(shape=(n_features_in_,))) + model.add(tensorflow.keras.layers.Dense(1, activation="sigmoid")) + + from scikeras.wrappers import KerasClassifier + + clf = KerasClassifier(model=model, loss="binary_crossentropy") + + X, y = make_classification(1000, 20, n_informative=10, random_state=0) + clf.fit(X, y) + + dumped = dumps(clf) + + # Tries to load the dumped model but returns an untrusted exception + with pytest.raises(UntrustedTypesFoundException): + loads(dumped)