Skip to content

Commit

Permalink
Merge pull request #955 from benmalef/dynunet_set_automaticaly_stride…
Browse files Browse the repository at this point in the history
…s_kernels

Automatic set strides kernels in dynunet
  • Loading branch information
sarthakpati authored Oct 9, 2024
2 parents 1fc8ede + 4c3da18 commit bcffff8
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 30 deletions.
69 changes: 51 additions & 18 deletions GANDLF/models/dynunet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,41 @@
import monai.networks.nets.dynunet as dynunet


def get_kernels_strides(sizes, spacings):
"""
More info: https://github.com/Project-MONAI/tutorials/blob/main/modules/dynunet_pipeline/create_network.py#L19
When refering this method for other tasks, please ensure that the patch size for each spatial dimension should
be divisible by the product of all strides in the corresponding dimension.
In addition, the minimal spatial size should have at least one dimension that has twice the size of
the product of all strides. For patch sizes that cannot find suitable strides, an error will be raised.
"""
input_size = sizes
strides, kernels = [], []
while True:
spacing_ratio = [sp / min(spacings) for sp in spacings]
stride = [
2 if ratio <= 2 and size >= 8 else 1
for (ratio, size) in zip(spacing_ratio, sizes)
]
kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
if all(s == 1 for s in stride):
break
for idx, (i, j) in enumerate(zip(sizes, stride)):
assert (
i % j == 0
), f"Patch size is not supported, please try to modify the size {input_size[idx]} in the spatial dimension {idx}."
sizes = [i / j for i, j in zip(sizes, stride)]
spacings = [i * j for i, j in zip(spacings, stride)]
kernels.append(kernel)
strides.append(stride)

strides.insert(0, len(spacings) * [1])
kernels.append(len(spacings) * [3])
return kernels, strides


class dynunet_wrapper(ModelBase):
"""
More info: https://docs.monai.io/en/stable/networks.html#dynunet
Expand All @@ -26,35 +61,33 @@ class dynunet_wrapper(ModelBase):
def __init__(self, parameters: dict):
super(dynunet_wrapper, self).__init__(parameters)

# checking for validation
assert (
"kernel_size" in parameters["model"]
) == True, "\033[0;31m`kernel_size` key missing in parameters"
assert (
"strides" in parameters["model"]
) == True, "\033[0;31m`strides` key missing in parameters"

# defining some defaults
# if not ("upsample_kernel_size" in parameters["model"]):
# parameters["model"]["upsample_kernel_size"] = parameters["model"][
# "strides"
# ][1:]
patch_size = parameters.get("patch_size", None)
spacing = parameters.get(
"spacing_for_internal_computations",
[1.0 for i in range(parameters["model"]["dimension"])],
)
parameters["model"]["kernel_size"] = parameters["model"].get(
"kernel_size", None
)
parameters["model"]["strides"] = parameters["model"].get("strides", None)
if (parameters["model"]["kernel_size"] is None) or (
parameters["model"]["strides"] is None
):
kernel_size, strides = get_kernels_strides(patch_size, spacing)
parameters["model"]["kernel_size"] = kernel_size
parameters["model"]["strides"] = strides

parameters["model"]["filters"] = parameters["model"].get("filters", None)
parameters["model"]["act_name"] = parameters["model"].get(
"act_name", ("leakyrelu", {"inplace": True, "negative_slope": 0.01})
)

parameters["model"]["deep_supervision"] = parameters["model"].get(
"deep_supervision", True
"deep_supervision", False
)

parameters["model"]["deep_supr_num"] = parameters["model"].get(
"deep_supr_num", 1
)

parameters["model"]["res_block"] = parameters["model"].get("res_block", True)

parameters["model"]["trans_bias"] = parameters["model"].get("trans_bias", False)
parameters["model"]["dropout"] = parameters["model"].get("dropout", None)

Expand Down
12 changes: 0 additions & 12 deletions testing/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,6 @@ def test_train_segmentation_rad_2d(device):
["acs", "soft", "conv3d"]
)

if model == "dynunet":
# More info: https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py
parameters["model"]["kernel_size"] = (3, 3, 3, 1)
parameters["model"]["strides"] = (1, 1, 1, 1)
parameters["model"]["deep_supervision"] = False

parameters["model"]["architecture"] = model
parameters["nested_training"]["testing"] = -5
parameters["nested_training"]["validation"] = -5
Expand Down Expand Up @@ -374,12 +368,6 @@ def test_train_segmentation_rad_3d(device):
["acs", "soft", "conv3d"]
)

if model == "dynunet":
# More info: https://github.com/Project-MONAI/MONAI/blob/96bfda00c6bd290297f5e3514ea227c6be4d08b4/tests/test_dynunet.py
parameters["model"]["kernel_size"] = (3, 3, 3, 1)
parameters["model"]["strides"] = (1, 1, 1, 1)
parameters["model"]["deep_supervision"] = False

parameters["model"]["architecture"] = model
parameters["nested_training"]["testing"] = -5
parameters["nested_training"]["validation"] = -5
Expand Down

0 comments on commit bcffff8

Please sign in to comment.