diff --git a/GANDLF/cli/generate_metrics.py b/GANDLF/cli/generate_metrics.py index cf64f991f..7a0e2c5ca 100644 --- a/GANDLF/cli/generate_metrics.py +++ b/GANDLF/cli/generate_metrics.py @@ -1,3 +1,4 @@ +import sys import yaml from pprint import pprint import pandas as pd @@ -258,11 +259,14 @@ def __fix_2d_tensor(input_tensor): gt_image_infill, output_infill ).item() - # PSNR - similar to pytorch PeakSignalNoiseRatio until 4 digits after decimal point overall_stats_dict[current_subject_id]["psnr"] = peak_signal_noise_ratio( gt_image_infill, output_infill ).item() + overall_stats_dict[current_subject_id]["psnr_eps"] = peak_signal_noise_ratio( + gt_image_infill, output_infill, epsilon=sys.float_info.epsilon + ).item() + pprint(overall_stats_dict) if outputfile is not None: with open(outputfile, "w") as outfile: diff --git a/GANDLF/metrics/synthesis.py b/GANDLF/metrics/synthesis.py index 3321b121a..fd3590a58 100644 --- a/GANDLF/metrics/synthesis.py +++ b/GANDLF/metrics/synthesis.py @@ -1,4 +1,3 @@ -import sys import SimpleITK as sitk import PIL.Image import numpy as np @@ -8,6 +7,7 @@ MeanSquaredError, MeanSquaredLogError, MeanAbsoluteError, + PeakSignalNoiseRatio, ) from GANDLF.utils import get_image_from_tensor @@ -25,7 +25,7 @@ def structural_similarity_index(target, prediction, mask=None) -> torch.Tensor: torch.Tensor: The structural similarity index. """ ssim = StructuralSimilarityIndexMeasure(return_full_image=True) - _, ssim_idx_full_image = ssim(target, prediction) + _, ssim_idx_full_image = ssim(preds=prediction, target=target) mask = torch.ones_like(ssim_idx_full_image) if mask is None else mask try: ssim_idx = ssim_idx_full_image[mask] @@ -45,23 +45,30 @@ def mean_squared_error(target, prediction) -> torch.Tensor: prediction (torch.Tensor): The prediction tensor. """ mse = MeanSquaredError() - return mse(target, prediction) + return mse(preds=prediction, target=target) -def peak_signal_noise_ratio(target, prediction) -> torch.Tensor: +def peak_signal_noise_ratio(target, prediction, data_range=None, epsilon=None) -> torch.Tensor: """ Computes the peak signal to noise ratio between the target and prediction. Args: target (torch.Tensor): The target tensor. prediction (torch.Tensor): The prediction tensor. + data_range (float, optional): If not None, this data range is used as enumerator instead of computing it from the given data. Defaults to None. + epsilon (float, optional): If not None, this epsilon is added to the denominator of the fraction to avoid infinity as output. Defaults to None. """ - mse = mean_squared_error(target, prediction) - return ( - 10.0 - * torch.log10((torch.max(target) - torch.min(target)) ** 2) - / (mse + sys.float_info.epsilon) - ) + + if epsilon == None: + psnr = PeakSignalNoiseRatio(data_range=data_range) + return psnr(preds=prediction, target=target) + else: # implementation of PSNR that does not give 'inf'/'nan' when 'mse==0' + mse = mean_squared_error(target, prediction) + if data_range == None: #compute data_range like torchmetrics if not given + min_v = 0 if torch.min(target) > 0 else torch.min(target) #look at this line + max_v = torch.max(target) + data_range = max_v - min_v + return 10.0 * torch.log10((data_range ** 2) / (mse + epsilon)) def mean_squared_log_error(target, prediction) -> torch.Tensor: @@ -73,7 +80,7 @@ def mean_squared_log_error(target, prediction) -> torch.Tensor: prediction (torch.Tensor): The prediction tensor. """ mle = MeanSquaredLogError() - return mle(target, prediction) + return mle(preds=prediction, target=target) def mean_absolute_error(target, prediction) -> torch.Tensor: @@ -85,7 +92,7 @@ def mean_absolute_error(target, prediction) -> torch.Tensor: prediction (torch.Tensor): The prediction tensor. """ mae = MeanAbsoluteError() - return mae(target, prediction) + return mae(preds=prediction, target=target) def _get_ncc_image(target, prediction) -> sitk.Image: