Skip to content

Commit

Permalink
Fix: remove redundant BaseWrapper alias for KerasRegressor
Browse files Browse the repository at this point in the history
  • Loading branch information
RollerKnobster committed Jun 24, 2024
1 parent bede872 commit 06de56f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions gordo/machine/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tensorflow.keras.models
from tensorflow.keras.models import load_model, save_model
from tensorflow.keras.preprocessing.sequence import pad_sequences, TimeseriesGenerator
from scikeras.wrappers import KerasRegressor as BaseWrapper
from scikeras.wrappers import KerasRegressor
import numpy as np
import pandas as pd
import xarray as xr
Expand All @@ -34,7 +34,7 @@
logger = logging.getLogger(__name__)


class KerasBaseEstimator(BaseWrapper, GordoBase):
class KerasBaseEstimator(KerasRegressor, GordoBase):
supported_fit_args = [
"batch_size",
"epochs",
Expand Down Expand Up @@ -269,7 +269,7 @@ def fit(
y = y.values
kwargs.setdefault("verbose", 0)
history = super().fit(X, y, sample_weight=None, **kwargs)
if isinstance(history, BaseWrapper):
if isinstance(history, KerasRegressor):
self.history = history.history_
return self

Expand Down

0 comments on commit 06de56f

Please sign in to comment.