diff --git a/skore/src/skore/item/cross_validation_item.py b/skore/src/skore/item/cross_validation_item.py index 205949744..e36692a17 100644 --- a/skore/src/skore/item/cross_validation_item.py +++ b/skore/src/skore/item/cross_validation_item.py @@ -70,19 +70,29 @@ def _metric_favorability( metric: str, ) -> Literal["greater_is_better", "lower_is_better", "unknown"]: greater_is_better_metrics = ( - "r2", - "test_r2", - "roc_auc", - "recall", - "recall_weighted", + "accuracy", + "balanced_accuracy", + "top_k_accuracy", + "average_precision", + "f1", "precision", - "precision_weighted", - "roc_auc_ovr_weighted", + "recall", + "jaccard", + "roc_auc", + "r2", ) - lower_is_better_metrics = ("fit_time", "score_time") - - if metric.endswith("_score") or metric in greater_is_better_metrics: + any_match_greater_is_better = any( + re.search(re.escape(pattern), metric) for pattern in greater_is_better_metrics + ) + if ( + any_match_greater_is_better + # other scikit-learn conventions + or metric.endswith("_score") # score: higher is better + or metric.startswith("neg_") # negative loss: negative of lower is better + ): return "greater_is_better" + + lower_is_better_metrics = ("fit_time", "score_time") if ( metric.endswith("_error") or metric.endswith("_loss") diff --git a/skore/tests/unit/item/test_cross_validation_item.py b/skore/tests/unit/item/test_cross_validation_item.py index 3c3e557bf..ad8f64b49 100644 --- a/skore/tests/unit/item/test_cross_validation_item.py +++ b/skore/tests/unit/item/test_cross_validation_item.py @@ -8,6 +8,7 @@ CrossValidationItem, ItemTypeError, _hash_numpy, + _metric_favorability, ) from skore.sklearn.cross_validation import CrossValidationReporter from skore.sklearn.cross_validation.cross_validation_reporter import ( @@ -86,9 +87,11 @@ def test_factory(self, mock_nowstr, reporter): assert item.cv_results_serialized == {"test_score": [1, 2, 3]} assert item.estimator_info == { "name": reporter.estimator.__class__.__name__, - "params": {} - if isinstance(reporter.estimator, FakeEstimatorNoGetParams) - else {"alpha": {"value": "3", "default": True}}, + "params": ( + {} + if isinstance(reporter.estimator, FakeEstimatorNoGetParams) + else {"alpha": {"value": "3", "default": True}} + ), "module": "tests.unit.item.test_cross_validation_item", } assert item.X_info == { @@ -137,3 +140,53 @@ def test_get_serializable_dict(self, monkeypatch, mock_nowstr): ], } ] + + @pytest.mark.parametrize( + "metric,expected", + [ + # greater_is_better metrics (exact matches) + ("accuracy", "greater_is_better"), + ("balanced_accuracy", "greater_is_better"), + ("top_k_accuracy", "greater_is_better"), + ("average_precision", "greater_is_better"), + ("f1", "greater_is_better"), + ("precision", "greater_is_better"), + ("recall", "greater_is_better"), + ("jaccard", "greater_is_better"), + ("roc_auc", "greater_is_better"), + ("r2", "greater_is_better"), + # greater_is_better metrics (pattern matches) + ("weighted_f1", "greater_is_better"), + ("macro_precision", "greater_is_better"), + ("micro_recall", "greater_is_better"), + # greater_is_better by convention (_score suffix) + ("custom_score", "greater_is_better"), + ("validation_score", "greater_is_better"), + # greater_is_better by convention (neg_ prefix) + ("neg_mean_squared_error", "greater_is_better"), + ("neg_log_loss", "greater_is_better"), + # the same one but without the neg_ prefix + ("mean_squared_error", "lower_is_better"), + ("log_loss", "lower_is_better"), + # lower_is_better metrics (exact matches) + ("fit_time", "lower_is_better"), + ("score_time", "lower_is_better"), + # lower_is_better by convention (suffixes) + ("mean_squared_error", "lower_is_better"), + ("mean_absolute_error", "lower_is_better"), + ("binary_crossentropy_loss", "lower_is_better"), + ("hinge_loss", "lower_is_better"), + ("entropy_deviance", "lower_is_better"), + # unknown metrics + ("custom_metric", "unknown"), + ("undefined", "unknown"), + ("", "unknown"), + ], + ) + def test_metric_favorability(self, metric, expected): + """Test the _metric_favorability function with various metric names. + + Non-regression test for: + https://github.com/probabl-ai/skore/issues/1061 + """ + assert _metric_favorability(metric) == expected