Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH - Gets SciKeras script working #394

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
925c960
Gets SciKeras script working
lazarust Oct 10, 2023
3500592
Fixes test_metainfo
lazarust Oct 10, 2023
492a1ca
Updates changes.rst
lazarust Oct 11, 2023
4c91a07
Merge branch 'main' into enh-get-scikeras-working
lazarust Oct 11, 2023
101b90c
Merge branch 'main' into enh-get-scikeras-working
lazarust Oct 13, 2023
08f5ba7
Merge branch 'main' into enh-get-scikeras-working
lazarust Oct 14, 2023
0491a7b
Merge branch 'main' into enh-get-scikeras-working
lazarust Oct 29, 2023
f1b93fe
Update changes.rst
lazarust Oct 29, 2023
bbe6b34
Adds test
lazarust Nov 14, 2023
3b274b4
Loads dumped model in test and checks output
lazarust Nov 18, 2023
e7ab34e
Refactor test_external.py to use dumps instead of dump
lazarust Nov 21, 2023
a1a92cc
Add TensorFlow as a dependent package
lazarust Nov 23, 2023
7b0f21e
WIP Still running into a recursion error
lazarust Nov 24, 2023
3397b17
Refactor imports and update test method
lazarust Nov 24, 2023
49a16f0
WIP Fix get_state function to include module and class
lazarust Nov 29, 2023
d4469f9
Merge branch 'main' into enh-get-scikeras-working
lazarust Dec 5, 2023
3b55471
Merge branch 'main' into enh-get-scikeras-working
lazarust Dec 14, 2023
b51ca0e
Reverts changes from previous implementation
lazarust Jan 4, 2024
dd8d6e1
Merge branch 'main' into enh-get-scikeras-working
lazarust Jan 11, 2024
f68eea6
WIP Fix saving of scikeras models in zip file
lazarust Jan 12, 2024
f304dc7
Fixes lines I missed when reverting previous commits
lazarust Jan 12, 2024
c131ebd
Switch to saving as a `.keras` file
lazarust Jan 17, 2024
846e72e
Updates to use TempFile to save the model
lazarust Jan 24, 2024
7f8593a
Fixes typo in comment
lazarust Jan 24, 2024
208839b
Merge branch 'main' into enh-get-scikeras-working
lazarust Feb 12, 2024
28e9f15
Merge branch 'main' into enh-get-scikeras-working
lazarust Feb 19, 2024
d1d260e
Merge branch 'main' into enh-get-scikeras-working
lazarust Mar 31, 2024
ac11b46
Add support for saving and loading scikeras models by adding _scikera…
lazarust Apr 1, 2024
a9b9dbf
Removes comment that isn't necessary now
lazarust Apr 1, 2024
e5eb579
Adds SciKerasNode
lazarust Apr 1, 2024
0eca68a
Update Keras import to TensorFlow and fix model loading
lazarust Apr 1, 2024
30f8993
Update dependencies for TensorFlow to be included in docs
lazarust Apr 1, 2024
1b1cdff
Adds scikeras to docs dependencies
lazarust Apr 1, 2024
5033112
Updates scikeras to version 0.13
lazarust Apr 17, 2024
6a8e821
Removes default trusted types
lazarust Apr 17, 2024
f119713
Merge branch 'main' into enh-get-scikeras-working
lazarust Apr 24, 2024
ab530f3
Add importing __future__ annotations in _scikeras.py
lazarust Apr 24, 2024
0d8efca
Update TensorFlow version to 2.16.0 in _min_dependencies.py
lazarust Apr 24, 2024
07e0d5a
Update scikeras version to 0.12.0 in _min_dependencies.py
lazarust Apr 24, 2024
cc08530
Update TensorFlow version to 2.13.0 in _min_dependencies.py
lazarust Apr 24, 2024
cc2aead
Update TensorFlow version to 2.12.0 in _min_dependencies.py
lazarust Apr 24, 2024
c96a50a
Merge branch 'main' into enh-get-scikeras-working
lazarust May 3, 2024
91983e8
Moves changes to the correct version
lazarust May 5, 2024
2f3ae7a
Merge branch 'main' into enh-get-scikeras-working
lazarust May 14, 2024
3739f1f
Ignores deprecation warning from protobuf
lazarust May 25, 2024
ce11bf0
Fixes deprecation warning from matplotlib
lazarust May 25, 2024
4c47aaf
Merge branch 'main' into enh-get-scikeras-working
lazarust Jun 9, 2024
12e2108
Merge branch 'main' into enh-get-scikeras-working
lazarust Jul 2, 2024
d677476
Fixes making scikears a hard dependency
lazarust Jul 2, 2024
6d10f71
Adds test for error on untrusted types
lazarust Jul 2, 2024
e307560
Cleans up unneeded ()
lazarust Jul 2, 2024
bb82961
Merge branch 'main' into enh-get-scikeras-working
lazarust Jul 14, 2024
76341fd
Merge branch 'main' into enh-get-scikeras-working
lazarust Aug 1, 2024
fa6b208
use TF directly
adrinjalali Aug 28, 2024
83891ed
Merge remote-tracking branch 'upstream/main' into enh-get-scikeras-wo…
adrinjalali Aug 28, 2024
e9b2dd0
move changelog
adrinjalali Aug 28, 2024
2d92168
add missing file
adrinjalali Aug 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ v0.10
- Removes Pythn 3.8 support and adds Python 3.12 Support :pr:`418` by :user:`Thomas Lazarus <lazarust>`.
- Removes a shortcut to add `sklearn-intelex` as a not dependency.
:pr:`420` by :user:`Thomas Lazarus < lazarust > `.
- Fix dumping Scikeras model failing because of maximum recursion depth. :pr:`388`
by :user:`Thomas Lazarus <lazarust>`.

