Skip to content

Commit

Permalink
fix: Use a more flexible approach to check favorability of metrics (#…
Browse files Browse the repository at this point in the history
…1063)

closes #1061 


![image](https://github.com/user-attachments/assets/98d002c1-874d-4de2-bb26-5fc16838b2f1)

We are more flexible using a regular expression to check the score
names. In addition, we take care to test first the `neg_` part that
would mean that negative score are therefore "higher is greater"
convention.

---------

Co-authored-by: Auguste Baum <[email protected]>
  • Loading branch information
glemaitre and augustebaum authored Jan 9, 2025
1 parent 3c05371 commit b7ab74a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 13 deletions.
30 changes: 20 additions & 10 deletions skore/src/skore/item/cross_validation_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,29 @@ def _metric_favorability(
metric: str,
) -> Literal["greater_is_better", "lower_is_better", "unknown"]:
greater_is_better_metrics = (
"r2",
"test_r2",
"roc_auc",
"recall",
"recall_weighted",
"accuracy",
"balanced_accuracy",
"top_k_accuracy",
"average_precision",
"f1",
"precision",
"precision_weighted",
"roc_auc_ovr_weighted",
"recall",
"jaccard",
"roc_auc",
"r2",
)
lower_is_better_metrics = ("fit_time", "score_time")

if metric.endswith("_score") or metric in greater_is_better_metrics:
any_match_greater_is_better = any(
re.search(re.escape(pattern), metric) for pattern in greater_is_better_metrics
)
if (
any_match_greater_is_better
# other scikit-learn conventions
or metric.endswith("_score") # score: higher is better
or metric.startswith("neg_") # negative loss: negative of lower is better
):
return "greater_is_better"

lower_is_better_metrics = ("fit_time", "score_time")
if (
metric.endswith("_error")
or metric.endswith("_loss")
Expand Down
59 changes: 56 additions & 3 deletions skore/tests/unit/item/test_cross_validation_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CrossValidationItem,
ItemTypeError,
_hash_numpy,
_metric_favorability,
)
from skore.sklearn.cross_validation import CrossValidationReporter
from skore.sklearn.cross_validation.cross_validation_reporter import (
Expand Down Expand Up @@ -86,9 +87,11 @@ def test_factory(self, mock_nowstr, reporter):
assert item.cv_results_serialized == {"test_score": [1, 2, 3]}
assert item.estimator_info == {
"name": reporter.estimator.__class__.__name__,
"params": {}
if isinstance(reporter.estimator, FakeEstimatorNoGetParams)
else {"alpha": {"value": "3", "default": True}},
"params": (
{}
if isinstance(reporter.estimator, FakeEstimatorNoGetParams)
else {"alpha": {"value": "3", "default": True}}
),
"module": "tests.unit.item.test_cross_validation_item",
}
assert item.X_info == {
Expand Down Expand Up @@ -137,3 +140,53 @@ def test_get_serializable_dict(self, monkeypatch, mock_nowstr):
],
}
]

@pytest.mark.parametrize(
"metric,expected",
[
# greater_is_better metrics (exact matches)
("accuracy", "greater_is_better"),
("balanced_accuracy", "greater_is_better"),
("top_k_accuracy", "greater_is_better"),
("average_precision", "greater_is_better"),
("f1", "greater_is_better"),
("precision", "greater_is_better"),
("recall", "greater_is_better"),
("jaccard", "greater_is_better"),
("roc_auc", "greater_is_better"),
("r2", "greater_is_better"),
# greater_is_better metrics (pattern matches)
("weighted_f1", "greater_is_better"),
("macro_precision", "greater_is_better"),
("micro_recall", "greater_is_better"),
# greater_is_better by convention (_score suffix)
("custom_score", "greater_is_better"),
("validation_score", "greater_is_better"),
# greater_is_better by convention (neg_ prefix)
("neg_mean_squared_error", "greater_is_better"),
("neg_log_loss", "greater_is_better"),
# the same one but without the neg_ prefix
("mean_squared_error", "lower_is_better"),
("log_loss", "lower_is_better"),
# lower_is_better metrics (exact matches)
("fit_time", "lower_is_better"),
("score_time", "lower_is_better"),
# lower_is_better by convention (suffixes)
("mean_squared_error", "lower_is_better"),
("mean_absolute_error", "lower_is_better"),
("binary_crossentropy_loss", "lower_is_better"),
("hinge_loss", "lower_is_better"),
("entropy_deviance", "lower_is_better"),
# unknown metrics
("custom_metric", "unknown"),
("undefined", "unknown"),
("", "unknown"),
],
)
def test_metric_favorability(self, metric, expected):
"""Test the _metric_favorability function with various metric names.
Non-regression test for:
https://github.com/probabl-ai/skore/issues/1061
"""
assert _metric_favorability(metric) == expected

0 comments on commit b7ab74a

Please sign in to comment.