Skip to content

Commit

Permalink
Merge pull request #65 from sct-pipeline/jca/cnr
Browse files Browse the repository at this point in the history
Add CNR figure for spine-generic data
  • Loading branch information
jcohenadad authored Jan 25, 2022
2 parents 1fe1bf2 + eaac07e commit f8fd2b8
Showing 1 changed file with 56 additions and 32 deletions.
88 changes: 56 additions & 32 deletions gmchallenge/generate_figure_spinegeneric.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
#!/usr/bin/env python
#
# Generate figures for the spine-generic results

import pandas as pd
import numpy as np
import argparse
import seaborn as sns
import os
import matplotlib.pyplot as plt
import ptitprince as pt
import seaborn as sns
from matplotlib.patches import PathPatch

sns.set(style="whitegrid", font_scale=1)


def get_parameters():
def get_parser():
parser = argparse.ArgumentParser(description='Generate figure for spine generic dataset')
parser.add_argument("-ir", "--path-input-results",
help="Path to results.csv",
Expand All @@ -20,19 +24,15 @@ def get_parameters():
required=True)
parser.add_argument("-o", "--path-output",
help="Path to save images",
required=True,
)
arguments = parser.parse_args()
return arguments
required=True)
return parser


def adjust_box_widths(g, fac):
# From https://github.com/mwaskom/seaborn/issues/1076#issuecomment-634541579

"""
Adjust the widths of a seaborn-generated boxplot.
Source: From https://github.com/mwaskom/seaborn/issues/1076#issuecomment-634541579
"""

# iterating through Axes instances
for ax in g.axes:

Expand All @@ -58,32 +58,50 @@ def adjust_box_widths(g, fac):

# setting new width of median line
for l in ax.lines:
if not l.get_xdata().size == 0:
if np.all(np.equal(l.get_xdata(), [xmin, xmax])):
if not len(l.get_xdata()) == 0:
if np.all(np.equal(l.get_xdata()[0:2], [xmin, xmax])):
l.set_xdata([xmin_new, xmax_new])


def generate_figure(data_in, column, path_output):
# Hue Input for Subgroups
dx = np.ones(len(data_in[column]))
dy = column
dhue = "Manufacturer"
ort = "v"
# dodge blue, limegreen, red
colors = [ "#1E90FF", "#32CD32","#FF0000" ]
pal = colors
sigma = .2
hue = "Manufacturer"
pal = ["#1E90FF", "#32CD32", "#FF0000"]
f, ax = plt.subplots(figsize=(4, 6))

ax = pt.RainCloud(x=dx, y=dy, hue=dhue, data=data_in, palette=pal, bw=sigma,
width_viol=.5, ax=ax, orient=ort, alpha=.4, dodge=True, width_box=.35,
box_showmeans=True,
box_meanprops={"marker":"^", "markerfacecolor":"black", "markeredgecolor":"black", "markersize":"10"},
box_notch=True)
if column == 'CNR_single/t':
coeff = 100
else:
coeff = 1
ax = pt.half_violinplot(x=dx, y=dy, data=data_in*coeff, hue=hue, palette=pal, bw=.4, cut=0., linewidth=0.,
scale="area", width=.8, inner=None, orient="v", dodge=False, alpha=.4, offset=0.5)
ax = sns.boxplot(x=dx, y=dy, data=data_in*coeff, hue=hue, color="black", palette=pal,
showcaps=True, boxprops={'facecolor': 'none', "zorder": 10}, showmeans=True,
meanprops={"marker": "^", "markerfacecolor": "black", "markeredgecolor": "black",
"markersize": "8"},
showfliers=True, whiskerprops={'linewidth': 2, "zorder": 10},
saturation=1, orient="v", dodge=True)
ax = sns.stripplot(x=dx, y=dy, data=data_in*coeff, hue=hue, palette=pal, edgecolor="white",
size=3, jitter=1, zorder=0, orient="v", dodge=True)
plt.xlim([-1, 0.5])
handles, labels = ax.get_legend_handles_labels()
# The code below doesn't work (the label for CNR is "GEGEGEGEGEGEG...") so i need to hard-code the labels (because
# I don't have time to dig further).
# _ = plt.legend(handles[0:len(labels) // 3], labels[0:len(labels) // 3],
# bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,
# title=str(hue))
_ = plt.legend(handles[0:3], ['Philips', 'Siemens', 'GE'],
bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,
title=str(hue))
f.gca().invert_xaxis()
#adjust boxplot width
adjust_box_widths(f, 0.4)
plt.xlabel(column)
adjust_box_widths(f, 0.6)
# special hack
if column == 'CNR_single/t':
plt.xlabel('CNR_single/√t')
fname_out = os.path.join(path_output, 'figure_CNR_single_t')
else:
plt.xlabel(column)
fname_out = os.path.join(path_output, 'figure_' + column)
# remove ylabel
plt.ylabel('')
# hide xtick
Expand All @@ -93,11 +111,17 @@ def generate_figure(data_in, column, path_output):
bottom=False,
top=False,
labelbottom=False)
# plt.legend(title="Line", loc='upper left', handles=handles[::-1])
plt.savefig(os.path.join(path_output, 'figure_' + column), bbox_inches='tight', dpi=300)
plt.savefig(fname_out, bbox_inches='tight', dpi=300)


def main(argv=None):
# user params
parser = get_parser()
args = parser.parse_args(argv)
path_input_results = args.path_input_results
path_input_participants = args.path_input_participants
path_output = args.path_output

def main(path_input_results, path_input_participants, path_output):
if not os.path.isdir(path_output):
os.makedirs(path_output)

Expand All @@ -114,8 +138,8 @@ def main(path_input_results, path_input_participants, path_output):

generate_figure(content_results_csv, 'SNR_single', path_output)
generate_figure(content_results_csv, 'Contrast', path_output)
generate_figure(content_results_csv, 'CNR_single/t', path_output)


if __name__ == "__main__":
args = get_parameters()
main(args.path_input_results, args.path_input_participants, args.path_output)
main()

0 comments on commit f8fd2b8

Please sign in to comment.