Skip to content

Commit

Permalink
Fix(requirements): use scikeras to bring back keras wrappers
Browse files Browse the repository at this point in the history
* Add `scikeras~=0.13.0` dependency
* Use `scikeras.wrappers.KerasRegressor` instead of removed `tensorflow.keras.wrappers.scikit_learn import KerasRegressor`
  • Loading branch information
RollerKnobster committed Jun 24, 2024
1 parent 4064a8e commit 695c592
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
2 changes: 1 addition & 1 deletion gordo/machine/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tensorflow.keras.models
from tensorflow.keras.models import load_model, save_model
from tensorflow.keras.preprocessing.sequence import pad_sequences, TimeseriesGenerator
from tensorflow.keras.wrappers.scikit_learn import KerasRegressor as BaseWrapper
from scikeras.wrappers import KerasRegressor as BaseWrapper
from tensorflow.keras.callbacks import History
import numpy as np
import pandas as pd
Expand Down
7 changes: 6 additions & 1 deletion requirements/full_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ joblib==1.4.2
jsonpickle==3.2.2
# via azureml-core
keras==3.3.3
# via tensorflow
# via
# scikeras
# tensorflow
kiwisolver==1.4.5
# via matplotlib
knack==0.11.0
Expand Down Expand Up @@ -382,10 +384,13 @@ requests-oauthlib==2.0.0
# via msrest
rich==13.7.1
# via keras
scikeras==0.13.0
# via -r requirements.in
scikit-learn==1.5.0
# via
# gordo-core
# mlflow
# scikeras
scipy==1.13.1
# via
# catboost
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ h5py~=3.1
jinja2~=3.1
python-dateutil~=2.8
tensorflow~=2.16.0
scikeras~=0.13.0
Flask>=2.2.5,<3.0.0
simplejson~=3.17
catboost~=1.2.5
Expand Down
3 changes: 1 addition & 2 deletions tests/gordo/machine/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from sklearn.exceptions import NotFittedError
from sklearn.pipeline import Pipeline
from sklearn.model_selection import cross_val_score, TimeSeriesSplit

from tensorflow.keras.wrappers.scikit_learn import KerasRegressor as BaseWrapper
from scikeras.wrappers import KerasRegressor as BaseWrapper
from tensorflow.keras.callbacks import EarlyStopping

from tests.utils import get_model
Expand Down

0 comments on commit 695c592

Please sign in to comment.