Skip to content

Commit

Permalink
EnsureSameShaped, adds warn option (#6455)
Browse files Browse the repository at this point in the history
Adds an option to hide warning messages of EnsureSameShaped transform.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: myron <[email protected]>
  • Loading branch information
myron authored May 2, 2023
1 parent 01552fd commit 0a00115
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
22 changes: 16 additions & 6 deletions monai/apps/auto3dseg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
allow_missing_keys: bool = False,
source_key: str = "image",
allowed_shape_difference: int = 5,
warn: bool = True,
) -> None:
"""
Args:
Expand All @@ -47,29 +48,38 @@ def __init__(
source_key: key of the item with the reference shape.
allowed_shape_difference: raises error if shapes are different more than this value in any dimension,
otherwise corrects for the shape mismatch using nearest interpolation.
warn: if `True` prints a warning if the label image is resized
"""
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
self.source_key = source_key
self.allowed_shape_difference = allowed_shape_difference
self.warn = warn

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
d = dict(data)
image_shape = d[self.source_key].shape[1:]
for key in self.key_iterator(d):
label_shape = d[key].shape[1:]
if label_shape != image_shape:
filename = ""
if hasattr(d[key], "meta") and isinstance(d[key].meta, Mapping): # type: ignore[attr-defined]
filename = d[key].meta.get(ImageMetaKey.FILENAME_OR_OBJ) # type: ignore[attr-defined]

if np.allclose(list(label_shape), list(image_shape), atol=self.allowed_shape_difference):
msg = f"The {key} with shape {label_shape} was resized to match the source shape {image_shape}"
if hasattr(d[key], "meta") and isinstance(d[key].meta, Mapping): # type: ignore[attr-defined]
filename = d[key].meta.get(ImageMetaKey.FILENAME_OR_OBJ) # type: ignore[attr-defined]
msg += f", the metadata was not updated: filename={filename}"
warnings.warn(msg)
if self.warn:
warnings.warn(
f"The {key} with shape {label_shape} was resized to match the source shape {image_shape}"
f", the metadata was not updated {filename}."
)
d[key] = torch.nn.functional.interpolate(
input=d[key].unsqueeze(0),
size=image_shape,
mode="nearest-exact" if pytorch_after(1, 11) else "nearest",
).squeeze(0)
else:
raise ValueError(f"The {key} shape {label_shape} is different from the source shape {image_shape}.")
raise ValueError(
f"The {key} shape {label_shape} is different from the source shape {image_shape} {filename}."
)
return d
8 changes: 5 additions & 3 deletions tests/test_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from monai.networks import eval_mode
from monai.networks.nets import DenseNet121, Densenet169, DenseNet264, densenet201
from monai.utils import optional_import
from tests.utils import skip_if_quick, test_script_save
from tests.utils import skip_if_downloading_fails, skip_if_quick, test_script_save

if TYPE_CHECKING:
import torchvision
Expand Down Expand Up @@ -82,7 +82,8 @@ class TestPretrainedDENSENET(unittest.TestCase):
@parameterized.expand([TEST_PRETRAINED_2D_CASE_1, TEST_PRETRAINED_2D_CASE_2])
@skip_if_quick
def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_shape):
net = model(**input_param).to(device)
with skip_if_downloading_fails():
net = model(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)
Expand All @@ -91,7 +92,8 @@ def test_121_2d_shape_pretrain(self, model, input_param, input_shape, expected_s
@skipUnless(has_torchvision, "Requires `torchvision` package.")
def test_pretrain_consistency(self, model, input_param, input_shape):
example = torch.randn(input_shape).to(device)
net = model(**input_param).to(device)
with skip_if_downloading_fails():
net = model(**input_param).to(device)
with eval_mode(net):
result = net.features.forward(example)
torchvision_net = torchvision.models.densenet121(pretrained=True).to(device)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_torchvision_fc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from monai.networks.nets import TorchVisionFCModel, UNet
from monai.networks.utils import look_up_named_module, set_named_module
from monai.utils import min_version, optional_import
from tests.utils import skip_if_downloading_fails

Inception_V3_Weights, has_enum = optional_import("torchvision.models.inception", name="Inception_V3_Weights")

Expand Down Expand Up @@ -176,7 +177,8 @@ def test_without_pretrained(self, input_param, input_shape, expected_shape):
)
@skipUnless(has_tv, "Requires TorchVision.")
def test_with_pretrained(self, input_param, input_shape, expected_shape, expected_value):
net = TorchVisionFCModel(**input_param).to(device)
with skip_if_downloading_fails():
net = TorchVisionFCModel(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
value = next(net.features.parameters())[0, 0, 0, 0].item()
Expand Down

0 comments on commit 0a00115

Please sign in to comment.