diff --git a/src/skore/item/sklearn_base_estimator_item.py b/src/skore/item/sklearn_base_estimator_item.py index b81959b08..6100fb456 100644 --- a/src/skore/item/sklearn_base_estimator_item.py +++ b/src/skore/item/sklearn_base_estimator_item.py @@ -25,8 +25,9 @@ class SklearnBaseEstimatorItem(Item): def __init__( self, - estimator_skops, - estimator_html_repr, + estimator_html_repr: str, + estimator_skops: bytes, + estimator_skops_untrusted_types: list[str], created_at: str | None = None, updated_at: str | None = None, ): @@ -35,10 +36,12 @@ def __init__( Parameters ---------- - estimator_skops : Any - The skops representation of the scikit-learn estimator. estimator_html_repr : str The HTML representation of the scikit-learn estimator. + estimator_skops : bytes + The skops representation of the scikit-learn estimator. + estimator_skops_untrusted_types : list[str] + The list of untrusted types in the skops representation. created_at : str, optional The creation timestamp in ISO format. updated_at : str, optional @@ -46,8 +49,9 @@ def __init__( """ super().__init__(created_at, updated_at) - self.estimator_skops = estimator_skops self.estimator_html_repr = estimator_html_repr + self.estimator_skops = estimator_skops + self.estimator_skops_untrusted_types = estimator_skops_untrusted_types @cached_property def estimator(self) -> sklearn.base.BaseEstimator: @@ -61,7 +65,9 @@ def estimator(self) -> sklearn.base.BaseEstimator: """ import skops.io - return skops.io.loads(self.estimator_skops) + return skops.io.loads( + self.estimator_skops, trusted=self.estimator_skops_untrusted_types + ) @classmethod def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorItem: @@ -85,9 +91,16 @@ def factory(cls, estimator: sklearn.base.BaseEstimator) -> SklearnBaseEstimatorI if not isinstance(estimator, sklearn.base.BaseEstimator): raise TypeError(f"Type '{estimator.__class__}' is not supported.") + estimator_html_repr = sklearn.utils.estimator_html_repr(estimator) + estimator_skops = skops.io.dumps(estimator) + estimator_skops_untrusted_types = skops.io.get_untrusted_types( + data=estimator_skops + ) + instance = cls( - estimator_skops=skops.io.dumps(estimator), - estimator_html_repr=sklearn.utils.estimator_html_repr(estimator), + estimator_html_repr=estimator_html_repr, + estimator_skops=estimator_skops, + estimator_skops_untrusted_types=estimator_skops_untrusted_types, ) # add estimator as cached property diff --git a/tests/unit/item/test_sklearn_base_estimator_item.py b/tests/unit/item/test_sklearn_base_estimator_item.py index 8003430bd..6e79f1d38 100644 --- a/tests/unit/item/test_sklearn_base_estimator_item.py +++ b/tests/unit/item/test_sklearn_base_estimator_item.py @@ -11,19 +11,26 @@ def monkeypatch_datetime(self, monkeypatch, MockDatetime): @pytest.mark.order(0) def test_factory(self, monkeypatch, mock_nowstr): - monkeypatch.setattr("skops.io.dumps", lambda _: "") - monkeypatch.setattr( - "sklearn.utils.estimator_html_repr", lambda _: "" - ) - estimator = sklearn.svm.SVC() - estimator_skops = "" estimator_html_repr = "" + estimator_skops = "" + estimator_skops_untrusted_types = "" + + monkeypatch.setattr( + "sklearn.utils.estimator_html_repr", + lambda *args, **kwargs: estimator_html_repr, + ) + monkeypatch.setattr("skops.io.dumps", lambda *args, **kwargs: estimator_skops) + monkeypatch.setattr( + "skops.io.get_untrusted_types", + lambda *args, **kwargs: estimator_skops_untrusted_types, + ) item = SklearnBaseEstimatorItem.factory(estimator) - assert item.estimator_skops == estimator_skops assert item.estimator_html_repr == estimator_html_repr + assert item.estimator_skops == estimator_skops + assert item.estimator_skops_untrusted_types == estimator_skops_untrusted_types assert item.created_at == mock_nowstr assert item.updated_at == mock_nowstr @@ -31,12 +38,15 @@ def test_factory(self, monkeypatch, mock_nowstr): def test_estimator(self, mock_nowstr): estimator = sklearn.svm.SVC() estimator_skops = skops.io.dumps(estimator) - estimator_html_repr = "" + estimator_skops_untrusted_types = skops.io.get_untrusted_types( + data=estimator_skops + ) item1 = SklearnBaseEstimatorItem.factory(estimator) item2 = SklearnBaseEstimatorItem( + estimator_html_repr=None, estimator_skops=estimator_skops, - estimator_html_repr=estimator_html_repr, + estimator_skops_untrusted_types=estimator_skops_untrusted_types, created_at=mock_nowstr, updated_at=mock_nowstr, )