From b4def1a0d2f70ce709a5fc0a0f26a37f8af6f831 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 24 Aug 2021 18:19:16 +0100 Subject: [PATCH] all close (#2829) Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_fill_holes.py | 5 ++-- tests/test_flip.py | 7 ++---- tests/test_flipd.py | 7 ++---- .../test_keep_largest_connected_component.py | 6 ++--- tests/test_label_filter.py | 5 ++-- tests/test_rand_axis_flip.py | 8 ++---- tests/test_rand_axis_flipd.py | 7 ++---- tests/test_rand_flip.py | 7 ++---- tests/test_rand_flipd.py | 7 ++---- tests/utils.py | 25 ++++++------------- 10 files changed, 27 insertions(+), 57 deletions(-) diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py index 294bbd8c87..6ea83c239b 100644 --- a/tests/test_fill_holes.py +++ b/tests/test_fill_holes.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import FillHoles -from tests.utils import allclose, clone +from tests.utils import assert_allclose, clone grid_1_raw = [ [1, 1, 1], @@ -278,10 +278,9 @@ def test_correct_results(self, _, args, input_image, expected): converter = FillHoles(**args) if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): result = converter(clone(input_image).cuda()) - assert allclose(result, expected.cuda()) else: result = converter(clone(input_image)) - assert allclose(result, expected) + assert_allclose(result, expected) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_flip.py b/tests/test_flip.py index bd0162fb8b..404a3def7d 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -12,11 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import Flip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -40,9 +39,7 @@ def test_correct_results(self, _, spatial_axis): expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) result = flip(im) - if isinstance(result, torch.Tensor): - result = result.cpu() - self.assertTrue(np.allclose(expected, result)) + assert_allclose(expected, result) if __name__ == "__main__": diff --git a/tests/test_flipd.py b/tests/test_flipd.py index cec4a99cbf..1676723800 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -12,11 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import Flipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -39,9 +38,7 @@ def test_correct_results(self, _, spatial_axis): expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) result = flip({"img": p(self.imt[0])})["img"] - if isinstance(result, torch.Tensor): - result = result.cpu() - assert np.allclose(expected, result) + assert_allclose(expected, result) if __name__ == "__main__": diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index 670dd2d2ee..527d986614 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponent -from tests.utils import allclose, clone +from tests.utils import assert_allclose, clone grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]) grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]) @@ -327,10 +327,10 @@ def test_correct_results(self, _, args, input_image, expected): converter = KeepLargestConnectedComponent(**args) if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): result = converter(clone(input_image).cuda()) - assert allclose(result, expected.cuda()) + else: result = converter(clone(input_image)) - assert allclose(result, expected) + assert_allclose(result, expected) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_label_filter.py b/tests/test_label_filter.py index 9165fddc40..c699fb31fd 100644 --- a/tests/test_label_filter.py +++ b/tests/test_label_filter.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import LabelFilter -from tests.utils import allclose, clone +from tests.utils import assert_allclose, clone grid_1 = torch.tensor( [ @@ -108,10 +108,9 @@ def test_correct_results(self, _, args, input_image, expected): converter = LabelFilter(**args) if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): result = converter(clone(input_image).cuda()) - assert allclose(result, expected.cuda()) else: result = converter(clone(input_image)) - assert allclose(result, expected) + assert_allclose(result, expected) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index bd53fa1fb0..c05c3a1e0d 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -12,10 +12,9 @@ import unittest import numpy as np -import torch from monai.transforms import RandAxisFlip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class TestRandAxisFlip(NumpyImageTestCase2D): @@ -23,13 +22,10 @@ def test_correct_results(self): for p in TEST_NDARRAYS: flip = RandAxisFlip(prob=1.0) result = flip(p(self.imt[0])) - if isinstance(result, torch.Tensor): - result = result.cpu() - expected = [] for channel in self.imt[0]: expected.append(np.flip(channel, flip._axis)) - self.assertTrue(np.allclose(np.stack(expected), result)) + assert_allclose(np.stack(expected), result) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index 518d78dd29..7bef0baa63 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -12,10 +12,9 @@ import unittest import numpy as np -import torch from monai.transforms import RandAxisFlipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D, assert_allclose class TestRandAxisFlip(NumpyImageTestCase3D): @@ -23,13 +22,11 @@ def test_correct_results(self): for p in TEST_NDARRAYS: flip = RandAxisFlipd(keys="img", prob=1.0) result = flip({"img": p(self.imt[0])})["img"] - if isinstance(result, torch.Tensor): - result = result.cpu() expected = [] for channel in self.imt[0]: expected.append(np.flip(channel, flip._axis)) - self.assertTrue(np.allclose(np.stack(expected), result)) + assert_allclose(np.stack(expected), result) if __name__ == "__main__": diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index c20c13fec5..b3c514cb1f 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -12,11 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import RandFlip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -40,9 +39,7 @@ def test_correct_results(self, _, spatial_axis): expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) result = flip(im) - if isinstance(result, torch.Tensor): - result = result.cpu() - self.assertTrue(np.allclose(expected, result)) + assert_allclose(expected, result) if __name__ == "__main__": diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index 42c7dfe4b5..8972024fd8 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -12,11 +12,10 @@ import unittest import numpy as np -import torch from parameterized import parameterized from monai.transforms import RandFlipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1])] @@ -27,13 +26,11 @@ def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: flip = RandFlipd(keys="img", prob=1.0, spatial_axis=spatial_axis) result = flip({"img": p(self.imt[0])})["img"] - if isinstance(result, torch.Tensor): - result = result.cpu() expected = [] for channel in self.imt[0]: expected.append(np.flip(channel, spatial_axis)) expected = np.stack(expected) - self.assertTrue(np.allclose(expected, result)) + assert_allclose(expected, result) if __name__ == "__main__": diff --git a/tests/utils.py b/tests/utils.py index 1148af7551..22720849f1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,6 +33,7 @@ from monai.config import NdarrayTensor from monai.config.deviceconfig import USE_COMPILED +from monai.config.type_definitions import NdarrayOrTensor from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism from monai.utils.module import version_leq @@ -55,27 +56,17 @@ def clone(data: NdarrayTensor) -> NdarrayTensor: return copy.deepcopy(data) -def allclose(a: NdarrayTensor, b: NdarrayTensor) -> bool: +def assert_allclose(a: NdarrayOrTensor, b: NdarrayOrTensor, *args, **kwargs): """ - Check if all values of two data objects are close. - - Note: - This method also checks that both data objects are either Pytorch Tensors or numpy arrays. + Assert that all values of two data objects are close. Args: - a (NdarrayTensor): Pytorch Tensor or numpy array for comparison - b (NdarrayTensor): Pytorch Tensor or numpy array to compare against - - Returns: - bool: If both data objects are close. + a (NdarrayOrTensor): Pytorch Tensor or numpy array for comparison + b (NdarrayOrTensor): Pytorch Tensor or numpy array to compare against """ - if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): - return np.allclose(a, b) - - if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): - return torch.allclose(a, b) - - return False + a = a.cpu() if isinstance(a, torch.Tensor) else a + b = b.cpu() if isinstance(b, torch.Tensor) else b + np.testing.assert_allclose(a, b, *args, **kwargs) def test_pretrained_networks(network, input_param, device):