diff --git a/monai/apps/auto3dseg/transforms.py b/monai/apps/auto3dseg/transforms.py index 0bb65edd13..bb755aa78c 100644 --- a/monai/apps/auto3dseg/transforms.py +++ b/monai/apps/auto3dseg/transforms.py @@ -39,6 +39,7 @@ def __init__( allow_missing_keys: bool = False, source_key: str = "image", allowed_shape_difference: int = 5, + warn: bool = True, ) -> None: """ Args: @@ -47,11 +48,14 @@ def __init__( source_key: key of the item with the reference shape. allowed_shape_difference: raises error if shapes are different more than this value in any dimension, otherwise corrects for the shape mismatch using nearest interpolation. + warn: if `True` prints a warning if the label image is resized + """ super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) self.source_key = source_key self.allowed_shape_difference = allowed_shape_difference + self.warn = warn def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) @@ -59,17 +63,23 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc for key in self.key_iterator(d): label_shape = d[key].shape[1:] if label_shape != image_shape: + filename = "" + if hasattr(d[key], "meta") and isinstance(d[key].meta, Mapping): # type: ignore[attr-defined] + filename = d[key].meta.get(ImageMetaKey.FILENAME_OR_OBJ) # type: ignore[attr-defined] + if np.allclose(list(label_shape), list(image_shape), atol=self.allowed_shape_difference): - msg = f"The {key} with shape {label_shape} was resized to match the source shape {image_shape}" - if hasattr(d[key], "meta") and isinstance(d[key].meta, Mapping): # type: ignore[attr-defined] - filename = d[key].meta.get(ImageMetaKey.FILENAME_OR_OBJ) # type: ignore[attr-defined] - msg += f", the metadata was not updated: filename={filename}" - warnings.warn(msg) + if self.warn: + warnings.warn( + f"The {key} with shape {label_shape} was resized to match the source shape {image_shape}" + f", the metadata was not updated {filename}." + ) d[key] = torch.nn.functional.interpolate( input=d[key].unsqueeze(0), size=image_shape, mode="nearest-exact" if pytorch_after(1, 11) else "nearest", ).squeeze(0) else: - raise ValueError(f"The {key} shape {label_shape} is different from the source shape {image_shape}.") + raise ValueError( + f"The {key} shape {label_shape} is different from the source shape {image_shape} {filename}." + ) return d diff --git a/tests/test_densenet.py b/tests/test_densenet.py index 8354237a25..1b44baf0c2 100644 --- a/tests/test_densenet.py +++ b/tests/test_densenet.py @@ -21,7 +21,7 @@ from monai.networks import eval_mode from monai.networks.nets import DenseNet121, Densenet169, DenseNet264, densenet201 from monai.utils import optional_import -from tests.utils import skip_if_quick, test_script_save +from tests.utils import skip_if_downloading_fails, skip_if_quick, test_script_save if TYPE_CHECKING: import torchvision @@ -82,7 +82,8 @@ class TestPretrainedDENSENET(unittest.TestCase): @parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2]) @skip_if_quick def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_shape): - net = model(**input_param).to(device) + with skip_if_downloading_fails(): + net = model(**input_param).to(device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) @@ -91,7 +92,8 @@ def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_s @skipUnless(has_torchvision, "Requires `torchvision` package.") def test_pretrain_consistency(self, model, input_param, input_shape): example = torch.randn(input_shape).to(device) - net = model(**input_param).to(device) + with skip_if_downloading_fails(): + net = model(**input_param).to(device) with eval_mode(net): result = net.features.forward(example) torchvision_net = torchvision.models.densenet121(pretrained=True).to(device) diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py index 5f92a1f8b4..e913b2b9b1 100644 --- a/tests/test_torchvision_fc_model.py +++ b/tests/test_torchvision_fc_model.py @@ -21,6 +21,7 @@ from monai.networks.nets import TorchVisionFCModel, UNet from monai.networks.utils import look_up_named_module, set_named_module from monai.utils import min_version, optional_import +from tests.utils import skip_if_downloading_fails Inception_V3_Weights, has_enum = optional_import("torchvision.models.inception", name="Inception_V3_Weights") @@ -176,7 +177,8 @@ def test_without_pretrained(self, input_param, input_shape, expected_shape): ) @skipUnless(has_tv, "Requires TorchVision.") def test_with_pretrained(self, input_param, input_shape, expected_shape, expected_value): - net = TorchVisionFCModel(**input_param).to(device) + with skip_if_downloading_fails(): + net = TorchVisionFCModel(**input_param).to(device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) value = next(net.features.parameters())[0, 0, 0, 0].item()