From 99b9dd7cf0667113e56f0dd90d40699115048d59 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 5 Jul 2022 13:18:14 +0100 Subject: [PATCH] compat. torch 1.7/1.8 sw integration Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 2 ++ monai/inferers/utils.py | 2 +- tests/test_integration_sliding_window.py | 9 ++++++--- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 1582652f53..969f93872f 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -66,6 +66,8 @@ class MetaTensor(MetaObj, torch.Tensor): not work if `im` is of type `MetaTensor`. This can be resolved with `torch.jit.trace(net, im.as_tensor())`. - For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported. + - For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor. + see: https://github.com/pytorch/pytorch/issues/54457 - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute. diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 2fa0b79476..5126b23c0a 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -137,7 +137,7 @@ def sliding_window_inference( diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) - inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval) + inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index ba1f96c1bc..9b1c7e5200 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -24,7 +24,7 @@ from monai.networks import eval_mode, predict_segmentation from monai.networks.nets import UNet from monai.transforms import AddChannel, SaveImage -from monai.utils import set_determinism +from monai.utils import pytorch_after, set_determinism from tests.utils import DistTestCase, TimedCall, make_nifti_image, skip_if_quick @@ -47,8 +47,11 @@ def _sliding_window_processor(_engine, batch): return predict_segmentation(seg_probs) def save_func(engine): - for m in engine.state.output: - saver(m) + if pytorch_after(1, 9, 1): + for m in engine.state.output: + saver(m) + else: + saver(engine.state.output[0]) infer_engine = Engine(_sliding_window_processor) infer_engine.add_event_handler(Events.ITERATION_COMPLETED, save_func)