v0.9
----
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ filterwarnings = [
"ignore:DataFrameGroupBy.apply operated on the grouping columns.:DeprecationWarning",
# Ignore Pandas 2.2 warning on PyArrow. It might be reverted in a later release.
"ignore:\\s*Pyarrow will become a required dependency of pandas.*:DeprecationWarning",
# Ignore Google protobuf deprecation warnings. This should be reverted once tensorflow supports protobuf 5.0+
"ignore:Type google._upb._message.MessageMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14.:DeprecationWarning:importlib.*:",
"ignore:Type google._upb._message.ScalarMapContainer uses PyType_Spec with a metaclass that has custom tp_new. This is deprecated and will no longer be allowed in Python 3.14.:DeprecationWarning:importlib.*:",
]
markers = [
"network: marks tests as requiring internet (deselect with '-m \"not network\"')",
Expand Down
2 changes: 2 additions & 0 deletions skops/_min_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
"catboost": ("1.0", "tests", None),
"fairlearn": ("0.7.0", "docs, tests", None),
"rich": ("12", "tests, rich", None),
"scikeras": ("0.12.0", "docs, tests", None),
"tensorflow": ("2.12.0", "docs, tests", None),
}


Expand Down
2 changes: 1 addition & 1 deletion skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ def add_permutation_importances(
_, ax = plt.subplots()
ax.boxplot(
x=permutation_importances.importances[sorted_importances_idx].T,
labels=columns[sorted_importances_idx],
tick_labels=columns[sorted_importances_idx],
vert=False,
)
ax.set_title(plot_name)
Expand Down
9 changes: 8 additions & 1 deletion skops/io/_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
# We load the dispatch functions from the corresponding modules and register
# them. Old protocols are found in the 'old/' directory, with the protocol
# version appended to the corresponding module name.
modules = ["._general", "._numpy", "._scipy", "._sklearn", "._quantile_forest"]
modules = [
"._general",
"._numpy",
"._scikeras",
"._scipy",
"._sklearn",
"._quantile_forest",
]
modules.extend([".old._general_v0", ".old._numpy_v0"])
for module_name in modules:
# register exposed functions for get_state and get_tree
Expand Down
63 changes: 63 additions & 0 deletions skops/io/_scikeras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import io
import os
import tempfile
from typing import Sequence, Type

import tensorflow as tf
from scikeras.wrappers import KerasClassifier, KerasRegressor
lazarust marked this conversation as resolved.
Show resolved Hide resolved

from ._audit import Node
from ._protocol import PROTOCOL
from ._utils import Any, LoadContext, SaveContext, get_module


def scikeras_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
res = {
"__class__": obj.__class__.__name__,
"__module__": get_module(type(obj)),
"__loader__": "SciKerasNode",
}

obj_id = save_context.memoize(obj)
f_name = f"{obj_id}.keras"

with tempfile.TemporaryDirectory() as temp_dir:
file_name = os.path.join(temp_dir, "model.keras")
obj.model.save(file_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we only save the model attribute? This sounds odd.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding of https://keras.io/guides/serialization_and_saving/, it seems that Keras is compressing all the pieces of the models into the .keras file. Should I change the name of the file to make it more clear?

save_context.zip_file.write(file_name, f_name)

res.update(type="scikeras", file=f_name)
return res


class SciKerasNode(Node):
def __init__(
self,
state: dict[str, Any],
load_context: LoadContext,
trusted: bool | Sequence[str] = False,
) -> None:
super().__init__(state, load_context, trusted)
self.trusted = self._get_trusted(trusted, default=[])

self.children = {"content": io.BytesIO(load_context.src.read(state["file"]))}

def _construct(self):
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, "model.keras")
with open(file_path, "wb") as f:
f.write(self.children["content"].getbuffer())
model = tf.keras.models.load_model(file_path, compile=False)
return model


GET_STATE_DISPATCH_FUNCTIONS = [
(KerasClassifier, scikeras_get_state),
(KerasRegressor, scikeras_get_state),
]

NODE_TYPE_MAPPING: dict[tuple[str, int], Type[Node]] = {
("SciKerasNode", PROTOCOL): SciKerasNode
}
49 changes: 48 additions & 1 deletion skops/io/tests/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
with a range of hyperparameters.

"""

from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -427,3 +426,51 @@ def test_quantile_forest(self, quantile_forest, regr_data, trusted, tree_method)
assert_method_outputs_equal(estimator, loaded, X)

visualize(dumped, trusted=trusted)


class TestSciKeras:
"""Tests for SciKerasRegressor and SciKerasClassifier"""

@pytest.fixture
def trusted(self):
return ["scikeras.wrappers.KerasClassifier"]

@pytest.fixture(autouse=True)
def capture_stdout(self):
# Mock print and rich.print so that running these tests with pytest -s
# does not spam stdout. Other, more common methods of suppressing
# printing to stdout don't seem to work, perhaps because of pytest.
with patch("builtins.print", Mock()), patch("rich.print", Mock()):
yield

@pytest.fixture(autouse=True)
def tensorflow(self):
tensorflow = pytest.importorskip("tensorflow")
return tensorflow

def test_dumping_model(self, tensorflow, trusted):
# This simplifies the basic usage tutorial from https://adriangb.com/scikeras/stable/notebooks/Basic_Usage.html

n_features_in_ = 20
model = tensorflow.keras.models.Sequential()
model.add(tensorflow.keras.layers.Input(shape=(n_features_in_,)))
model.add(tensorflow.keras.layers.Dense(1, activation="sigmoid"))

from scikeras.wrappers import KerasClassifier

clf = KerasClassifier(model=model, loss="binary_crossentropy")

X, y = make_classification(1000, 20, n_informative=10, random_state=0)
clf.fit(X, y)

predictions = clf.predict(X)

dumped = dumps(clf)

# Loads returns the Keras model so we need to initialize it as a SciKeras model
new_clf_model = loads(dumped, trusted=trusted)

clf_new = KerasClassifier(new_clf_model)
clf_new.initialize(X, y)
new_preidctions = clf_new.predict(X)
assert all(new_preidctions == predictions)
Loading