diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 1582652f53..969f93872f 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -66,6 +66,8 @@ class MetaTensor(MetaObj, torch.Tensor): not work if `im` is of type `MetaTensor`. This can be resolved with `torch.jit.trace(net, im.as_tensor())`. - For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported. + - For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor. + see: https://github.com/pytorch/pytorch/issues/54457 - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute. diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 2fa0b79476..5126b23c0a 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -137,7 +137,7 @@ def sliding_window_inference( diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) - inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval) + inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index ba1f96c1bc..9b1c7e5200 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -24,7 +24,7 @@ from monai.networks import eval_mode, predict_segmentation from monai.networks.nets import UNet from monai.transforms import AddChannel, SaveImage -from monai.utils import set_determinism +from monai.utils import pytorch_after, set_determinism from tests.utils import DistTestCase, TimedCall, make_nifti_image, skip_if_quick @@ -47,8 +47,11 @@ def _sliding_window_processor(_engine, batch): return predict_segmentation(seg_probs) def save_func(engine): - for m in engine.state.output: - saver(m) + if pytorch_after(1, 9, 1): + for m in engine.state.output: + saver(m) + else: + saver(engine.state.output[0]) infer_engine = Engine(_sliding_window_processor) infer_engine.add_event_handler(Events.ITERATION_COMPLETED, save_func)