From 0fb67d969a785f03807f1b78e208489dd8ee96cb Mon Sep 17 00:00:00 2001 From: hadign20 Date: Wed, 9 Oct 2024 16:54:00 -0400 Subject: [PATCH] added a function to save heatmaps --- nnunetv2/inference/export_prediction.py | 48 +++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/nnunetv2/inference/export_prediction.py b/nnunetv2/inference/export_prediction.py index f5cdb958d..b8e94d1d7 100644 --- a/nnunetv2/inference/export_prediction.py +++ b/nnunetv2/inference/export_prediction.py @@ -4,6 +4,7 @@ import numpy as np import torch +import matplotlib.pyplot as plt from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice from batchgenerators.utilities.file_and_folder_operations import load_json, isfile, save_pickle @@ -12,6 +13,47 @@ from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager + +def save_heatmap(probabilities, output_path): + # Check if the probabilities array is 4D and handle each channel and slice separately + if probabilities.ndim == 4: # Shape: (num_channels, num_slices, width, height) + num_channels = probabilities.shape[0] + num_slices = probabilities.shape[1] + output_dir = os.path.dirname(output_path) + + # Create a directory for the channel and slice images if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + for channel in range(num_channels): + for slice_idx in range(num_slices): + slice_output_path = os.path.join(output_path, f"channel_{channel}_slice_{slice_idx}.png") + plt.figure(figsize=(10, 8)) + plt.imshow(probabilities[channel, slice_idx], cmap='hot', interpolation='nearest', origin='lower') + plt.colorbar() + plt.savefig(slice_output_path) + plt.close() + elif probabilities.ndim == 3: # Handle the 3D case separately (num_slices, width, height) + num_slices = probabilities.shape[0] + output_dir = os.path.dirname(output_path) + + os.makedirs(output_dir, exist_ok=True) + + for i in range(num_slices): + slice_output_path = os.path.join(output_path, f"slice_{i}.png") + plt.figure(figsize=(10, 8)) + plt.imshow(probabilities[i], cmap='hot', interpolation='nearest', origin='lower') + plt.colorbar() + plt.savefig(slice_output_path) + plt.close() + else: + # If the array is already 2D, save it directly + plt.figure(figsize=(10, 8)) + plt.imshow(probabilities, cmap='hot', interpolation='nearest', origin='lower') + plt.colorbar() + plt.savefig(output_path) + plt.close() + + def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray], plans_manager: PlansManager, configuration_manager: ConfigurationManager, @@ -97,6 +139,12 @@ def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, tor segmentation_final, probabilities_final = ret np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final) save_pickle(properties_dict, output_file_truncated + '.pkl') + + # Save heatmaps of the probabilities + heatmap_output_path = f"{output_file_truncated}_heatmap" + os.makedirs(heatmap_output_path, exist_ok=True) + save_heatmap(probabilities_final, heatmap_output_path) + del probabilities_final, ret else: segmentation_final = ret