-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add a classification report example * add an example for multiclass * finish the example * Use signature instead of poping kwargs * Solve the issue with the doc * Correct mispealing * Add readme for dataset examples
- Loading branch information
Showing
7 changed files
with
192 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
""" | ||
============================================= | ||
Multiclass classification with under-sampling | ||
============================================= | ||
Some balancing methods allow for balancing dataset with multiples classes. | ||
We provide an example to illustrate the use of those methods which do | ||
not differ from the binary case. | ||
""" | ||
|
||
from sklearn.datasets import load_iris | ||
from sklearn.svm import LinearSVC | ||
from sklearn.model_selection import train_test_split | ||
|
||
from imblearn.under_sampling import NearMiss | ||
from imblearn.pipeline import make_pipeline | ||
from imblearn.metrics import classification_report_imbalanced | ||
|
||
print(__doc__) | ||
|
||
RANDOM_STATE = 42 | ||
|
||
# Create a folder to fetch the dataset | ||
iris = load_iris() | ||
# Make the dataset imbalanced | ||
# Select only half of the first class | ||
iris.data = iris.data[25:-1, :] | ||
iris.target = iris.target[25:-1] | ||
|
||
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, | ||
random_state=RANDOM_STATE) | ||
|
||
# Create a pipeline | ||
pipeline = make_pipeline(NearMiss(version=2, random_state=RANDOM_STATE), | ||
LinearSVC(random_state=RANDOM_STATE)) | ||
pipeline.fit(X_train, y_train) | ||
|
||
# Classify and report the results | ||
print(classification_report_imbalanced(y_test, pipeline.predict(X_test))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
.. _dataset_examples: | ||
|
||
Dataset examples | ||
----------------------- | ||
|
||
Examples concerning the :mod:`imblearn.datasets` module. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
""" | ||
============================================= | ||
Evaluate classification by compiling a report | ||
============================================= | ||
Specific metrics have been developed to evaluate classifier which has been | ||
trained using imbalanced data. `imblearn` provides a classification | ||
report similar to `sklearn`, with additional metrics specific to imbalanced | ||
learning problem. | ||
""" | ||
|
||
from sklearn import datasets | ||
from sklearn.svm import LinearSVC | ||
from sklearn.model_selection import train_test_split | ||
|
||
from imblearn import over_sampling as os | ||
from imblearn import pipeline as pl | ||
from imblearn.metrics import classification_report_imbalanced | ||
|
||
print(__doc__) | ||
|
||
RANDOM_STATE = 42 | ||
|
||
# Generate a dataset | ||
X, y = datasets.make_classification(n_classes=2, class_sep=2, | ||
weights=[0.1, 0.9], n_informative=10, | ||
n_redundant=1, flip_y=0, n_features=20, | ||
n_clusters_per_class=4, n_samples=5000, | ||
random_state=RANDOM_STATE) | ||
|
||
pipeline = pl.make_pipeline(os.SMOTE(random_state=RANDOM_STATE), | ||
LinearSVC(random_state=RANDOM_STATE)) | ||
|
||
# Split the data | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, | ||
random_state=RANDOM_STATE) | ||
|
||
# Train the classifier with balancing | ||
pipeline.fit(X_train, y_train) | ||
|
||
# Test the classifier and get the prediction | ||
y_pred_bal = pipeline.predict(X_test) | ||
|
||
# Show the classification report | ||
print(classification_report_imbalanced(y_test, y_pred_bal)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
""" | ||
======================================= | ||
Metrics specific to imbalanced learning | ||
======================================= | ||
Specific metrics have been developed to evaluate classifier which | ||
has been trained using imbalanced data. `imblearn` provides mainly | ||
two additional metrics which are not implemented in `sklearn`: (i) | ||
geometric mean and (ii) index balanced accuracy. | ||
""" | ||
|
||
from sklearn import datasets | ||
from sklearn.svm import LinearSVC | ||
from sklearn.model_selection import train_test_split | ||
|
||
from imblearn import over_sampling as os | ||
from imblearn import pipeline as pl | ||
from imblearn.metrics import (geometric_mean_score, | ||
make_index_balanced_accuracy) | ||
|
||
print(__doc__) | ||
|
||
RANDOM_STATE = 42 | ||
|
||
# Generate a dataset | ||
X, y = datasets.make_classification(n_classes=3, class_sep=2, | ||
weights=[0.1, 0.9], n_informative=10, | ||
n_redundant=1, flip_y=0, n_features=20, | ||
n_clusters_per_class=4, n_samples=5000, | ||
random_state=RANDOM_STATE) | ||
|
||
pipeline = pl.make_pipeline(os.SMOTE(random_state=RANDOM_STATE), | ||
LinearSVC(random_state=RANDOM_STATE)) | ||
|
||
# Split the data | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, | ||
random_state=RANDOM_STATE) | ||
|
||
# Train the classifier with balancing | ||
pipeline.fit(X_train, y_train) | ||
|
||
# Test the classifier and get the prediction | ||
y_pred_bal = pipeline.predict(X_test) | ||
|
||
############################################################################### | ||
# The geometric mean corresponds to the square root of the product of the | ||
# sensitivity and specificity. Combining the two metrics should account for | ||
# the balancing of the dataset. | ||
|
||
print('The geometric mean is {}'.format(geometric_mean_score( | ||
y_test, | ||
y_pred_bal))) | ||
|
||
############################################################################### | ||
# The index balanced accuracy can transform any metric to be used in | ||
# imbalanced learning problems. | ||
|
||
alpha = 0.1 | ||
geo_mean = make_index_balanced_accuracy(alpha=alpha, squared=True)( | ||
geometric_mean_score) | ||
|
||
print('The IBA using alpha = {} and the geometric mean: {}'.format( | ||
alpha, geo_mean( | ||
y_test, | ||
y_pred_bal))) | ||
|
||
alpha = 0.5 | ||
geo_mean = make_index_balanced_accuracy(alpha=alpha, squared=True)( | ||
geometric_mean_score) | ||
|
||
print('The IBA using alpha = {} and the geometric mean: {}'.format( | ||
alpha, geo_mean( | ||
y_test, | ||
y_pred_bal))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters