Skip to content

Commit

Permalink
all close (#2829)
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Brown <[email protected]>
  • Loading branch information
rijobro authored Aug 24, 2021
1 parent fab8467 commit b4def1a
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 57 deletions.
5 changes: 2 additions & 3 deletions tests/test_fill_holes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from parameterized import parameterized

from monai.transforms import FillHoles
from tests.utils import allclose, clone
from tests.utils import assert_allclose, clone

grid_1_raw = [
[1, 1, 1],
Expand Down Expand Up @@ -278,10 +278,9 @@ def test_correct_results(self, _, args, input_image, expected):
converter = FillHoles(**args)
if isinstance(input_image, torch.Tensor) and torch.cuda.is_available():
result = converter(clone(input_image).cuda())
assert allclose(result, expected.cuda())
else:
result = converter(clone(input_image))
assert allclose(result, expected)
assert_allclose(result, expected)

@parameterized.expand(INVALID_CASES)
def test_raise_exception(self, _, args, input_image, expected_error):
Expand Down
7 changes: 2 additions & 5 deletions tests/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import Flip
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose

INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)]

Expand All @@ -40,9 +39,7 @@ def test_correct_results(self, _, spatial_axis):
expected.append(np.flip(channel, spatial_axis))
expected = np.stack(expected)
result = flip(im)
if isinstance(result, torch.Tensor):
result = result.cpu()
self.assertTrue(np.allclose(expected, result))
assert_allclose(expected, result)


if __name__ == "__main__":
Expand Down
7 changes: 2 additions & 5 deletions tests/test_flipd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import Flipd
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose

INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)]

Expand All @@ -39,9 +38,7 @@ def test_correct_results(self, _, spatial_axis):
expected.append(np.flip(channel, spatial_axis))
expected = np.stack(expected)
result = flip({"img": p(self.imt[0])})["img"]
if isinstance(result, torch.Tensor):
result = result.cpu()
assert np.allclose(expected, result)
assert_allclose(expected, result)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions tests/test_keep_largest_connected_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from parameterized import parameterized

from monai.transforms import KeepLargestConnectedComponent
from tests.utils import allclose, clone
from tests.utils import assert_allclose, clone

grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]])
grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]])
Expand Down Expand Up @@ -327,10 +327,10 @@ def test_correct_results(self, _, args, input_image, expected):
converter = KeepLargestConnectedComponent(**args)
if isinstance(input_image, torch.Tensor) and torch.cuda.is_available():
result = converter(clone(input_image).cuda())
assert allclose(result, expected.cuda())

else:
result = converter(clone(input_image))
assert allclose(result, expected)
assert_allclose(result, expected)

@parameterized.expand(INVALID_CASES)
def test_raise_exception(self, _, args, input_image, expected_error):
Expand Down
5 changes: 2 additions & 3 deletions tests/test_label_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from parameterized import parameterized

from monai.transforms import LabelFilter
from tests.utils import allclose, clone
from tests.utils import assert_allclose, clone

grid_1 = torch.tensor(
[
Expand Down Expand Up @@ -108,10 +108,9 @@ def test_correct_results(self, _, args, input_image, expected):
converter = LabelFilter(**args)
if isinstance(input_image, torch.Tensor) and torch.cuda.is_available():
result = converter(clone(input_image).cuda())
assert allclose(result, expected.cuda())
else:
result = converter(clone(input_image))
assert allclose(result, expected)
assert_allclose(result, expected)

@parameterized.expand(INVALID_CASES)
def test_raise_exception(self, _, args, input_image, expected_error):
Expand Down
8 changes: 2 additions & 6 deletions tests/test_rand_axis_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,20 @@
import unittest

import numpy as np
import torch

from monai.transforms import RandAxisFlip
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


class TestRandAxisFlip(NumpyImageTestCase2D):
def test_correct_results(self):
for p in TEST_NDARRAYS:
flip = RandAxisFlip(prob=1.0)
result = flip(p(self.imt[0]))
if isinstance(result, torch.Tensor):
result = result.cpu()

expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, flip._axis))
self.assertTrue(np.allclose(np.stack(expected), result))
assert_allclose(np.stack(expected), result)


if __name__ == "__main__":
Expand Down
7 changes: 2 additions & 5 deletions tests/test_rand_axis_flipd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,21 @@
import unittest

import numpy as np
import torch

from monai.transforms import RandAxisFlipd
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D, assert_allclose


class TestRandAxisFlip(NumpyImageTestCase3D):
def test_correct_results(self):
for p in TEST_NDARRAYS:
flip = RandAxisFlipd(keys="img", prob=1.0)
result = flip({"img": p(self.imt[0])})["img"]
if isinstance(result, torch.Tensor):
result = result.cpu()

expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, flip._axis))
self.assertTrue(np.allclose(np.stack(expected), result))
assert_allclose(np.stack(expected), result)


if __name__ == "__main__":
Expand Down
7 changes: 2 additions & 5 deletions tests/test_rand_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import RandFlip
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose

INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)]

Expand All @@ -40,9 +39,7 @@ def test_correct_results(self, _, spatial_axis):
expected.append(np.flip(channel, spatial_axis))
expected = np.stack(expected)
result = flip(im)
if isinstance(result, torch.Tensor):
result = result.cpu()
self.assertTrue(np.allclose(expected, result))
assert_allclose(expected, result)


if __name__ == "__main__":
Expand Down
7 changes: 2 additions & 5 deletions tests/test_rand_flipd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import RandFlipd
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose

VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1])]

Expand All @@ -27,13 +26,11 @@ def test_correct_results(self, _, spatial_axis):
for p in TEST_NDARRAYS:
flip = RandFlipd(keys="img", prob=1.0, spatial_axis=spatial_axis)
result = flip({"img": p(self.imt[0])})["img"]
if isinstance(result, torch.Tensor):
result = result.cpu()
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = np.stack(expected)
self.assertTrue(np.allclose(expected, result))
assert_allclose(expected, result)


if __name__ == "__main__":
Expand Down
25 changes: 8 additions & 17 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from monai.config import NdarrayTensor
from monai.config.deviceconfig import USE_COMPILED
from monai.config.type_definitions import NdarrayOrTensor
from monai.data import create_test_image_2d, create_test_image_3d
from monai.utils import ensure_tuple, optional_import, set_determinism
from monai.utils.module import version_leq
Expand All @@ -55,27 +56,17 @@ def clone(data: NdarrayTensor) -> NdarrayTensor:
return copy.deepcopy(data)


def allclose(a: NdarrayTensor, b: NdarrayTensor) -> bool:
def assert_allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, *args, **kwargs):
"""
Check if all values of two data objects are close.
Note:
This method also checks that both data objects are either Pytorch Tensors or numpy arrays.
Assert that all values of two data objects are close.
Args:
a (NdarrayTensor): Pytorch Tensor or numpy array for comparison
b (NdarrayTensor): Pytorch Tensor or numpy array to compare against
Returns:
bool: If both data objects are close.
a (NdarrayOrTensor): Pytorch Tensor or numpy array for comparison
b (NdarrayOrTensor): Pytorch Tensor or numpy array to compare against
"""
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
return np.allclose(a, b)

if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
return torch.allclose(a, b)

return False
a = a.cpu() if isinstance(a, torch.Tensor) else a
b = b.cpu() if isinstance(b, torch.Tensor) else b
np.testing.assert_allclose(a, b, *args, **kwargs)


def test_pretrained_networks(network, input_param, device):
Expand Down

0 comments on commit b4def1a

Please sign in to comment.