diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 5d6b4d87fd..58f3526086 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -9,9 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Hashable, Optional, Tuple +from typing import Hashable, Optional, Tuple -import numpy as np import torch from monai.transforms.transform import RandomizableTransform, Transform @@ -113,7 +112,7 @@ def pop_transform(self, data: dict, key: Hashable) -> None: """Remove most recent transform.""" data[str(key) + InverseKeys.KEY_SUFFIX].pop() - def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: dict) -> dict: """ Inverse of ``__call__``. diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 86f0e84249..a7d93f88f3 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -20,6 +20,7 @@ import torch from monai.config import USE_COMPILED, DtypeLike +from monai.config.type_definitions import NdarrayOrTensor from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.transforms.croppad.array import CenterSpatialCrop @@ -45,6 +46,7 @@ issequenceiterable, optional_import, ) +from monai.utils.enums import TransformBackends from monai.utils.module import look_up_option nib, _ = optional_import("nibabel") @@ -317,17 +319,20 @@ class Flip(Transform): """ + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: self.spatial_axis = spatial_axis - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - - result: np.ndarray = np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) - return result.astype(img.dtype) + if isinstance(img, np.ndarray): + return np.ascontiguousarray(np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis))) + else: + return torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) class Resize(Transform): @@ -800,11 +805,13 @@ class RandFlip(RandomizableTransform): spatial_axis: Spatial axes along which to flip over. Default is None. """ + backend = Flip.backend + def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -826,15 +833,17 @@ class RandAxisFlip(RandomizableTransform): """ + backend = Flip.backend + def __init__(self, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None - def randomize(self, data: np.ndarray) -> None: + def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) self._axis = self.R.randint(data.ndim - 1) - def __call__(self, img: np.ndarray) -> np.ndarray: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d953fd63ea..b0558a6556 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -23,6 +23,7 @@ import torch from monai.config import DtypeLike, KeysCollection +from monai.config.type_definitions import NdarrayOrTensor from monai.networks.layers import AffineTransform from monai.networks.layers.simplelayers import GaussianFilter from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad @@ -1128,6 +1129,8 @@ class Flipd(MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = Flip.backend + def __init__( self, keys: KeysCollection, @@ -1137,20 +1140,17 @@ def __init__( super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): self.push_transform(d, key) d[key] = self.flipper(d[key]) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): _ = self.get_most_recent_transform(d, key) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform @@ -1173,6 +1173,8 @@ class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ + backend = Flip.backend + def __init__( self, keys: KeysCollection, @@ -1186,7 +1188,7 @@ def __init__( self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: self.randomize(None) d = dict(data) for key in self.key_iterator(d): @@ -1195,15 +1197,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform[InverseKeys.DO_TRANSFORM]: - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = self.flipper(d[key]) # Remove the applied transform @@ -1225,16 +1224,18 @@ class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform): """ + backend = Flip.backend + def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None - def randomize(self, data: np.ndarray) -> None: + def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) self._axis = self.R.randint(data.ndim - 1) - def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: self.randomize(data=data[self.keys[0]]) flipper = Flip(spatial_axis=self._axis) @@ -1245,16 +1246,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda self.push_transform(d, key, extra_info={"axis": self._axis}) return d - def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) if transform[InverseKeys.DO_TRANSFORM]: flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO]["axis"]) - # Might need to convert to numpy - if isinstance(d[key], torch.Tensor): - d[key] = torch.Tensor(d[key]).cpu().numpy() # Inverse is same as forward d[key] = flipper(d[key]) # Remove the applied transform diff --git a/tests/test_flip.py b/tests/test_flip.py index fe169c4da8..bd0162fb8b 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -12,10 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import Flip -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -31,12 +32,17 @@ def test_invalid_inputs(self, _, spatial_axis, raises): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, spatial_axis): - flip = Flip(spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) - expected = np.stack(expected) - self.assertTrue(np.allclose(expected, flip(self.imt[0]))) + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + flip = Flip(spatial_axis=spatial_axis) + expected = [] + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + result = flip(im) + if isinstance(result, torch.Tensor): + result = result.cpu() + self.assertTrue(np.allclose(expected, result)) if __name__ == "__main__": diff --git a/tests/test_flipd.py b/tests/test_flipd.py index b8996dee42..cec4a99cbf 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -12,10 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import Flipd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -31,13 +32,16 @@ def test_invalid_cases(self, _, spatial_axis, raises): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, spatial_axis): - flip = Flipd(keys="img", spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) - expected = np.stack(expected) - res = flip({"img": self.imt[0]}) - assert np.allclose(expected, res["img"]) + for p in TEST_NDARRAYS: + flip = Flipd(keys="img", spatial_axis=spatial_axis) + expected = [] + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + result = flip({"img": p(self.imt[0])})["img"] + if isinstance(result, torch.Tensor): + result = result.cpu() + assert np.allclose(expected, result) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index 0bc2eb130e..bd53fa1fb0 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -12,20 +12,24 @@ import unittest import numpy as np +import torch from monai.transforms import RandAxisFlip -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D class TestRandAxisFlip(NumpyImageTestCase2D): def test_correct_results(self): - flip = RandAxisFlip(prob=1.0) - result = flip(self.imt[0]) - - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, flip._axis)) - self.assertTrue(np.allclose(np.stack(expected), result)) + for p in TEST_NDARRAYS: + flip = RandAxisFlip(prob=1.0) + result = flip(p(self.imt[0])) + if isinstance(result, torch.Tensor): + result = result.cpu() + + expected = [] + for channel in self.imt[0]: + expected.append(np.flip(channel, flip._axis)) + self.assertTrue(np.allclose(np.stack(expected), result)) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index 154d7813cb..518d78dd29 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -12,20 +12,24 @@ import unittest import numpy as np +import torch from monai.transforms import RandAxisFlipd -from tests.utils import NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D class TestRandAxisFlip(NumpyImageTestCase3D): def test_correct_results(self): - flip = RandAxisFlipd(keys="img", prob=1.0) - result = flip({"img": self.imt[0]}) - - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, flip._axis)) - self.assertTrue(np.allclose(np.stack(expected), result["img"])) + for p in TEST_NDARRAYS: + flip = RandAxisFlipd(keys="img", prob=1.0) + result = flip({"img": p(self.imt[0])})["img"] + if isinstance(result, torch.Tensor): + result = result.cpu() + + expected = [] + for channel in self.imt[0]: + expected.append(np.flip(channel, flip._axis)) + self.assertTrue(np.allclose(np.stack(expected), result)) if __name__ == "__main__": diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index b7a019136c..c20c13fec5 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -12,10 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandFlip -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -31,12 +32,17 @@ def test_invalid_inputs(self, _, spatial_axis, raises): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, spatial_axis): - flip = RandFlip(prob=1.0, spatial_axis=spatial_axis) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) - expected = np.stack(expected) - self.assertTrue(np.allclose(expected, flip(self.imt[0]))) + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + flip = RandFlip(prob=1.0, spatial_axis=spatial_axis) + expected = [] + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + result = flip(im) + if isinstance(result, torch.Tensor): + result = result.cpu() + self.assertTrue(np.allclose(expected, result)) if __name__ == "__main__": diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index 7bbd15f04c..42c7dfe4b5 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -12,10 +12,11 @@ import unittest import numpy as np +import torch from parameterized import parameterized from monai.transforms import RandFlipd -from tests.utils import NumpyImageTestCase2D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1])] @@ -23,13 +24,16 @@ class TestRandFlipd(NumpyImageTestCase2D): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, spatial_axis): - flip = RandFlipd(keys="img", prob=1.0, spatial_axis=spatial_axis) - res = flip({"img": self.imt[0]}) - expected = [] - for channel in self.imt[0]: - expected.append(np.flip(channel, spatial_axis)) - expected = np.stack(expected) - self.assertTrue(np.allclose(expected, res["img"])) + for p in TEST_NDARRAYS: + flip = RandFlipd(keys="img", prob=1.0, spatial_axis=spatial_axis) + result = flip({"img": p(self.imt[0])})["img"] + if isinstance(result, torch.Tensor): + result = result.cpu() + expected = [] + for channel in self.imt[0]: + expected.append(np.flip(channel, spatial_axis)) + expected = np.stack(expected) + self.assertTrue(np.allclose(expected, result)) if __name__ == "__main__": diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 66d7627971..a07d59703d 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -25,6 +25,7 @@ from monai.transforms.croppad.dictionary import SpatialPadd from monai.transforms.spatial.dictionary import Rand2DElasticd, RandFlipd, Spacingd from monai.utils import optional_import, set_determinism +from tests.utils import TEST_NDARRAYS if TYPE_CHECKING: import tqdm @@ -40,7 +41,7 @@ class TestTestTimeAugmentation(unittest.TestCase): @staticmethod - def get_data(num_examples, input_size, include_label=True): + def get_data(num_examples, input_size, data_type=np.asarray, include_label=True): custom_create_test_image_2d = partial( create_test_image_2d, *input_size, rad_max=7, num_seg_classes=1, num_objs=1 ) @@ -48,10 +49,10 @@ def get_data(num_examples, input_size, include_label=True): for _ in range(num_examples): im, label = custom_create_test_image_2d() d = {} - d["image"] = im + d["image"] = data_type(im) d["image_meta_dict"] = {"affine": np.eye(4)} if include_label: - d["label"] = label + d["label"] = data_type(label) d["label_meta_dict"] = {"affine": np.eye(4)} data.append(d) return data[0] if num_examples == 1 else data @@ -142,9 +143,10 @@ def test_fail_random_but_not_invertible(self): TestTimeAugmentation(transforms, None, None, None) def test_single_transform(self): - transforms = RandFlipd(["image", "label"], prob=1.0) - tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x) - tta(self.get_data(1, (20, 20))) + for p in TEST_NDARRAYS: + transforms = RandFlipd(["image", "label"], prob=1.0) + tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x) + tta(self.get_data(1, (20, 20), data_type=p)) def test_image_no_label(self): transforms = RandFlipd(["image"], prob=1.0)