Skip to content

Commit

Permalink
add largeZ model in comparison, but comment it out for now
Browse files Browse the repository at this point in the history
  • Loading branch information
naga-karthik committed Nov 11, 2024
1 parent c5974b0 commit 1e21f2c
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions testing/generate_figures_and_compute_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
"Dataset904_tumMSStitchedStraightRegion": 'Stitched\nStraightened',
# "Dataset910_tumMSChunksPolyNYUAxialRegion": 'Chunks\nPolyNYU',
}
order_datasets_tum_largeZ = {
"Dataset902_tumMSStitchedRegion": 'Stitched\nNative',
"Dataset904_tumMSStitchedStraightRegion": 'Stitched\nStraightened',
# "Dataset910_tumMSChunksPolyNYUAxialRegion": 'Chunks\nPolyNYU',
}

order_datasets_testing_large = {
"DeepSegLesionInference_tumNeuropoly": 'DeepSeg\nLesion',
Expand Down Expand Up @@ -85,9 +90,14 @@ def find_dataset_in_path(path):


def find_model_in_path(path):
match = re.search(r'2d|3d_fullres', path)
match = re.search(r'2d|3d_fullres', path) #|3d_largeZ', path)
if match:
return '2D' if '2d' in match.group(0) else '3D'
if '2d' in match.group(0):
return '2D'
elif '3d_fullres' in match.group(0):
return '3D'
elif '3d_largeZ' in match.group(0):
return '3D_largeZ'


def find_filename_in_path(path):
Expand Down Expand Up @@ -131,7 +141,7 @@ def create_rainplot(args, df, metrics, path_figures, pred_type):
box_showmeans=True, # show mean value inside the boxplots
box_meanprops={'marker': '^', 'markerfacecolor': 'black', 'markeredgecolor': 'black',
'markersize': '6'},
hue_order=['3D', '2D'],
hue_order=['3D', '2D'] #, '3D_largeZ'],
)

# TODO: include mean +- std for each boxplot above the mean value
Expand All @@ -149,7 +159,8 @@ def create_rainplot(args, df, metrics, path_figures, pred_type):
labels[i] = f'{label} Model' + ' ($\it{n}$' + f' = {n})'
# Since the figure contains violionplot + boxplot + scatterplot we are keeping only last two legend entries
handles, labels = handles[-2:], labels[-2:]
ax.legend(handles, labels, fontsize=TICK_FONT_SIZE, loc='lower center', bbox_to_anchor=(0.5, -0.25), ncol=2)
ax.legend(handles, labels, fontsize=TICK_FONT_SIZE, loc='lower center',
bbox_to_anchor=(0.5, -0.25), ncol=len(labels))

# Make legend box's frame color black and remove transparency
legend = ax.get_legend()
Expand Down Expand Up @@ -372,12 +383,14 @@ def compute_kruskal_wallis_test_across_tum_datasets(df_concat, list_of_metrics):

df_stitch_native_2d = df_concat[(df_concat['dataset'] == 'Dataset902_tumMSStitchedRegion') & (df_concat['model'] == '2D')]
df_stitch_native_3d = df_concat[(df_concat['dataset'] == 'Dataset902_tumMSStitchedRegion') & (df_concat['model'] == '3D')]
# df_stitch_native_3d_largeZ = df_concat[(df_concat['dataset'] == 'Dataset902_tumMSStitchedRegion') & (df_concat['model'] == '3D_largeZ')]

df_chunks_straight_2d = df_concat[(df_concat['dataset'] == 'Dataset903_tumMSChunksStraightRegion') & (df_concat['model'] == '2D')]
df_chunks_straight_3d = df_concat[(df_concat['dataset'] == 'Dataset903_tumMSChunksStraightRegion') & (df_concat['model'] == '3D')]

df_stitch_straight_2d = df_concat[(df_concat['dataset'] == 'Dataset904_tumMSStitchedStraightRegion') & (df_concat['model'] == '2D')]
df_stitch_straight_3d = df_concat[(df_concat['dataset'] == 'Dataset904_tumMSStitchedStraightRegion') & (df_concat['model'] == '3D')]
# df_stitch_straight_3d_largeZ = df_concat[(df_concat['dataset'] == 'Dataset904_tumMSStitchedStraightRegion') & (df_concat['model'] == '3D_largeZ')]

# ensure that the two dataframes have the same number of rows
assert len(df_chunks_native_2d) == len(df_stitch_native_2d) == len(df_chunks_straight_2d) == len(df_stitch_straight_2d) == \
Expand All @@ -392,10 +405,12 @@ def compute_kruskal_wallis_test_across_tum_datasets(df_concat, list_of_metrics):
f'{metric}_chunks_native_3d': df_chunks_native_3d[metric].values,
f'{metric}_stitch_native_2d': df_stitch_native_2d[metric].values,
f'{metric}_stitch_native_3d': df_stitch_native_3d[metric].values,
# f'{metric}_stitch_native_3d_largeZ': df_stitch_native_3d_largeZ[metric].values,
f'{metric}_chunks_straight_2d': df_chunks_straight_2d[metric].values,
f'{metric}_chunks_straight_3d': df_chunks_straight_3d[metric].values,
f'{metric}_stitch_straight_2d': df_stitch_straight_2d[metric].values,
f'{metric}_stitch_straight_3d': df_stitch_straight_3d[metric].values,
# f'{metric}_stitch_straight_3d_largeZ': df_stitch_straight_3d_largeZ[metric].values
})

# Print number of subjects
Expand All @@ -404,9 +419,10 @@ def compute_kruskal_wallis_test_across_tum_datasets(df_concat, list_of_metrics):
# Compute Kruskal-Wallis H-test
stat, p = kruskal(
df[metric + '_chunks_native_2d'], df[metric + '_chunks_native_3d'],
df[metric + '_stitch_native_2d'], df[metric + '_stitch_native_3d'],
df[metric + '_stitch_native_2d'], df[metric + '_stitch_native_3d'], #df[metric + '_stitch_native_3d_largeZ'],
df[metric + '_chunks_straight_2d'], df[metric + '_chunks_straight_3d'],
df[metric + '_stitch_straight_2d'], df[metric + '_stitch_straight_3d'])
df[metric + '_stitch_straight_2d'], df[metric + '_stitch_straight_3d'], #df[metric + '_stitch_straight_3d_largeZ']
)
logger.info(f'{metrics_short[metric]}: Kruskal-Wallis H-test: formatted p{format_pvalue(p)}, unformatted p={p:0.6f}')

if p < 0.05:
Expand Down

0 comments on commit 1e21f2c

Please sign in to comment.