Skip to content

Commit

Permalink
Label Quality Assessment for Active Learning (#5065)
Browse files Browse the repository at this point in the history
Signed-off-by: Vishwesh Nath <[email protected]>


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Vishwesh Nath <[email protected]>
  • Loading branch information
finalelement authored Sep 9, 2022
1 parent 05de740 commit 15e15c6
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 4 deletions.
7 changes: 7 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ Metrics
.. autoclass:: VarianceMetric
:members:

`LabelQualityScore`
--------------------
.. autofunction:: label_quality_score

.. autoclass:: LabelQualityScore
:members:

`IterationMetric`
-----------------
.. autoclass:: IterationMetric
Expand Down
2 changes: 1 addition & 1 deletion monai/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .active_learning_metrics import VarianceMetric, compute_variance
from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
from .cumulative_average import CumulativeAverage
from .froc import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score
Expand Down
87 changes: 85 additions & 2 deletions monai/metrics/active_learning_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,41 @@ def __call__(self, y_pred: Any) -> Any: # type: ignore
)


class LabelQualityScore(Metric):
"""
The assumption is that the DL model makes better predictions than the provided label quality, hence the difference
can be treated as a label quality score
It can be combined with variance/uncertainty for active learning frameworks to factor in the quality of label along
with uncertainty
Args:
include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image
dimensions
scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used
"""

def __init__(self, include_background: bool = True, scalar_reduction: str = "sum") -> None:
super().__init__()
self.include_background = include_background
self.scalar_reduction = scalar_reduction

def __call__(self, y_pred: Any, y: Any): # type: ignore
"""
Args:
y_pred: Predicted segmentation, typically segmentation model output.
It must be N-repeats, repeat-first tensor [N,C,H,W,D].
Returns:
Pytorch tensor of scalar value of variance as uncertainty or a spatial map of uncertainty
"""
return label_quality_score(
y_pred=y_pred, y=y, include_background=self.include_background, scalar_reduction=self.scalar_reduction
)


