diff --git a/figures/cross_validation_train_test_diagram.png b/figures/cross_validation_train_test_diagram.png index 1e485c737..64bc3622c 100644 Binary files a/figures/cross_validation_train_test_diagram.png and b/figures/cross_validation_train_test_diagram.png differ diff --git a/figures/nested_cross_validation_diagram.png b/figures/nested_cross_validation_diagram.png index 6ad68ec6a..559404ea5 100644 Binary files a/figures/nested_cross_validation_diagram.png and b/figures/nested_cross_validation_diagram.png differ diff --git a/figures/plot_parameter_tuning_cv.py b/figures/plot_parameter_tuning_cv.py index f16a80cd6..5a7caa595 100644 --- a/figures/plot_parameter_tuning_cv.py +++ b/figures/plot_parameter_tuning_cv.py @@ -1,6 +1,6 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib.colors import LinearSegmentedColormap +from matplotlib.colors import ListedColormap from matplotlib.patches import Patch from pathlib import Path from sklearn.model_selection import KFold @@ -9,16 +9,18 @@ FIGURES_FOLDER = Path(__file__).parent plt.style.use(FIGURES_FOLDER / "../python_scripts/matplotlibrc") -colors = ["#009e73ff", "#fd3c06ff", "#0072b2ff"] -cmap_name = "my_list" -cmap_cv = LinearSegmentedColormap.from_list(cmap_name, colors=colors, N=8) +colors_cv = ["#009e73ff", "#fd3c06ff", "white"] +colors_eval = ["#fd3c06ff", "#fd3c06ff", "#0072b2ff"] +cmap_cv = ListedColormap(colors=colors_cv) +cmap_eval = ListedColormap(colors=colors_eval) -def plot_cv_indices(cv, X, y, ax, lw=50): +def plot_cv_indices(cv, X, y, axs): """Create a sample plot for indices of a cross-validation object embeded in a train-test split.""" splits = list(cv.split(X=X, y=y)) n_splits = len(splits) + ax1, ax2 = axs # Generate the training/testing visualizations for each CV split for ii, (train, test) in enumerate(splits): @@ -28,7 +30,7 @@ def plot_cv_indices(cv, X, y, ax, lw=50): indices[X.shape[0] : X.shape[0] + 10] = 2 # Visualize the results - ax.scatter( + ax1.scatter( range(len(indices)), [ii + 0.5] * len(indices), c=indices, @@ -36,112 +38,168 @@ def plot_cv_indices(cv, X, y, ax, lw=50): lw=25, cmap=cmap_cv, ) + ax2.scatter( + range(len(indices)), + [0.5] * len(indices), + c=indices, + marker="_", + lw=25, + cmap=cmap_eval, + ) # Formatting yticklabels = list(range(n_splits)) - ax.set( + ax1.set( yticks=np.arange(n_splits) + 0.5, yticklabels=yticklabels, - xlabel="Sample index", ylabel="CV iteration", ylim=[n_splits + 0.2, -0.2], xlim=[0, 50], ) - ax.set_title( - "{} cross validation inside (non-shuffled-)train-test split".format( - type(cv).__name__ - ) + + ax2.set( + yticks=[0.5], + yticklabels=[], + xlabel="Sample index", + ymargin=10, + ylim=[0.3, 0.7], + xlim=[0, 50], ) - ax.legend( + ax2.set_ylabel("refit +\nevaluation", labelpad=15) + ax2.legend( [ - Patch(color=cmap_cv(0.9)), Patch(color=cmap_cv(0.5)), Patch(color=cmap_cv(0.02)), + Patch(color=cmap_eval(0.9)), + ], + [ + "Training samples", + "Validation samples\n(for hyperparameter\ntuning)", + "Testing samples\n(reserved until\nfinal evaluation)", ], - ["Testing samples", "Training samples", "Validation samples"], - loc=(1.02, 0.7), + loc=(1.02, 1.1), + labelspacing=1, ) - return ax + return n_points = 40 X = np.random.randn(n_points, 10) y = np.random.randn(n_points) -fig, ax = plt.subplots(figsize=(12, 4)) +fig, axs = plt.subplots( + ncols=1, + nrows=2, + sharex=True, + figsize=(12, 5), + gridspec_kw={"height_ratios": [5, 1.5], "hspace": 0}, +) cv = KFold(5) -_ = plot_cv_indices(cv, X, y, ax) +plot_cv_indices(cv, X, y, axs) +plt.suptitle( + "Internal {} cross-validation in GridSearchCV".format( + type(cv).__name__), + y=0.95, + ) plt.tight_layout() fig.savefig(FIGURES_FOLDER / "cross_validation_train_test_diagram.png") -def plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax, lw=50): +def plot_cv_nested_indices(cv_inner, cv_outer, X, y, axs): """Create a sample plot for indices of a nested cross-validation object.""" splits_outer = list(cv_outer.split(X=X, y=y)) n_splits_outer = len(splits_outer) # Generate the training/testing visualizations for each CV split - for ii, (train_outer, test_outer) in enumerate(splits_outer): - + for outer_index, (train_outer, test_outer) in enumerate(splits_outer): splits_inner = list(cv_inner.split(train_outer)) n_splits_inner = len(splits_inner) # Fill in indices with the training/test groups - for jj, (train_inner, test_inner) in enumerate(splits_inner): + for inner_index, (train_inner, test_inner) in enumerate(splits_inner): indices = np.zeros(shape=X.shape[0], dtype=np.int32) indices[train_outer[train_inner]] = 1 indices[test_outer] = 2 # Visualize the results - ax.scatter( + axs[outer_index * 2].scatter( range(len(indices)), - [n_splits_inner * ii + jj + 0.5] * len(indices), + [inner_index + 0.6] * len(indices), c=indices, marker="_", lw=25, cmap=cmap_cv, ) + axs[outer_index*2 + 1].scatter( + range(len(indices)), + [0.5] * len(indices), + c=indices, + marker="_", + lw=25, + cmap=cmap_eval, + ) + axs[outer_index*2 + 1].set( + yticks=[0.5], + yticklabels=["refit +\nevaluation"], + xlabel="Sample index", + ymargin=10, + ylim=[0.3, 0.7], + xlim=[0, 50], + ) + # Formatting - ax.set_title("{} nested cross-validation".format(type(cv_outer).__name__)) - ax1 = ax.twinx() - yticklabels = n_splits_outer * list(range(n_splits_inner)) - ax1.set( - yticks=np.arange(n_splits_outer * n_splits_inner) + 0.3, - yticklabels=yticklabels, - xlabel="Sample index", - ylabel="CV inner iteration", - ylim=[n_splits_outer * n_splits_inner + 0.2, -0.2], - xlim=[0, 50], - ) - yticklabels = list(range(n_splits_outer)) - ax.set( - yticks=n_splits_inner*np.arange(n_splits_outer) + 0.5, - yticklabels=yticklabels, - xlabel="Sample index", - ylabel="CV outer iteration", - ylim=[n_splits_outer * n_splits_inner + 0.2, 0.08], - xlim=[0, 50], - ) - ax.legend( + ax_twin = axs[outer_index * 2].twinx() + yticklabels = list(range(n_splits_inner)) + ax_twin.set( + yticks=np.arange(n_splits_inner) + 0.4, + yticklabels=yticklabels, + xlabel="Sample index", + ylabel="inner iteration", + ylim=[n_splits_inner + 0.2, -0.2], + xlim=[0, 50], + ) + + axs[outer_index * 2].set( + yticks=n_splits_inner * np.arange(n_splits_outer) + 0.5, + yticklabels=[outer_index] * n_splits_outer, + xlabel="Sample index", + ylim=[ n_splits_inner + 0.2, 0.08], + xlim=[0, 50], + ) + + axs[0].legend( [ - Patch(color=cmap_cv(0.9)), Patch(color=cmap_cv(0.5)), Patch(color=cmap_cv(0.02)), + Patch(color=cmap_eval(0.9)), + ], + [ + "Training samples", + "Validation samples\n(for hyperparameter\ntuning)", + "Testing samples\n(reserved until\nevaluation)", ], - ["Testing samples", "Training samples", "Validation samples"], - loc=(1.06, .93), + loc=(1.2, -0.2), + labelspacing=1, ) - return ax + return n_points = 50 X = np.random.randn(n_points, 10) y = np.random.randn(n_points) -fig, ax = plt.subplots(figsize=(12, 12)) +fig, axs = plt.subplots( + ncols=1, + nrows=10, + sharex=True, + figsize=(14, 15), + gridspec_kw={"height_ratios": [5, 1.5] * 5, "hspace": 0}, +) cv_inner = KFold(n_splits=4, shuffle=False) cv_outer = KFold(n_splits=5, shuffle=False) -_ = plot_cv_nested_indices(cv_inner, cv_outer, X, y, ax) +plot_cv_nested_indices(cv_inner, cv_outer, X, y, axs) +plt.suptitle("{} nested cross-validation".format(type(cv_outer).__name__), y=0.97) plt.tight_layout() +fig.text(0.0, 0.5, "outer iteration", va="center", rotation="vertical") fig.savefig(FIGURES_FOLDER / "nested_cross_validation_diagram.png")