Skip to content

Commit

Permalink
Normalize intensity (#2831)
Browse files Browse the repository at this point in the history
* all close

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>

* assert_allclose

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>

* NormalizeIntensity

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>
rijobro authored Aug 24, 2021
1 parent b4def1a commit fe559e5
Showing 4 changed files with 189 additions and 95 deletions.
52 changes: 40 additions & 12 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
@@ -539,10 +539,12 @@ class NormalizeIntensity(Transform):
dtype: output data type, defaults to float32.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(
self,
subtrahend: Union[Sequence, np.ndarray, None] = None,
divisor: Union[Sequence, np.ndarray, None] = None,
subtrahend: Union[Sequence, NdarrayOrTensor, None] = None,
divisor: Union[Sequence, NdarrayOrTensor, None] = None,
nonzero: bool = False,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
@@ -553,26 +555,51 @@ def __init__(
self.channel_wise = channel_wise
self.dtype = dtype

def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray:
slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=bool)
if not np.any(slices):
@staticmethod
def _mean(x):
if isinstance(x, np.ndarray):
return np.mean(x)
x = torch.mean(x.float())
return x.item() if x.numel() == 1 else x

@staticmethod
def _std(x):
if isinstance(x, np.ndarray):
return np.std(x)
x = torch.std(x.float(), unbiased=False)
return x.item() if x.numel() == 1 else x

def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor:
img, *_ = convert_data_type(img, dtype=torch.float32)

if self.nonzero:
slices = img != 0
else:
if isinstance(img, np.ndarray):
slices = np.ones_like(img, dtype=bool)
else:
slices = torch.ones_like(img, dtype=torch.bool)
if not slices.any():
return img

_sub = sub if sub is not None else np.mean(img[slices])
if isinstance(_sub, np.ndarray):
_sub = sub if sub is not None else self._mean(img[slices])
if isinstance(_sub, (torch.Tensor, np.ndarray)):
_sub, *_ = convert_to_dst_type(_sub, img)
_sub = _sub[slices]

_div = div if div is not None else np.std(img[slices])
_div = div if div is not None else self._std(img[slices])
if np.isscalar(_div):
if _div == 0.0:
_div = 1.0
elif isinstance(_div, np.ndarray):
elif isinstance(_div, (torch.Tensor, np.ndarray)):
_div, *_ = convert_to_dst_type(_div, img)
_div = _div[slices]
_div[_div == 0.0] = 1.0

img[slices] = (img[slices] - _sub) / _div
return img

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True,
"""
@@ -583,15 +610,16 @@ def __call__(self, img: np.ndarray) -> np.ndarray:
raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")

for i, d in enumerate(img):
img[i] = self._normalize(
img[i] = self._normalize( # type: ignore
d,
sub=self.subtrahend[i] if self.subtrahend is not None else None,
div=self.divisor[i] if self.divisor is not None else None,
)
else:
img = self._normalize(img, self.subtrahend, self.divisor)

return img.astype(self.dtype)
out, *_ = convert_data_type(img, dtype=self.dtype)
return out


class ThresholdIntensity(Transform):
8 changes: 5 additions & 3 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
@@ -612,11 +612,13 @@ class NormalizeIntensityd(MapTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = NormalizeIntensity.backend

def __init__(
self,
keys: KeysCollection,
subtrahend: Optional[np.ndarray] = None,
divisor: Optional[np.ndarray] = None,
subtrahend: Optional[NdarrayOrTensor] = None,
divisor: Optional[NdarrayOrTensor] = None,
nonzero: bool = False,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
@@ -625,7 +627,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.normalizer(d[key])
139 changes: 90 additions & 49 deletions tests/test_normalize_intensity.py
Original file line number Diff line number Diff line change
@@ -12,70 +12,111 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import NormalizeIntensity
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose

TEST_CASES = [
[{"nonzero": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])],
[
{"subtrahend": np.array([3.5, 3.5, 3.5, 3.5]), "divisor": np.array([0.5, 0.5, 0.5, 0.5]), "nonzero": True},
np.array([0.0, 3.0, 0.0, 4.0]),
np.array([0.0, -1.0, 0.0, 1.0]),
],
[{"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])],
[{"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])],
[{"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])],
[
{"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]},
np.ones((3, 2, 2)),
np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]),
],
[
{"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]},
np.ones((3, 2, 2)),
np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]),
],
[
{"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0},
np.ones((3, 2, 2)),
np.ones((3, 2, 2)) * -1.0,
],
[
{"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0},
np.ones((3, 2, 2)),
np.ones((3, 2, 2)) * 0.5,
],
[
{"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]},
np.ones((3, 2, 2)),
np.ones((3, 2, 2)) * 0.5,
],
]
TESTS = []
for p in TEST_NDARRAYS:
TESTS.append([p, {"nonzero": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])])
for q in TEST_NDARRAYS:
for u in TEST_NDARRAYS:
TESTS.append(
[
p,
{
"subtrahend": q(np.array([3.5, 3.5, 3.5, 3.5])),
"divisor": u(np.array([0.5, 0.5, 0.5, 0.5])),
"nonzero": True,
},
np.array([0.0, 3.0, 0.0, 4.0]),
np.array([0.0, -1.0, 0.0, 1.0]),
]
)
TESTS.append([p, {"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])])
TESTS.append([p, {"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])])
TESTS.append([p, {"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])])
TESTS.append(
[
p,
{"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]},
np.ones((3, 2, 2)),
np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]),
]
)
TESTS.append(
[
p,
{"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]},
np.ones((3, 2, 2)),
np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]),
]
)
TESTS.append(
[
p,
{"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0},
np.ones((3, 2, 2)),
np.ones((3, 2, 2)) * -1.0,
]
)
TESTS.append(
[
p,
{"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0},
np.ones((3, 2, 2)),
np.ones((3, 2, 2)) * 0.5,
]
)
TESTS.append(
[
p,
{"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]},
np.ones((3, 2, 2)),
np.ones((3, 2, 2)) * 0.5,
]
)


class TestNormalizeIntensity(NumpyImageTestCase2D):
def test_default(self):
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_default(self, im_type):
im = im_type(self.imt.copy())
normalizer = NormalizeIntensity()
normalized = normalizer(self.imt.copy())
self.assertTrue(normalized.dtype == np.float32)
normalized = normalizer(im)
self.assertEqual(type(im), type(normalized))
if isinstance(normalized, torch.Tensor):
self.assertEqual(im.device, normalized.device)
self.assertTrue(normalized.dtype in (np.float32, torch.float32))
expected = (self.imt - np.mean(self.imt)) / np.std(self.imt)
np.testing.assert_allclose(normalized, expected, rtol=1e-3)
assert_allclose(expected, normalized, rtol=1e-3)

@parameterized.expand(TEST_CASES)
def test_nonzero(self, input_param, input_data, expected_data):
@parameterized.expand(TESTS)
def test_nonzero(self, in_type, input_param, input_data, expected_data):
normalizer = NormalizeIntensity(**input_param)
np.testing.assert_allclose(expected_data, normalizer(input_data))
im = in_type(input_data)
normalized = normalizer(im)
self.assertEqual(type(im), type(normalized))
if isinstance(normalized, torch.Tensor):
self.assertEqual(im.device, normalized.device)
assert_allclose(expected_data, normalized)

def test_channel_wise(self):
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_channel_wise(self, im_type):
normalizer = NormalizeIntensity(nonzero=True, channel_wise=True)
input_data = np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])
input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))
expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]])
np.testing.assert_allclose(expected, normalizer(input_data))
normalized = normalizer(input_data)
self.assertEqual(type(input_data), type(normalized))
if isinstance(normalized, torch.Tensor):
self.assertEqual(input_data.device, normalized.device)
assert_allclose(expected, normalized)

def test_value_errors(self):
input_data = np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_value_errors(self, im_type):
input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))
normalizer = NormalizeIntensity(nonzero=True, channel_wise=True, subtrahend=[1])
with self.assertRaises(ValueError):
normalizer(input_data)
85 changes: 54 additions & 31 deletions tests/test_normalize_intensityd.py
Original file line number Diff line number Diff line change
@@ -12,54 +12,77 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import NormalizeIntensityd
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose

TEST_CASE_1 = [
{"keys": ["img"], "nonzero": True},
{"img": np.array([0.0, 3.0, 0.0, 4.0])},
np.array([0.0, -1.0, 0.0, 1.0]),
]

TEST_CASE_2 = [
{
"keys": ["img"],
"subtrahend": np.array([3.5, 3.5, 3.5, 3.5]),
"divisor": np.array([0.5, 0.5, 0.5, 0.5]),
"nonzero": True,
},
{"img": np.array([0.0, 3.0, 0.0, 4.0])},
np.array([0.0, -1.0, 0.0, 1.0]),
]

TEST_CASE_3 = [
{"keys": ["img"], "nonzero": True},
{"img": np.array([0.0, 0.0, 0.0, 0.0])},
np.array([0.0, 0.0, 0.0, 0.0]),
]
TESTS = []
for p in TEST_NDARRAYS:
for q in TEST_NDARRAYS:
TESTS.append(
[
{"keys": ["img"], "nonzero": True},
{"img": p(np.array([0.0, 3.0, 0.0, 4.0]))},
np.array([0.0, -1.0, 0.0, 1.0]),
]
)
TESTS.append(
[
{
"keys": ["img"],
"subtrahend": q(np.array([3.5, 3.5, 3.5, 3.5])),
"divisor": q(np.array([0.5, 0.5, 0.5, 0.5])),
"nonzero": True,
},
{"img": p(np.array([0.0, 3.0, 0.0, 4.0]))},
np.array([0.0, -1.0, 0.0, 1.0]),
]
)
TESTS.append(
[
{"keys": ["img"], "nonzero": True},
{"img": p(np.array([0.0, 0.0, 0.0, 0.0]))},
np.array([0.0, 0.0, 0.0, 0.0]),
]
)


class TestNormalizeIntensityd(NumpyImageTestCase2D):
def test_image_normalize_intensityd(self):
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_image_normalize_intensityd(self, im_type):
key = "img"
im = im_type(self.imt)
normalizer = NormalizeIntensityd(keys=[key])
normalized = normalizer({key: self.imt})
normalized = normalizer({key: im})[key]
expected = (self.imt - np.mean(self.imt)) / np.std(self.imt)
np.testing.assert_allclose(normalized[key], expected, rtol=1e-3)
self.assertEqual(type(im), type(normalized))
if isinstance(normalized, torch.Tensor):
self.assertEqual(im.device, normalized.device)
assert_allclose(normalized, expected, rtol=1e-3)

@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
@parameterized.expand(TESTS)
def test_nonzero(self, input_param, input_data, expected_data):
key = "img"
normalizer = NormalizeIntensityd(**input_param)
np.testing.assert_allclose(expected_data, normalizer(input_data)["img"])
normalized = normalizer(input_data)[key]
self.assertEqual(type(input_data[key]), type(normalized))
if isinstance(normalized, torch.Tensor):
self.assertEqual(input_data[key].device, normalized.device)
assert_allclose(normalized, expected_data)

def test_channel_wise(self):
@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_channel_wise(self, im_type):
key = "img"
normalizer = NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True)
input_data = {key: np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])}
input_data = {key: im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))}
normalized = normalizer(input_data)[key]
self.assertEqual(type(input_data[key]), type(normalized))
if isinstance(normalized, torch.Tensor):
self.assertEqual(input_data[key].device, normalized.device)
expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]])
np.testing.assert_allclose(expected, normalizer(input_data)[key])
assert_allclose(normalized, expected)


if __name__ == "__main__":

0 comments on commit fe559e5

Please sign in to comment.