Skip to content

Commit

Permalink
Forced Fourier class to output contiguous tensors. (#7969)
Browse files Browse the repository at this point in the history
Forced `Fourier` class to output contiguous tensors, which potentially
fixes a performance bottleneck.

### Description

Some transforms, such as `RandKSpaceSpikeNoise`, rely on the `Fourier`
class.
In its current state, the `Fourier` class returns non-contiguous
tensors, which potentially limits performance.
For example, when followed by `RandHistogramShift`, the following
warning occurs:

```
<path_to_monai>/monai/transforms/intensity/array.py:1852: UserWarning: torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value tensor if possible. This message will only appear once per program. (Triggered internally at /opt/conda/conda-bld/pytorch_1716905975447/work/aten/src/ATen/native/BucketizationUtils.h:32.)
  indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1
```

A straightforward fix is to force the `Fourier` class to output
contiguous tensors (see commit).
To reproduce, please run:

```
from monai.transforms import RandKSpaceSpikeNoise
from monai.transforms.utils import Fourier
import numpy as np

### TEST WITH TRANSFORMS ###   
t = RandKSpaceSpikeNoise(prob=1)

# for torch tensors
a_torch = torch.rand(1, 128, 128, 128)
print(a_torch.is_contiguous())

a_torch_mod = t(a_torch)
print(a_torch_mod.is_contiguous())

# for np arrays
a_np = np.random.rand(1, 128, 128, 128)
print(a_np.flags['C_CONTIGUOUS'])

a_np_mod = t(a_np)  # automatically transformed to torch.tensor
print(a_np_mod.is_contiguous())

### TEST DIRECTLY WITH FOURIER ###
f = Fourier()

# inv_shift_fourier
# for torch tensors
real_torch = torch.randn(1, 128, 128, 128)
im_torch = torch.randn(1, 128, 128, 128)
k_torch = torch.complex(real_torch, im_torch)
print(k_torch.is_contiguous())

out_torch = f.inv_shift_fourier(k_torch, spatial_dims=3)
print(out_torch.is_contiguous())

# for np arrays
real_np = np.random.randn(1, 100, 100, 100)
im_np = np.random.randn(1, 100, 100, 100)
k_np = real_np + 1j * im_np
print(k_np.flags['C_CONTIGUOUS'])

out_np = f.inv_shift_fourier(k_np, spatial_dims=3)
print(out_np.flags['C_CONTIGUOUS'])

# shift_fourier
# for torch tensors
a_torch = torch.rand(1, 128, 128, 128)
print(a_torch.is_contiguous())

out_torch = f.shift_fourier(a_torch, spatial_dims=3)
print(out_torch.is_contiguous())

# for np arrays
a_np = np.random.rand(1, 128, 128, 128)
print(a_np.flags['C_CONTIGUOUS'])

out_np = f.shift_fourier(a_np, spatial_dims=3)
print(out_np.flags['C_CONTIGUOUS'])
```

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Bastian Wittmann <[email protected]>.
Signed-off-by: Bastian Wittmann <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
bwittmann and KumoLiu authored Sep 3, 2024
1 parent dbfe418 commit befb5f6
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,14 +1863,15 @@ class Fourier:
"""

@staticmethod
def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = False) -> NdarrayOrTensor:
"""
Applies fourier transform and shifts the zero-frequency component to the
center of the spectrum. Only the spatial dimensions get transformed.
Args:
x: Image to transform.
spatial_dims: Number of spatial dimensions.
as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
Returns
k: K-space data.
Expand All @@ -1885,17 +1886,20 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims)
else:
k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims)
return k
return ascontiguousarray(k) if as_contiguous else k

@staticmethod
def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None) -> NdarrayOrTensor:
def inv_shift_fourier(
k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None, as_contiguous: bool = False
) -> NdarrayOrTensor:
"""
Applies inverse shift and fourier transform. Only the spatial
dimensions are transformed.
Args:
k: K-space data.
spatial_dims: Number of spatial dimensions.
as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
Returns:
x: Tensor in image space.
Expand All @@ -1910,7 +1914,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None
out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real
else:
out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real
return out
return ascontiguousarray(out) if as_contiguous else out


def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Hashable | None = None) -> int:
Expand Down

0 comments on commit befb5f6

Please sign in to comment.