Skip to content

Commit

Permalink
compat. torch 1.7/1.8 sw integration
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jul 5, 2022
1 parent 24cf761 commit 99b9dd7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
2 changes: 2 additions & 0 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class MetaTensor(MetaObj, torch.Tensor):
not work if `im` is of type `MetaTensor`. This can be resolved with
`torch.jit.trace(net, im.as_tensor())`.
- For pytorch < 1.8, sharing `MetaTensor` instances across processes may not be supported.
- For pytorch < 1.9, next(iter(meta_tensor)) returns a torch.Tensor.
see: https://github.com/pytorch/pytorch/issues/54457
- A warning will be raised if in the constructor `affine` is not `None` and
`meta` already contains the key `affine`.
- You can query whether the `MetaTensor` is a batch with the `is_batch` attribute.
Expand Down
2 changes: 1 addition & 1 deletion monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def sliding_window_inference(
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
half = diff // 2
pad_size.extend([half, diff - half])
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)

scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)

Expand Down
9 changes: 6 additions & 3 deletions tests/test_integration_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from monai.networks import eval_mode, predict_segmentation
from monai.networks.nets import UNet
from monai.transforms import AddChannel, SaveImage
from monai.utils import set_determinism
from monai.utils import pytorch_after, set_determinism
from tests.utils import DistTestCase, TimedCall, make_nifti_image, skip_if_quick


Expand All @@ -47,8 +47,11 @@ def _sliding_window_processor(_engine, batch):
return predict_segmentation(seg_probs)

def save_func(engine):
for m in engine.state.output:
saver(m)
if pytorch_after(1, 9, 1):
for m in engine.state.output:
saver(m)
else:
saver(engine.state.output[0])

infer_engine = Engine(_sliding_window_processor)
infer_engine.add_event_handler(Events.ITERATION_COMPLETED, save_func)
Expand Down

0 comments on commit 99b9dd7

Please sign in to comment.