diff --git a/gmchallenge/generate_figure_spinegeneric.py b/gmchallenge/generate_figure_spinegeneric.py index b0a2064..f2f0cf2 100644 --- a/gmchallenge/generate_figure_spinegeneric.py +++ b/gmchallenge/generate_figure_spinegeneric.py @@ -1,16 +1,20 @@ +#!/usr/bin/env python +# +# Generate figures for the spine-generic results + import pandas as pd import numpy as np import argparse -import seaborn as sns import os import matplotlib.pyplot as plt import ptitprince as pt +import seaborn as sns from matplotlib.patches import PathPatch sns.set(style="whitegrid", font_scale=1) -def get_parameters(): +def get_parser(): parser = argparse.ArgumentParser(description='Generate figure for spine generic dataset') parser.add_argument("-ir", "--path-input-results", help="Path to results.csv", @@ -20,19 +24,15 @@ def get_parameters(): required=True) parser.add_argument("-o", "--path-output", help="Path to save images", - required=True, - ) - arguments = parser.parse_args() - return arguments + required=True) + return parser def adjust_box_widths(g, fac): - # From https://github.com/mwaskom/seaborn/issues/1076#issuecomment-634541579 - """ Adjust the widths of a seaborn-generated boxplot. + Source: From https://github.com/mwaskom/seaborn/issues/1076#issuecomment-634541579 """ - # iterating through Axes instances for ax in g.axes: @@ -58,32 +58,50 @@ def adjust_box_widths(g, fac): # setting new width of median line for l in ax.lines: - if not l.get_xdata().size == 0: - if np.all(np.equal(l.get_xdata(), [xmin, xmax])): + if not len(l.get_xdata()) == 0: + if np.all(np.equal(l.get_xdata()[0:2], [xmin, xmax])): l.set_xdata([xmin_new, xmax_new]) def generate_figure(data_in, column, path_output): - # Hue Input for Subgroups dx = np.ones(len(data_in[column])) dy = column - dhue = "Manufacturer" - ort = "v" - # dodge blue, limegreen, red - colors = [ "#1E90FF", "#32CD32","#FF0000" ] - pal = colors - sigma = .2 + hue = "Manufacturer" + pal = ["#1E90FF", "#32CD32", "#FF0000"] f, ax = plt.subplots(figsize=(4, 6)) - - ax = pt.RainCloud(x=dx, y=dy, hue=dhue, data=data_in, palette=pal, bw=sigma, - width_viol=.5, ax=ax, orient=ort, alpha=.4, dodge=True, width_box=.35, - box_showmeans=True, - box_meanprops={"marker":"^", "markerfacecolor":"black", "markeredgecolor":"black", "markersize":"10"}, - box_notch=True) + if column == 'CNR_single/t': + coeff = 100 + else: + coeff = 1 + ax = pt.half_violinplot(x=dx, y=dy, data=data_in*coeff, hue=hue, palette=pal, bw=.4, cut=0., linewidth=0., + scale="area", width=.8, inner=None, orient="v", dodge=False, alpha=.4, offset=0.5) + ax = sns.boxplot(x=dx, y=dy, data=data_in*coeff, hue=hue, color="black", palette=pal, + showcaps=True, boxprops={'facecolor': 'none', "zorder": 10}, showmeans=True, + meanprops={"marker": "^", "markerfacecolor": "black", "markeredgecolor": "black", + "markersize": "8"}, + showfliers=True, whiskerprops={'linewidth': 2, "zorder": 10}, + saturation=1, orient="v", dodge=True) + ax = sns.stripplot(x=dx, y=dy, data=data_in*coeff, hue=hue, palette=pal, edgecolor="white", + size=3, jitter=1, zorder=0, orient="v", dodge=True) + plt.xlim([-1, 0.5]) + handles, labels = ax.get_legend_handles_labels() + # The code below doesn't work (the label for CNR is "GEGEGEGEGEGEG...") so i need to hard-code the labels (because + # I don't have time to dig further). + # _ = plt.legend(handles[0:len(labels) // 3], labels[0:len(labels) // 3], + # bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., + # title=str(hue)) + _ = plt.legend(handles[0:3], ['Philips', 'Siemens', 'GE'], + bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., + title=str(hue)) f.gca().invert_xaxis() - #adjust boxplot width - adjust_box_widths(f, 0.4) - plt.xlabel(column) + adjust_box_widths(f, 0.6) + # special hack + if column == 'CNR_single/t': + plt.xlabel('CNR_single/√t') + fname_out = os.path.join(path_output, 'figure_CNR_single_t') + else: + plt.xlabel(column) + fname_out = os.path.join(path_output, 'figure_' + column) # remove ylabel plt.ylabel('') # hide xtick @@ -93,11 +111,17 @@ def generate_figure(data_in, column, path_output): bottom=False, top=False, labelbottom=False) - # plt.legend(title="Line", loc='upper left', handles=handles[::-1]) - plt.savefig(os.path.join(path_output, 'figure_' + column), bbox_inches='tight', dpi=300) + plt.savefig(fname_out, bbox_inches='tight', dpi=300) + +def main(argv=None): + # user params + parser = get_parser() + args = parser.parse_args(argv) + path_input_results = args.path_input_results + path_input_participants = args.path_input_participants + path_output = args.path_output -def main(path_input_results, path_input_participants, path_output): if not os.path.isdir(path_output): os.makedirs(path_output) @@ -114,8 +138,8 @@ def main(path_input_results, path_input_participants, path_output): generate_figure(content_results_csv, 'SNR_single', path_output) generate_figure(content_results_csv, 'Contrast', path_output) + generate_figure(content_results_csv, 'CNR_single/t', path_output) if __name__ == "__main__": - args = get_parameters() - main(args.path_input_results, args.path_input_participants, args.path_output) + main()