From 15e15c64179ebe0c27564078cf4053af35cf6288 Mon Sep 17 00:00:00 2001 From: Vishwesh Date: Fri, 9 Sep 2022 10:27:51 -0500 Subject: [PATCH] Label Quality Assessment for Active Learning (#5065) Signed-off-by: Vishwesh Nath ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Vishwesh Nath --- docs/source/metrics.rst | 7 ++ monai/metrics/__init__.py | 2 +- monai/metrics/active_learning_metrics.py | 87 ++++++++++++++- tests/test_bundle_verify_net.py | 1 - tests/test_label_quality_score.py | 130 +++++++++++++++++++++++ 5 files changed, 223 insertions(+), 4 deletions(-) create mode 100644 tests/test_label_quality_score.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 24db04d062..9ba5fa0607 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -24,6 +24,13 @@ Metrics .. autoclass:: VarianceMetric :members: +`LabelQualityScore` +-------------------- +.. autofunction:: label_quality_score + +.. autoclass:: LabelQualityScore + :members: + `IterationMetric` ----------------- .. autoclass:: IterationMetric diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 8c4e148f5a..ac5de7e71b 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .active_learning_metrics import VarianceMetric, compute_variance +from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage from .froc import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score diff --git a/monai/metrics/active_learning_metrics.py b/monai/metrics/active_learning_metrics.py index 26c966ef52..ad62935f44 100644 --- a/monai/metrics/active_learning_metrics.py +++ b/monai/metrics/active_learning_metrics.py @@ -67,6 +67,41 @@ def __call__(self, y_pred: Any) -> Any: # type: ignore ) +class LabelQualityScore(Metric): + """ + The assumption is that the DL model makes better predictions than the provided label quality, hence the difference + can be treated as a label quality score + + It can be combined with variance/uncertainty for active learning frameworks to factor in the quality of label along + with uncertainty + Args: + include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector + spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image + dimensions + scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used + + """ + + def __init__(self, include_background: bool = True, scalar_reduction: str = "sum") -> None: + super().__init__() + self.include_background = include_background + self.scalar_reduction = scalar_reduction + + def __call__(self, y_pred: Any, y: Any): # type: ignore + """ + Args: + y_pred: Predicted segmentation, typically segmentation model output. + It must be N-repeats, repeat-first tensor [N,C,H,W,D]. + + Returns: + Pytorch tensor of scalar value of variance as uncertainty or a spatial map of uncertainty + + """ + return label_quality_score( + y_pred=y_pred, y=y, include_background=self.include_background, scalar_reduction=self.scalar_reduction + ) + + def compute_variance( y_pred: torch.Tensor, include_background: bool = True, @@ -77,9 +112,10 @@ def compute_variance( """ Args: y_pred: [N, C, H, W, D] or [N, C, H, W] or [N, C, H] where N is repeats, C is channels and H, W, D stand for - Height, Width & Depth + Height, Width & Depth include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector - spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image dimensions + spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image + dimensions scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used threshold: To avoid NaN's a threshold is used to replace zero's Returns: @@ -123,3 +159,50 @@ def compute_variance( elif scalar_reduction == "sum": var_sum = torch.sum(variance) return var_sum + + +def label_quality_score( + y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, scalar_reduction: str = "mean" +): + """ + The assumption is that the DL model makes better predictions than the provided label quality, hence the difference + can be treated as a label quality score + + Args: + y_pred: Input data of dimension [B, C, H, W, D] or [B, C, H, W] or [B, C, H] where B is Batch-size, C is + channels and H, W, D stand for Height, Width & Depth + y: Ground Truth of dimension [B, C, H, W, D] or [B, C, H, W] or [B, C, H] where B is Batch-size, C is channels + and H, W, D stand for Height, Width & Depth + include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector + scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used to retrieve a single scalar + value, if set to 'none' a spatial map will be returned + + Returns: + A single scalar absolute difference value as score with a reduction based on sum/mean or the spatial map of + absolute difference + """ + + # The background utils is only applicable here because instead of Batch-dimension we have repeats here + y_pred = y_pred.float() + y = y.float() + + if not include_background: + y_pred, y = ignore_background(y_pred=y_pred, y=y) + + n_len = len(y_pred.shape) + if n_len < 4 and scalar_reduction == "none": + warnings.warn("Reduction set to None, Spatial map return requires a 2D/3D image of B-Batchsize and C-channels") + return None + + abs_diff_map = torch.abs(y_pred - y) + + if scalar_reduction == "none": + return abs_diff_map + + elif scalar_reduction != "none": + if scalar_reduction == "mean": + lbl_score_mean = torch.mean(abs_diff_map, dim=list(range(1, n_len))) + return lbl_score_mean + elif scalar_reduction == "sum": + lbl_score_sum = torch.sum(abs_diff_map, dim=list(range(1, n_len))) + return lbl_score_sum diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py index 04b95a4828..e2be20a32b 100644 --- a/tests/test_bundle_verify_net.py +++ b/tests/test_bundle_verify_net.py @@ -36,7 +36,6 @@ def test_verify(self, meta_file, config_file): cmd = ["coverage", "run", "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file"] cmd += [meta_file, "--config_file", config_file, "-n", "4", "--any", "16", "--args_file", def_args_file] cmd += ["--device", "cpu", "--_meta_#network_data_format#inputs#image#spatial_shape", "[16,'*','2**p*n']"] - cmd += ["--network_def#_requires_", "$monai.config.print_debug_info()"] command_line_tests(cmd) diff --git a/tests/test_label_quality_score.py b/tests/test_label_quality_score.py new file mode 100644 index 0000000000..db31624a95 --- /dev/null +++ b/tests/test_label_quality_score.py @@ -0,0 +1,130 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.metrics import LabelQualityScore, label_quality_score + +_device = "cuda:0" if torch.cuda.is_available() else "cpu" + +# keep background, 1D Case +TEST_CASE_1 = [ # y_pred (3, 1, 3), expected out (0.0) + { + "y_pred": torch.tensor([[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]], device=_device), + "y": torch.tensor([[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]], device=_device), + "include_background": True, + "scalar_reduction": "sum", + }, + [0.0, 0.0, 0.0], +] + +# keep background, 2D Case +TEST_CASE_2 = [ # y_pred (1, 1, 2, 2), expected out (0.0) + { + "y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device), + "y": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device), + "include_background": True, + "scalar_reduction": "sum", + }, + [0.0], +] + +# keep background, 3D Case +TEST_CASE_3 = [ # y_pred (1, 1, 1, 2, 2), expected out (0.0) + { + "y_pred": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]]]]], device=_device), + "y": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]]]]], device=_device), + "include_background": True, + "scalar_reduction": "sum", + }, + [0.0], +] + +# keep background, 2D Case +TEST_CASE_4 = [ # y_pred (1, 1, 2, 2), expected out (0.0) + { + "y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device), + "y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]]], device=_device), + "include_background": True, + "scalar_reduction": "sum", + }, + [4.0], +] + +TEST_CASE_5 = [ # y_pred (1, 1, 2, 2), expected out (0.0) + { + "y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device), + "y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]]], device=_device), + "include_background": True, + "scalar_reduction": "mean", + }, + [1.0], +] + +# Spatial Map Test Case for 3D Case +TEST_CASE_6 = [ # y_pred (1, 1, 2, 2, 2), expected out all (0.0) map of 2x2x2 + { + "y_pred": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]], device=_device), + "y": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]], device=_device), + "include_background": True, + "scalar_reduction": "none", + }, + [[[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]], +] + +# Spatial Map Test Case for 2D Case +TEST_CASE_7 = [ # y_pred (1, 1, 2, 2) + { + "y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=_device), + "y": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=_device), + "include_background": True, + "scalar_reduction": "none", + }, + [[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]], +] + + +class TestLabelQualityScore(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_value(self, input_data, expected_value): + result = label_quality_score(**input_data) + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + @parameterized.expand([TEST_CASE_6, TEST_CASE_7]) + def test_spatial_case(self, input_data, expected_value): + result = label_quality_score(**input_data) + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_value_class(self, input_data, expected_value): + vals = {} + vals["y_pred"] = input_data.pop("y_pred") + vals["y"] = input_data.pop("y") + comp_var = LabelQualityScore(**input_data) + result = comp_var(**vals) + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + @parameterized.expand([TEST_CASE_6, TEST_CASE_7]) + def test_spatial_case_class(self, input_data, expected_value): + vals = {} + vals["y_pred"] = input_data.pop("y_pred") + vals["y"] = input_data.pop("y") + comp_var = LabelQualityScore(**input_data) + result = comp_var(**vals) + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + +if __name__ == "__main__": + unittest.main()