Skip to content

Commit

Permalink
Fix test_get_metrics_dict_scaler unit-test
Browse files Browse the repository at this point in the history
  • Loading branch information
koropets committed Jan 24, 2024
1 parent 5e442d0 commit c05475c
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions tests/gordo/builder/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import sklearn.compose
import sklearn.ensemble
from sklearn.base import BaseEstimator
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.multioutput import MultiOutputRegressor
from sklearn.linear_model import LinearRegression
from sklearn import metrics
Expand Down Expand Up @@ -63,6 +63,17 @@ def machine_check(machine: Machine, check_history):
)


class CustomRegressor(BaseEstimator, RegressorMixin):
def __init__(self, multiplier):
self.multiplier = multiplier

def fit(self, X, y):
return self

def predict(self, X):
return X * self.multiplier


@pytest.mark.parametrize("scaler", [None, "sklearn.preprocessing.MinMaxScaler"])
def test_get_metrics_dict_scaler(scaler):
metrics_list = [sklearn.metrics.mean_squared_error]
Expand All @@ -74,12 +85,10 @@ def test_get_metrics_dict_scaler(scaler):
metrics_dict = ModelBuilder.build_metrics_dict(metrics_list, y, scaler=scaler)
metric_func = metrics_dict["mean-squared-error"]

mock_model1 = MagicMock()
mock_model1.predict = lambda _y: _y * [0.8, 1]
mse_feature_one_wrong = metric_func(mock_model1, y, y)
mock_model2 = MagicMock()
mock_model2.predict = lambda _y: _y * [1, 0.8]
mse_feature_two_wrong = metric_func(mock_model2, y, y)
classifier1 = CustomRegressor(np.array([0.8, 1]))
mse_feature_one_wrong = metric_func(classifier1, y, y)
classifier2 = CustomRegressor(np.array([1, 0.8]))
mse_feature_two_wrong = metric_func(classifier2, y, y)

if scaler:
assert np.isclose(mse_feature_one_wrong, mse_feature_two_wrong)
Expand Down

0 comments on commit c05475c

Please sign in to comment.