def compute_variance(
y_pred: torch.Tensor,
include_background: bool = True,
Expand All @@ -77,9 +112,10 @@ def compute_variance(
"""
Args:
y_pred: [N, C, H, W, D] or [N, C, H, W] or [N, C, H] where N is repeats, C is channels and H, W, D stand for
Height, Width & Depth
Height, Width & Depth
include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image dimensions
spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image
dimensions
scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used
threshold: To avoid NaN's a threshold is used to replace zero's
Returns:
Expand Down Expand Up @@ -123,3 +159,50 @@ def compute_variance(
elif scalar_reduction == "sum":
var_sum = torch.sum(variance)
return var_sum


def label_quality_score(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, scalar_reduction: str = "mean"
):
"""
The assumption is that the DL model makes better predictions than the provided label quality, hence the difference
can be treated as a label quality score
Args:
y_pred: Input data of dimension [B, C, H, W, D] or [B, C, H, W] or [B, C, H] where B is Batch-size, C is
channels and H, W, D stand for Height, Width & Depth
y: Ground Truth of dimension [B, C, H, W, D] or [B, C, H, W] or [B, C, H] where B is Batch-size, C is channels
and H, W, D stand for Height, Width & Depth
include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used to retrieve a single scalar
value, if set to 'none' a spatial map will be returned
Returns:
A single scalar absolute difference value as score with a reduction based on sum/mean or the spatial map of
absolute difference
"""

# The background utils is only applicable here because instead of Batch-dimension we have repeats here
y_pred = y_pred.float()
y = y.float()

if not include_background:
y_pred, y = ignore_background(y_pred=y_pred, y=y)

n_len = len(y_pred.shape)
if n_len < 4 and scalar_reduction == "none":
warnings.warn("Reduction set to None, Spatial map return requires a 2D/3D image of B-Batchsize and C-channels")
return None

abs_diff_map = torch.abs(y_pred - y)

if scalar_reduction == "none":
return abs_diff_map

elif scalar_reduction != "none":
if scalar_reduction == "mean":
lbl_score_mean = torch.mean(abs_diff_map, dim=list(range(1, n_len)))
return lbl_score_mean
elif scalar_reduction == "sum":
lbl_score_sum = torch.sum(abs_diff_map, dim=list(range(1, n_len)))
return lbl_score_sum
1 change: 0 additions & 1 deletion tests/test_bundle_verify_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_verify(self, meta_file, config_file):
cmd = ["coverage", "run", "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file"]
cmd += [meta_file, "--config_file", config_file, "-n", "4", "--any", "16", "--args_file", def_args_file]
cmd += ["--device", "cpu", "--_meta_#network_data_format#inputs#image#spatial_shape", "[16,'*','2**p*n']"]
cmd += ["--network_def#_requires_", "$monai.config.print_debug_info()"]
command_line_tests(cmd)


Expand Down
130 changes: 130 additions & 0 deletions tests/test_label_quality_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.metrics import LabelQualityScore, label_quality_score

_device = "cuda:0" if torch.cuda.is_available() else "cpu"

# keep background, 1D Case
TEST_CASE_1 = [ # y_pred (3, 1, 3), expected out (0.0)
{
"y_pred": torch.tensor([[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]], device=_device),
"y": torch.tensor([[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]], device=_device),
"include_background": True,
"scalar_reduction": "sum",
},
[0.0, 0.0, 0.0],
]

# keep background, 2D Case
TEST_CASE_2 = [ # y_pred (1, 1, 2, 2), expected out (0.0)
{
"y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),
"y": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),
"include_background": True,
"scalar_reduction": "sum",
},
[0.0],
]

# keep background, 3D Case
TEST_CASE_3 = [ # y_pred (1, 1, 1, 2, 2), expected out (0.0)
{
"y_pred": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]]]]], device=_device),
"y": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]]]]], device=_device),
"include_background": True,
"scalar_reduction": "sum",
},
[0.0],
]

# keep background, 2D Case
TEST_CASE_4 = [ # y_pred (1, 1, 2, 2), expected out (0.0)
{
"y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),
"y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]]], device=_device),
"include_background": True,
"scalar_reduction": "sum",
},
[4.0],
]

TEST_CASE_5 = [ # y_pred (1, 1, 2, 2), expected out (0.0)
{
"y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),
"y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]]], device=_device),
"include_background": True,
"scalar_reduction": "mean",
},
[1.0],
]

# Spatial Map Test Case for 3D Case
TEST_CASE_6 = [ # y_pred (1, 1, 2, 2, 2), expected out all (0.0) map of 2x2x2
{
"y_pred": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]], device=_device),
"y": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]], device=_device),
"include_background": True,
"scalar_reduction": "none",
},
[[[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]],
]

# Spatial Map Test Case for 2D Case
TEST_CASE_7 = [ # y_pred (1, 1, 2, 2)
{
"y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=_device),
"y": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=_device),
"include_background": True,
"scalar_reduction": "none",
},
[[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]],
]


class TestLabelQualityScore(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_value(self, input_data, expected_value):
result = label_quality_score(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

@parameterized.expand([TEST_CASE_6, TEST_CASE_7])
def test_spatial_case(self, input_data, expected_value):
result = label_quality_score(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_value_class(self, input_data, expected_value):
vals = {}
vals["y_pred"] = input_data.pop("y_pred")
vals["y"] = input_data.pop("y")
comp_var = LabelQualityScore(**input_data)
result = comp_var(**vals)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)

@parameterized.expand([TEST_CASE_6, TEST_CASE_7])
def test_spatial_case_class(self, input_data, expected_value):
vals = {}
vals["y_pred"] = input_data.pop("y_pred")
vals["y"] = input_data.pop("y")
comp_var = LabelQualityScore(**input_data)
result = comp_var(**vals)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)


if __name__ == "__main__":
unittest.main()

0 comments on commit 15e15c6

Please sign in to comment.