Skip to content

Commit

Permalink
Torch flip (#2822)
Browse files Browse the repository at this point in the history
torch transforms - RandRicianNoise, StdShiftIntensity, RandStdShiftIntensity
  • Loading branch information
rijobro authored Aug 24, 2021
1 parent 42ad892 commit fab8467
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 78 deletions.
5 changes: 2 additions & 3 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Hashable, Optional, Tuple
from typing import Hashable, Optional, Tuple

import numpy as np
import torch

from monai.transforms.transform import RandomizableTransform, Transform
Expand Down Expand Up @@ -113,7 +112,7 @@ def pop_transform(self, data: dict, key: Hashable) -> None:
"""Remove most recent transform."""
data[str(key) + InverseKeys.KEY_SUFFIX].pop()

def inverse(self, data: dict) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: dict) -> dict:
"""
Inverse of ``__call__``.
Expand Down
23 changes: 16 additions & 7 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch

from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.utils import compute_shape_offset, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.transforms.croppad.array import CenterSpatialCrop
Expand All @@ -45,6 +46,7 @@
issequenceiterable,
optional_import,
)
from monai.utils.enums import TransformBackends
from monai.utils.module import look_up_option

nib, _ = optional_import("nibabel")
Expand Down Expand Up @@ -317,17 +319,20 @@ class Flip(Transform):
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None:
self.spatial_axis = spatial_axis

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
"""

result: np.ndarray = np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis))
return result.astype(img.dtype)
if isinstance(img, np.ndarray):
return np.ascontiguousarray(np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)))
else:
return torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis))


class Resize(Transform):
Expand Down Expand Up @@ -800,11 +805,13 @@ class RandFlip(RandomizableTransform):
spatial_axis: Spatial axes along which to flip over. Default is None.
"""

backend = Flip.backend

def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None:
RandomizableTransform.__init__(self, prob)
self.flipper = Flip(spatial_axis=spatial_axis)

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
Expand All @@ -826,15 +833,17 @@ class RandAxisFlip(RandomizableTransform):
"""

backend = Flip.backend

def __init__(self, prob: float = 0.1) -> None:
RandomizableTransform.__init__(self, prob)
self._axis: Optional[int] = None

def randomize(self, data: np.ndarray) -> None:
def randomize(self, data: NdarrayOrTensor) -> None:
super().randomize(None)
self._axis = self.R.randint(data.ndim - 1)

def __call__(self, img: np.ndarray) -> np.ndarray:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
Expand Down
30 changes: 14 additions & 16 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch

from monai.config import DtypeLike, KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from monai.networks.layers import AffineTransform
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
Expand Down Expand Up @@ -1128,6 +1129,8 @@ class Flipd(MapTransform, InvertibleTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = Flip.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -1137,20 +1140,17 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.flipper = Flip(spatial_axis=spatial_axis)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
d[key] = self.flipper(d[key])
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
_ = self.get_most_recent_transform(d, key)
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
# Inverse is same as forward
d[key] = self.flipper(d[key])
# Remove the applied transform
Expand All @@ -1173,6 +1173,8 @@ class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform):
allow_missing_keys: don't raise exception if key is missing.
"""

backend = Flip.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -1186,7 +1188,7 @@ def __init__(

self.flipper = Flip(spatial_axis=spatial_axis)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
self.randomize(None)
d = dict(data)
for key in self.key_iterator(d):
Expand All @@ -1195,15 +1197,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
self.push_transform(d, key)
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM]:
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
# Inverse is same as forward
d[key] = self.flipper(d[key])
# Remove the applied transform
Expand All @@ -1225,16 +1224,18 @@ class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform):
"""

backend = Flip.backend

def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: bool = False) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
self._axis: Optional[int] = None

def randomize(self, data: np.ndarray) -> None:
def randomize(self, data: NdarrayOrTensor) -> None:
super().randomize(None)
self._axis = self.R.randint(data.ndim - 1)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
self.randomize(data=data[self.keys[0]])
flipper = Flip(spatial_axis=self._axis)

Expand All @@ -1245,16 +1246,13 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
self.push_transform(d, key, extra_info={"axis": self._axis})
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Check if random transform was actually performed (based on `prob`)
if transform[InverseKeys.DO_TRANSFORM]:
flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO]["axis"])
# Might need to convert to numpy
if isinstance(d[key], torch.Tensor):
d[key] = torch.Tensor(d[key]).cpu().numpy()
# Inverse is same as forward
d[key] = flipper(d[key])
# Remove the applied transform
Expand Down
20 changes: 13 additions & 7 deletions tests/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import Flip
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D

INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)]

Expand All @@ -31,12 +32,17 @@ def test_invalid_inputs(self, _, spatial_axis, raises):

@parameterized.expand(VALID_CASES)
def test_correct_results(self, _, spatial_axis):
flip = Flip(spatial_axis=spatial_axis)
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = np.stack(expected)
self.assertTrue(np.allclose(expected, flip(self.imt[0])))
for p in TEST_NDARRAYS:
im = p(self.imt[0])
flip = Flip(spatial_axis=spatial_axis)
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = np.stack(expected)
result = flip(im)
if isinstance(result, torch.Tensor):
result = result.cpu()
self.assertTrue(np.allclose(expected, result))


if __name__ == "__main__":
Expand Down
20 changes: 12 additions & 8 deletions tests/test_flipd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import Flipd
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D

INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)]

Expand All @@ -31,13 +32,16 @@ def test_invalid_cases(self, _, spatial_axis, raises):

@parameterized.expand(VALID_CASES)
def test_correct_results(self, _, spatial_axis):
flip = Flipd(keys="img", spatial_axis=spatial_axis)
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = np.stack(expected)
res = flip({"img": self.imt[0]})
assert np.allclose(expected, res["img"])
for p in TEST_NDARRAYS:
flip = Flipd(keys="img", spatial_axis=spatial_axis)
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = np.stack(expected)
result = flip({"img": p(self.imt[0])})["img"]
if isinstance(result, torch.Tensor):
result = result.cpu()
assert np.allclose(expected, result)


if __name__ == "__main__":
Expand Down
20 changes: 12 additions & 8 deletions tests/test_rand_axis_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,24 @@
import unittest

import numpy as np
import torch

from monai.transforms import RandAxisFlip
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D


class TestRandAxisFlip(NumpyImageTestCase2D):
def test_correct_results(self):
flip = RandAxisFlip(prob=1.0)
result = flip(self.imt[0])

expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, flip._axis))
self.assertTrue(np.allclose(np.stack(expected), result))
for p in TEST_NDARRAYS:
flip = RandAxisFlip(prob=1.0)
result = flip(p(self.imt[0]))
if isinstance(result, torch.Tensor):
result = result.cpu()

expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, flip._axis))
self.assertTrue(np.allclose(np.stack(expected), result))


if __name__ == "__main__":
Expand Down
20 changes: 12 additions & 8 deletions tests/test_rand_axis_flipd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,24 @@
import unittest

import numpy as np
import torch

from monai.transforms import RandAxisFlipd
from tests.utils import NumpyImageTestCase3D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D


class TestRandAxisFlip(NumpyImageTestCase3D):
def test_correct_results(self):
flip = RandAxisFlipd(keys="img", prob=1.0)
result = flip({"img": self.imt[0]})

expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, flip._axis))
self.assertTrue(np.allclose(np.stack(expected), result["img"]))
for p in TEST_NDARRAYS:
flip = RandAxisFlipd(keys="img", prob=1.0)
result = flip({"img": p(self.imt[0])})["img"]
if isinstance(result, torch.Tensor):
result = result.cpu()

expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, flip._axis))
self.assertTrue(np.allclose(np.stack(expected), result))


if __name__ == "__main__":
Expand Down
20 changes: 13 additions & 7 deletions tests/test_rand_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import RandFlip
from tests.utils import NumpyImageTestCase2D
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D

INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)]

Expand All @@ -31,12 +32,17 @@ def test_invalid_inputs(self, _, spatial_axis, raises):

@parameterized.expand(VALID_CASES)
def test_correct_results(self, _, spatial_axis):
flip = RandFlip(prob=1.0, spatial_axis=spatial_axis)
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = np.stack(expected)
self.assertTrue(np.allclose(expected, flip(self.imt[0])))
for p in TEST_NDARRAYS:
im = p(self.imt[0])
flip = RandFlip(prob=1.0, spatial_axis=spatial_axis)
expected = []
for channel in self.imt[0]:
expected.append(np.flip(channel, spatial_axis))
expected = np.stack(expected)
result = flip(im)
if isinstance(result, torch.Tensor):
result = result.cpu()
self.assertTrue(np.allclose(expected, result))


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit fab8467

Please sign in to comment.