Skip to content

Commit

Permalink
add support for replicate padding + trick to change symmetric+1 paddi…
Browse files Browse the repository at this point in the history
…ng in replicate
  • Loading branch information
franckma31 committed Nov 7, 2024
1 parent a84ee72 commit 9e07baf
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
9 changes: 7 additions & 2 deletions deel/torchlip/modules/unconstrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,13 @@ def __init__(
self.old_padding = padding
self.old_padding_mode = padding_mode
if padding_mode.lower() == "symmetric":
padding_mode = "zeros"
padding = "valid"
# symmetric padding of one pixel can be replaced by replicate
if (isinstance(padding, int) and padding <= 1) or \
(isinstance(padding, tuple) and padding[0] <= 1 and padding[1] <= 1):
self.old_padding_mode = padding_mode = "replicate"
else:
padding_mode = "zeros"
padding = "valid"

super(PadConv2d, self).__init__(
in_channels=in_channels,
Expand Down
26 changes: 22 additions & 4 deletions tests/test_unconstrained_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,26 @@ def compare(x, x_ref, index_x=[], index_x_ref=[]):
np.testing.assert_allclose(x_cropped, np.zeros(x_cropped.shape), 1e-2, 0)
else:
np.testing.assert_allclose(
x_cropped,
x_cropped-
x_ref[
:, :, index_x_ref[0] : index_x_ref[1], index_x_ref[3] : index_x_ref[4]
][:, :, :: index_x_ref[2], :: index_x_ref[5]],
np.zeros(x_cropped.shape),
1e-2,
0,
)
# np.testing.assert_allclose(
# x_cropped,
# x_ref[
# :, :, index_x_ref[0] : index_x_ref[1], index_x_ref[3] : index_x_ref[4]
# ][:, :, :: index_x_ref[2], :: index_x_ref[5]],
# 1e-2,
# 0,
# )


@pytest.mark.parametrize(
"padding_tested", ["circular", "constant", "symmetric", "reflect"]
"padding_tested", ["circular", "constant", "symmetric", "reflect","replicate"]
)
@pytest.mark.parametrize(
"input_shape, batch_size, kernel_size, filters",
Expand Down Expand Up @@ -90,15 +99,19 @@ def test_padding(padding_tested, input_shape, batch_size, kernel_size, filters):
right_x_pad = [p_vert, -p_vert, 1, -p_hor, x_pad_NCHW[3], 1, "right"]
all_x = [0, x_NCHW[2], 1, 0, x_NCHW[3], 1]
upper_x = [0, p_vert, 1, 0, x_NCHW[3], 1]
upper_x_first = [0, 1, 1, 0, x_NCHW[3], 1]
upper_x_rev = [0, p_vert, -1, 0, x_NCHW[3], 1]
upper_x_refl = [1, p_vert + 1, -1, 0, x_NCHW[3], 1]
lower_x = [-p_vert, x_NCHW[2], 1, 0, x_NCHW[3], 1]
lower_x_last = [-1, x_NCHW[2], 1, 0, x_NCHW[3], 1]
lower_x_rev = [-p_vert, x_NCHW[2], -1, 0, x_NCHW[3], 1]
lower_x_refl = [-p_vert - 1, x_NCHW[2] - 1, -1, 0, x_NCHW[3], 1]
left_x = [0, x_NCHW[2], 1, 0, p_hor, 1]
left_x_first = [0, x_NCHW[2], 1, 0, 1, 1]
left_x_rev = [0, x_NCHW[2], 1, 0, p_hor, -1]
left_x_refl = [0, x_NCHW[2], 1, 1, p_hor + 1, -1]
right_x = [0, x_NCHW[2], 1, -p_hor, x_NCHW[3], 1]
right_x_last = [0, x_NCHW[2], 1, -1, x_NCHW[3], 1]
right_x_rev = [0, x_NCHW[2], 1, -p_hor, x_NCHW[3], -1]
right_x_refl = [0, x_NCHW[2], 1, -p_hor - 1, x_NCHW[3] - 1, -1]
zero_pad = [None, None, None, None]
Expand All @@ -108,30 +121,35 @@ def test_padding(padding_tested, input_shape, batch_size, kernel_size, filters):
"constant": [center_x_pad, all_x],
"symmetric": [center_x_pad, all_x],
"reflect": [center_x_pad, all_x],
"replicate": [center_x_pad, all_x],
},
{
"circular": [upper_x_pad, lower_x],
"constant": [upper_x_pad, zero_pad],
"symmetric": [upper_x_pad, upper_x_rev],
"reflect": [upper_x_pad, upper_x_refl],
"replicate": [upper_x_pad, upper_x_first],
},
{
"circular": [lower_x_pad, upper_x],
"constant": [lower_x_pad, zero_pad],
"symmetric": [lower_x_pad, lower_x_rev],
"reflect": [lower_x_pad, lower_x_refl],
"replicate": [lower_x_pad, lower_x_last],
},
{
"circular": [left_x_pad, right_x],
"constant": [left_x_pad, zero_pad],
"symmetric": [left_x_pad, left_x_rev],
"reflect": [left_x_pad, left_x_refl],
"replicate": [left_x_pad, left_x_first],
},
{
"circular": [right_x_pad, left_x],
"constant": [right_x_pad, zero_pad],
"symmetric": [right_x_pad, right_x_rev],
"reflect": [right_x_pad, right_x_refl],
"replicate": [right_x_pad, right_x_last],
},
]

Expand All @@ -149,7 +167,7 @@ def test_padding(padding_tested, input_shape, batch_size, kernel_size, filters):
reason="PadConv2d not available",
)
@pytest.mark.parametrize(
"padding_tested", ["circular", "constant", "symmetric", "reflect", "same", "valid"]
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"]
)
@pytest.mark.parametrize(
"input_shape, batch_size, kernel_size, filters",
Expand Down Expand Up @@ -222,7 +240,7 @@ def test_predict(padding_tested, input_shape, batch_size, kernel_size, filters):
reason="PadConv2d not available",
)
@pytest.mark.parametrize(
"padding_tested", ["circular", "constant", "symmetric", "reflect", "same", "valid"]
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"]
)
@pytest.mark.parametrize(
"input_shape, batch_size, kernel_size, filters",
Expand Down
3 changes: 2 additions & 1 deletion tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def is_supported_padding(padding):
"reflect",
"circular",
"symmetric",
'replicate'
] # "constant",


Expand All @@ -635,7 +636,7 @@ def pad_input(x, padding, kernel_size):
kernel_size = [kernel_size, kernel_size]
if padding.lower() in ["same", "valid"]:
return x
elif padding.lower() in ["constant", "reflect", "circular"]:
elif padding.lower() in ["constant", "reflect", "circular",'replicate']:
p_vert, p_hor = kernel_size[0] // 2, kernel_size[1] // 2
pad_sizes = [
p_hor,
Expand Down

0 comments on commit 9e07baf

Please sign in to comment.