Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Forced
Fourier
class to output contiguous tensors. (#7969)
Forced `Fourier` class to output contiguous tensors, which potentially fixes a performance bottleneck. ### Description Some transforms, such as `RandKSpaceSpikeNoise`, rely on the `Fourier` class. In its current state, the `Fourier` class returns non-contiguous tensors, which potentially limits performance. For example, when followed by `RandHistogramShift`, the following warning occurs: ``` <path_to_monai>/monai/transforms/intensity/array.py:1852: UserWarning: torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value tensor if possible. This message will only appear once per program. (Triggered internally at /opt/conda/conda-bld/pytorch_1716905975447/work/aten/src/ATen/native/BucketizationUtils.h:32.) indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1 ``` A straightforward fix is to force the `Fourier` class to output contiguous tensors (see commit). To reproduce, please run: ``` from monai.transforms import RandKSpaceSpikeNoise from monai.transforms.utils import Fourier import numpy as np ### TEST WITH TRANSFORMS ### t = RandKSpaceSpikeNoise(prob=1) # for torch tensors a_torch = torch.rand(1, 128, 128, 128) print(a_torch.is_contiguous()) a_torch_mod = t(a_torch) print(a_torch_mod.is_contiguous()) # for np arrays a_np = np.random.rand(1, 128, 128, 128) print(a_np.flags['C_CONTIGUOUS']) a_np_mod = t(a_np) # automatically transformed to torch.tensor print(a_np_mod.is_contiguous()) ### TEST DIRECTLY WITH FOURIER ### f = Fourier() # inv_shift_fourier # for torch tensors real_torch = torch.randn(1, 128, 128, 128) im_torch = torch.randn(1, 128, 128, 128) k_torch = torch.complex(real_torch, im_torch) print(k_torch.is_contiguous()) out_torch = f.inv_shift_fourier(k_torch, spatial_dims=3) print(out_torch.is_contiguous()) # for np arrays real_np = np.random.randn(1, 100, 100, 100) im_np = np.random.randn(1, 100, 100, 100) k_np = real_np + 1j * im_np print(k_np.flags['C_CONTIGUOUS']) out_np = f.inv_shift_fourier(k_np, spatial_dims=3) print(out_np.flags['C_CONTIGUOUS']) # shift_fourier # for torch tensors a_torch = torch.rand(1, 128, 128, 128) print(a_torch.is_contiguous()) out_torch = f.shift_fourier(a_torch, spatial_dims=3) print(out_torch.is_contiguous()) # for np arrays a_np = np.random.rand(1, 128, 128, 128) print(a_np.flags['C_CONTIGUOUS']) out_np = f.shift_fourier(a_np, spatial_dims=3) print(out_np.flags['C_CONTIGUOUS']) ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Bastian Wittmann <[email protected]>. Signed-off-by: Bastian Wittmann <[email protected]> Co-authored-by: YunLiu <[email protected]>
- Loading branch information