-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from arthurcgusmao/plot-mcc-f1-curve
Plot MCC-F1 curve
- Loading branch information
Showing
9 changed files
with
584 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from .mcc_f1_curve import mcc_f1_curve | ||
|
||
from ._plot.mcc_f1_curve import plot_mcc_f1_curve | ||
from ._plot.mcc_f1_curve import MCCF1CurveDisplay |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
"""Code contained in this file was extracted directly from the | ||
`sklearn.metrics._plot.base` module, and is included in this package for | ||
backward compatibility with previous sklearn versions.""" | ||
|
||
import numpy as np | ||
|
||
from sklearn.base import is_classifier | ||
|
||
|
||
def _check_classifier_response_method(estimator, response_method): | ||
"""Return prediction method from the response_method | ||
Parameters | ||
---------- | ||
estimator: object | ||
Classifier to check | ||
response_method: {'auto', 'predict_proba', 'decision_function'} | ||
Specifies whether to use :term:`predict_proba` or | ||
:term:`decision_function` as the target response. If set to 'auto', | ||
:term:`predict_proba` is tried first and if it does not exist | ||
:term:`decision_function` is tried next. | ||
Returns | ||
------- | ||
prediction_method: callable | ||
prediction method of estimator | ||
""" | ||
|
||
if response_method not in ("predict_proba", "decision_function", "auto"): | ||
raise ValueError("response_method must be 'predict_proba', " | ||
"'decision_function' or 'auto'") | ||
|
||
error_msg = "response method {} is not defined in {}" | ||
if response_method != "auto": | ||
prediction_method = getattr(estimator, response_method, None) | ||
if prediction_method is None: | ||
raise ValueError(error_msg.format(response_method, | ||
estimator.__class__.__name__)) | ||
else: | ||
predict_proba = getattr(estimator, 'predict_proba', None) | ||
decision_function = getattr(estimator, 'decision_function', None) | ||
prediction_method = predict_proba or decision_function | ||
if prediction_method is None: | ||
raise ValueError(error_msg.format( | ||
"decision_function or predict_proba", | ||
estimator.__class__.__name__)) | ||
|
||
return prediction_method | ||
|
||
|
||
def _get_response(X, estimator, response_method, pos_label=None): | ||
"""Return response and positive label. | ||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
Input values. | ||
estimator : estimator instance | ||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` | ||
in which the last estimator is a classifier. | ||
response_method: {'auto', 'predict_proba', 'decision_function'} | ||
Specifies whether to use :term:`predict_proba` or | ||
:term:`decision_function` as the target response. If set to 'auto', | ||
:term:`predict_proba` is tried first and if it does not exist | ||
:term:`decision_function` is tried next. | ||
pos_label : str or int, default=None | ||
The class considered as the positive class when computing | ||
the metrics. By default, `estimators.classes_[1]` is | ||
considered as the positive class. | ||
Returns | ||
------- | ||
y_pred: ndarray of shape (n_samples,) | ||
Target scores calculated from the provided response_method | ||
and pos_label. | ||
pos_label: str or int | ||
The class considered as the positive class when computing | ||
the metrics. | ||
""" | ||
classification_error = ( | ||
"{} should be a binary classifier".format(estimator.__class__.__name__) | ||
) | ||
|
||
if not is_classifier(estimator): | ||
raise ValueError(classification_error) | ||
|
||
prediction_method = _check_classifier_response_method( | ||
estimator, response_method) | ||
|
||
y_pred = prediction_method(X) | ||
|
||
if pos_label is not None and pos_label not in estimator.classes_: | ||
raise ValueError( | ||
f"The class provided by 'pos_label' is unknown. Got " | ||
f"{pos_label} instead of one of {estimator.classes_}" | ||
) | ||
|
||
if y_pred.ndim != 1: # `predict_proba` | ||
if y_pred.shape[1] != 2: | ||
raise ValueError(classification_error) | ||
if pos_label is None: | ||
pos_label = estimator.classes_[1] | ||
y_pred = y_pred[:, 1] | ||
else: | ||
class_idx = np.flatnonzero(estimator.classes_ == pos_label) | ||
y_pred = y_pred[:, class_idx] | ||
else: | ||
if pos_label is None: | ||
pos_label = estimator.classes_[1] | ||
elif pos_label == estimator.classes_[0]: | ||
y_pred *= -1 | ||
|
||
return y_pred, pos_label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
try: | ||
from sklearn.metrics._plot.base import _get_response | ||
except ImportError: | ||
# Function was not present in sklearn.metrics._plot.base before c3f2516 | ||
from .base import _get_response | ||
|
||
from sklearn.utils import check_matplotlib_support | ||
|
||
# from .. import mcc_f1_metric, mcc_f1_curve | ||
from .. import mcc_f1_curve | ||
|
||
|
||
class MCCF1CurveDisplay: | ||
"""MCC-F1 Curve visualization. | ||
It is recommend to use :func:`~mcc_f1.plot_mcc_f1_curve` to create a | ||
visualizer. All parameters are stored as attributes. | ||
Read more in scikit-learn's :ref:`User Guide <visualizations>`. | ||
Parameters | ||
---------- | ||
f1 : ndarray | ||
F1-Score. | ||
mcc : ndarray | ||
Matthews Correlation Coefficient. | ||
mcc_f1 : float, default=None | ||
MCC-F1 metric. If None, the mcc_f1 score is not shown. | ||
estimator_name : str, default=None | ||
Name of estimator. If None, the estimator name is not shown. | ||
pos_label : str or int, default=None | ||
The class considered as the positive class when computing the metrics. | ||
By default, `estimators.classes_[1]` is considered as the positive | ||
class. | ||
.. versionadded:: 0.24 | ||
Attributes | ||
---------- | ||
line_ : matplotlib Artist | ||
MCC-F1 Curve. | ||
ax_ : matplotlib Axes | ||
Axes with MCC-F1 Curve. | ||
figure_ : matplotlib Figure | ||
Figure containing the curve. | ||
Examples | ||
-------- | ||
>>> import matplotlib.pyplot as plt # doctest: +SKIP | ||
>>> import numpy as np | ||
>>> from sklearn import metrics | ||
>>> from mcc_f1 import mcc_f1_curve, mcc_f1_metric, MCCF1CurveDisplay | ||
>>> y = np.array([0, 0, 1, 1]) | ||
>>> pred = np.array([0.1, 0.4, 0.35, 0.8]) | ||
>>> f1, mcc, thresholds = mcc_f1_curve(y, pred) | ||
>>> mcc_f1 = mcc_f1_metric(f1, mcc) | ||
>>> display = MCCF1CurveDisplay(f1=f1, mcc=mcc, mcc_f1=mcc_f1,\ | ||
estimator_name='example estimator') | ||
>>> display.plot() # doctest: +SKIP | ||
>>> plt.show() # doctest: +SKIP | ||
""" | ||
|
||
def __init__(self, *, f1, mcc, | ||
mcc_f1=None, estimator_name=None, pos_label=None): | ||
self.estimator_name = estimator_name | ||
self.f1 = f1 | ||
self.mcc = mcc | ||
self.mcc_f1 = mcc_f1 | ||
self.pos_label = pos_label | ||
|
||
def plot(self, ax=None, *, name=None, **kwargs): | ||
"""Plot visualization | ||
Extra keyword arguments will be passed to matplotlib's ``plot``. | ||
Parameters | ||
---------- | ||
ax : matplotlib axes, default=None | ||
Axes object to plot on. If `None`, a new figure and axes is | ||
created. | ||
name : str, default=None | ||
Name of ROC Curve for labeling. If `None`, use the name of the | ||
estimator. | ||
Returns | ||
------- | ||
display : :class:`~sklearn.metrics.plot.RocCurveDisplay` | ||
Object that stores computed values. | ||
""" | ||
check_matplotlib_support('MCCF1CurveDisplay.plot') | ||
|
||
name = self.estimator_name if name is None else name | ||
|
||
line_kwargs = {} | ||
if self.mcc_f1 is not None and name is not None: | ||
line_kwargs["label"] = f"{name} (MCC-F1 = {self.mcc_f1:0.2f})" | ||
elif self.mcc_f1 is not None: | ||
line_kwargs["label"] = f"MCC-F1 = {self.mcc_f1:0.2f}" | ||
elif name is not None: | ||
line_kwargs["label"] = name | ||
|
||
line_kwargs.update(**kwargs) | ||
|
||
import matplotlib.pyplot as plt | ||
from matplotlib.figure import figaspect | ||
|
||
if ax is None: | ||
fig, ax = plt.subplots(figsize=figaspect(1.)) | ||
|
||
self.line_, = ax.plot(self.f1, self.mcc, **line_kwargs) | ||
info_pos_label = (f" (Positive label: {self.pos_label})" | ||
if self.pos_label is not None else "") | ||
|
||
xlabel = "F1-Score" + info_pos_label | ||
ylabel = "MCC" + info_pos_label | ||
ax.set(xlabel=xlabel, ylabel=ylabel, xlim=(0,1), ylim=(0,1)) | ||
|
||
if "label" in line_kwargs: | ||
ax.legend(loc="lower right") | ||
|
||
self.ax_ = ax | ||
self.figure_ = ax.figure | ||
return self | ||
|
||
|
||
def plot_mcc_f1_curve(estimator, X, y, *, sample_weight=None, | ||
response_method="auto", name=None, ax=None, | ||
pos_label=None, **kwargs): | ||
"""Plot MCC-F1 curve. | ||
Extra keyword arguments will be passed to matplotlib's `plot`. | ||
Read more in scikit-learn's :ref:`User Guide <visualizations>`. | ||
Parameters | ||
---------- | ||
estimator : estimator instance | ||
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` | ||
in which the last estimator is a classifier. | ||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
Input values. | ||
y : array-like of shape (n_samples,) | ||
Target values. | ||
sample_weight : array-like of shape (n_samples,), default=None | ||
Sample weights. | ||
response_method : {'predict_proba', 'decision_function', 'auto'} \ | ||
default='auto' | ||
Specifies whether to use :term:`predict_proba` or | ||
:term:`decision_function` as the target response. If set to 'auto', | ||
:term:`predict_proba` is tried first and if it does not exist | ||
:term:`decision_function` is tried next. | ||
name : str, default=None | ||
Name of MCC-F1 Curve for labeling. If `None`, use the name of the | ||
estimator. | ||
ax : matplotlib axes, default=None | ||
Axes object to plot on. If `None`, a new figure and axes is created. | ||
pos_label : str or int, default=None | ||
The class considered as the positive class when computing the metrics. | ||
By default, `estimators.classes_[1]` is considered as the positive | ||
class. | ||
.. versionadded:: 0.24 | ||
Returns | ||
------- | ||
display : :class:`~sklearn.metrics.MCCF1CurveDisplay` | ||
Object that stores computed values. | ||
See Also | ||
-------- | ||
mcc_f1_metric : Compute the MCC-F1 metric | ||
mcc_f1_curve : Compute the MCC-F1 curve | ||
Examples | ||
-------- | ||
>>> import matplotlib.pyplot as plt # doctest: +SKIP | ||
>>> from sklearn import datasets, metrics, model_selection, svm | ||
>>> from mcc_f1 import plot_mcc_f1_curve | ||
>>> X, y = datasets.make_classification(random_state=0) | ||
>>> X_train, X_test, y_train, y_test = model_selection.train_test_split( | ||
... X, y, random_state=0) | ||
>>> clf = svm.SVC(random_state=0) | ||
>>> clf.fit(X_train, y_train) | ||
SVC(random_state=0) | ||
>>> plot_mcc_f1_curve(clf, X_test, y_test) # doctest: +SKIP | ||
>>> plt.show() # doctest: +SKIP | ||
""" | ||
check_matplotlib_support('plot_mcc_f1_curve') | ||
|
||
y_pred, pos_label = _get_response( | ||
X, estimator, response_method, pos_label=pos_label) | ||
|
||
mcc, f1, _ = mcc_f1_curve(y, y_pred, pos_label=pos_label, | ||
sample_weight=sample_weight) | ||
# mcc_f1 = mcc_f1_metric(f1, mcc) | ||
mcc_f1 = None | ||
|
||
name = estimator.__class__.__name__ if name is None else name | ||
|
||
viz = MCCF1CurveDisplay( | ||
f1=f1, | ||
mcc=mcc, | ||
mcc_f1=mcc_f1, | ||
estimator_name=name, | ||
pos_label=pos_label | ||
) | ||
|
||
return viz.plot(ax=ax, name=name, **kwargs) |
Oops, something went wrong.