From d02ba11d8069870d71316a616f047c499627c71c Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Sat, 7 Sep 2024 05:56:46 -0400 Subject: [PATCH] Fix generalized dice computation (#7970) Fixes #7966 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Suraj Pai Signed-off-by: Suraj Pai Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/metrics/generalized_dice.py | 125 +++++++++++------- tests/test_compute_generalized_dice.py | 170 ++++++++++++++++++------- 2 files changed, 201 insertions(+), 94 deletions(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index e56bd46592..516021949b 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -14,34 +14,47 @@ import torch from monai.metrics.utils import do_metric_reduction, ignore_background -from monai.utils import MetricReduction, Weight, look_up_option +from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option from .metric import CumulativeIterationMetric class GeneralizedDiceScore(CumulativeIterationMetric): - """Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in: + """ + Compute the Generalized Dice Score metric between tensors. + This metric is the complement of the Generalized Dice Loss defined in: Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning - loss function for highly unbalanced segmentations. DLMIA 2017. + loss function for highly unbalanced segmentations. DLMIA 2017. - The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first - or batch-first tensors, i.e., CHW[D] or BCHW[D]. + The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D]. Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. Args: - include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the + include_background: Whether to include the background class (assumed to be in channel 0) in the score computation. Defaults to True. - reduction (str, optional): define mode of reduction to the metrics. Available reduction modes: - {``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction. - weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform + reduction: Define mode of reduction to the metrics. Available reduction modes: + {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. Raises: - ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}. + ValueError: When the `reduction` is not one of MetricReduction enum. """ + @deprecated_arg_default( + "reduction", + old_default=MetricReduction.MEAN_BATCH, + new_default=MetricReduction.MEAN, + since="1.4.0", + replaced="1.5.0", + msg_suffix=( + "Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, " + "If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'." + ), + ) def __init__( self, include_background: bool = True, @@ -50,79 +63,90 @@ def __init__( ) -> None: super().__init__() self.include_background = include_background - reduction_options = [ - "none", - "mean_batch", - "sum_batch", - MetricReduction.NONE, - MetricReduction.MEAN_BATCH, - MetricReduction.SUM_BATCH, - ] - self.reduction = reduction - if self.reduction not in reduction_options: - raise ValueError(f"reduction must be one of {reduction_options}") + self.reduction = look_up_option(reduction, MetricReduction) self.weight_type = look_up_option(weight_type, Weight) + self.sum_over_classes = self.reduction in { + MetricReduction.SUM, + MetricReduction.MEAN, + MetricReduction.MEAN_CHANNEL, + MetricReduction.SUM_CHANNEL, + } def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] - """Computes the Generalized Dice Score and returns a tensor with its per image values. + """ + Computes the Generalized Dice Score and returns a tensor with its per image values. Args: - y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, + y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. - y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + + Returns: + torch.Tensor: Generalized Dice Score averaged across batch and class Raises: - ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape. + ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape. """ return compute_generalized_dice( - y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type + y_pred=y_pred, + y=y, + include_background=self.include_background, + weight_type=self.weight_type, + sum_over_classes=self.sum_over_classes, ) + @deprecated_arg( + "reduction", + since="1.3.3", + removed="1.7.0", + msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute", + ) def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor: """ Execute reduction logic for the output of `compute_generalized_dice`. - Args: - reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics. - Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}. - Defaults to ``"mean"``. If "none", will not do reduction. + Returns: + torch.Tensor: Aggregated metric value. + + Raises: + ValueError: If the data to aggregate is not a PyTorch Tensor. """ data = self.get_buffer() if not isinstance(data, torch.Tensor): raise ValueError("The data to aggregate must be a PyTorch Tensor.") - # Validate reduction argument if specified - if reduction is not None: - reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"] - if reduction not in reduction_options: - raise ValueError(f"reduction must be one of {reduction_options}") - # Do metric reduction and return - f, _ = do_metric_reduction(data, reduction or self.reduction) + f, _ = do_metric_reduction(data, self.reduction) return f def compute_generalized_dice( - y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + weight_type: Weight | str = Weight.SQUARE, + sum_over_classes: bool = False, ) -> torch.Tensor: - """Computes the Generalized Dice Score and returns a tensor with its per image values. + """ + Computes the Generalized Dice Score and returns a tensor with its per image values. Args: - y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format + y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. - y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. - include_background (bool, optional): whether to include score computation on the first channel of the + y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. + include_background: Whether to include score computation on the first channel of the predicted output. Defaults to True. weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. + sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation. Returns: - torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. + torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. Raises: - ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions, + ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions, or `y_pred` and `y` don't have the same shape. """ # Ensure tensors have at least 3 dimensions and have the same shape @@ -158,16 +182,21 @@ def compute_generalized_dice( b[infs] = 0 b[infs] = torch.max(b) - # Compute the weighted numerator and denominator, summing along the class axis - numer = 2.0 * (intersection * w).sum(dim=1) - denom = (denominator * w).sum(dim=1) + # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True + if sum_over_classes: + numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True) + denom = (denominator * w).sum(dim=1, keepdim=True) + y_pred_o = y_pred_o.sum(dim=-1, keepdim=True) + else: + numer = 2.0 * (intersection * w) + denom = denominator * w + y_pred_o = y_pred_o # Compute the score generalized_dice_score = numer / denom # Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1. # Where denom == 0 but the prediction volume is not 0, score is 0 - y_pred_o = y_pred_o.sum(dim=-1) denom_zeros = denom == 0 generalized_dice_score[denom_zeros] = torch.where( (y_pred_o == 0)[denom_zeros], diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py index e04444e988..985a01e993 100644 --- a/tests/test_compute_generalized_dice.py +++ b/tests/test_compute_generalized_dice.py @@ -22,17 +22,17 @@ _device = "cuda:0" if torch.cuda.is_available() else "cpu" # keep background -TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1) +TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1) with compute_generalized_dice { "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device), "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device), "include_background": True, }, - [0.8], + [[0.8]], ] # remove background -TEST_CASE_2 = [ # y (2, 1, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) +TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -47,32 +47,32 @@ ] ), "include_background": False, + "reduction": "mean_batch", }, - [0.1667, 0.6667], + [0.583333, 0.333333], ] -# should return 0 for both cases -TEST_CASE_3 = [ +TEST_CASE_3 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (1) with GeneralizedDiceScore { "y_pred": torch.tensor( [ - [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]], - [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]], + [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], + [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], ] ), "y": torch.tensor( [ [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], ] ), "include_background": True, + "reduction": "mean", }, - [0.0, 0.0], + [0.5454], ] -TEST_CASE_4 = [ - {"include_background": True, "reduction": "mean_batch"}, +TEST_CASE_4 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (1) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -83,15 +83,36 @@ "y": torch.tensor( [ [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], ] ), + "include_background": True, + "reduction": "sum", }, - [0.5455], + [1.045455], +] + +TEST_CASE_5 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice + {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, + [[1.0000, 1.0000], [1.0000, 1.0000]], ] -TEST_CASE_5 = [ - {"include_background": True, "reduction": "sum_batch"}, +TEST_CASE_6 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice + {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, + [[0.0000, 0.0000], [0.0000, 0.0000]], +] + +TEST_CASE_7 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice + {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, + [[0.0000, 0.0000], [0.0000, 0.0000]], +] + +TEST_CASE_8 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice + {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, + [[1.0000, 1.0000], [1.0000, 1.0000]], +] + +TEST_CASE_9 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -102,61 +123,118 @@ "y": torch.tensor( [ [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], ] ), + "include_background": True, + "reduction": "mean_channel", }, - 1.0455, + [0.545455, 0.545455], ] -TEST_CASE_6 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [1.0000, 1.0000]] -TEST_CASE_7 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [0.0000, 0.0000]] - -TEST_CASE_8 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [0.0000, 0.0000]] +TEST_CASE_10 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 3) with compute_generalized_dice + # and (3) with GeneralizedDiceScore "mean_batch" + { + "y_pred": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], + [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], + ] + ), + "y": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], + ] + ), + "include_background": True, + }, + [[0.857143, 0.0, 0.0], [0.5, 0.4, 0.666667]], +] -TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]] +TEST_CASE_11 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 1) with compute_generalized_dice (summed over classes) + # and (2) with GeneralizedDiceScore "mean_channel" + { + "y_pred": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], + [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], + ] + ), + "y": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], + ] + ), + "include_background": True, + "sum_over_classes": True, + }, + [[0.545455], [0.545455]], +] class TestComputeGeneralizedDiceScore(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) def test_device(self, input_data, _expected_value): + """ + Test if the result tensor is on the same device as the input tensor. + """ result = compute_generalized_dice(**input_data) np.testing.assert_equal(result.device, input_data["y_pred"].device) - # Functional part tests - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]) def test_value(self, input_data, expected_value): + """ + Test if the computed generalized dice score matches the expected value. + """ result = compute_generalized_dice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) - # Functional part tests - @parameterized.expand([TEST_CASE_3]) - def test_nans(self, input_data, expected_value): - result = compute_generalized_dice(**input_data) - self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value)) - - # Samplewise tests - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_9]) def test_value_class(self, input_data, expected_value): - # same test as for compute_meandice - vals = {} - vals["y_pred"] = input_data.pop("y_pred") - vals["y"] = input_data.pop("y") + """ + Test if the GeneralizedDiceScore class computes the correct values. + """ + y_pred = input_data.pop("y_pred") + y = input_data.pop("y") generalized_dice_score = GeneralizedDiceScore(**input_data) - generalized_dice_score(**vals) - result = generalized_dice_score.aggregate(reduction="none") + generalized_dice_score(y_pred=y_pred, y=y) + result = generalized_dice_score.aggregate() np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) - # Aggregation tests - @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) - def test_nans_class(self, params, input_data, expected_value): - generalized_dice_score = GeneralizedDiceScore(**params) - generalized_dice_score(**input_data) - result = generalized_dice_score.aggregate() + @parameterized.expand([TEST_CASE_10]) + def test_values_compare(self, input_data, expected_value): + """ + Compare the results of compute_generalized_dice function and GeneralizedDiceScore class. + """ + result = compute_generalized_dice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + y_pred = input_data.pop("y_pred") + y = input_data.pop("y") + generalized_dice_score = GeneralizedDiceScore(**input_data, reduction="mean_batch") + generalized_dice_score(y_pred=y_pred, y=y) + result_class_mean = generalized_dice_score.aggregate() + np.testing.assert_allclose(result_class_mean.cpu().numpy(), np.mean(expected_value, axis=0), atol=1e-4) + + @parameterized.expand([TEST_CASE_11]) + def test_values_compare_sum_over_classes(self, input_data, expected_value): + """ + Compare the results when summing over classes between compute_generalized_dice function and GeneralizedDiceScore class. + """ + result = compute_generalized_dice(**input_data) + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + y_pred = input_data.pop("y_pred") + y = input_data.pop("y") + input_data.pop("sum_over_classes") + generalized_dice_score = GeneralizedDiceScore(**input_data, reduction="mean_channel") + generalized_dice_score(y_pred=y_pred, y=y) + result_class_mean = generalized_dice_score.aggregate() + np.testing.assert_allclose(result_class_mean.cpu().numpy(), np.mean(expected_value, axis=1), atol=1e-4) + if __name__ == "__main__": unittest.main()