Skip to content

Commit

Permalink
clean and linter
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Nov 7, 2024
1 parent 9e07baf commit b721cb7
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 20 deletions.
5 changes: 3 additions & 2 deletions deel/torchlip/modules/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
from torch import nn


class LipResidual(nn.Module):
class LipResidual(nn.Module):
"""
This class is a 1-Lipschitz residual connection
With a learnable parameter alpha that give a tradeoff
between the x and the layer y=l(x)
Args:
"""

def __init__(self):
super().__init__()
self.alpha = nn.Parameter(torch.tensor(0.0), requires_grad=True)
Expand Down
7 changes: 4 additions & 3 deletions deel/torchlip/modules/unconstrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ def __init__(
self.old_padding_mode = padding_mode
if padding_mode.lower() == "symmetric":
# symmetric padding of one pixel can be replaced by replicate
if (isinstance(padding, int) and padding <= 1) or \
(isinstance(padding, tuple) and padding[0] <= 1 and padding[1] <= 1):
self.old_padding_mode = padding_mode = "replicate"
if (isinstance(padding, int) and padding <= 1) or (
isinstance(padding, tuple) and padding[0] <= 1 and padding[1] <= 1
):
self.old_padding_mode = padding_mode = "replicate"
else:
padding_mode = "zeros"
padding = "valid"
Expand Down
20 changes: 7 additions & 13 deletions tests/test_unconstrained_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,18 @@ def compare(x, x_ref, index_x=[], index_x_ref=[]):
np.testing.assert_allclose(x_cropped, np.zeros(x_cropped.shape), 1e-2, 0)
else:
np.testing.assert_allclose(
x_cropped-
x_ref[
x_cropped
- x_ref[
:, :, index_x_ref[0] : index_x_ref[1], index_x_ref[3] : index_x_ref[4]
][:, :, :: index_x_ref[2], :: index_x_ref[5]],
np.zeros(x_cropped.shape),
1e-2,
0,
)
# np.testing.assert_allclose(
# x_cropped,
# x_ref[
# :, :, index_x_ref[0] : index_x_ref[1], index_x_ref[3] : index_x_ref[4]
# ][:, :, :: index_x_ref[2], :: index_x_ref[5]],
# 1e-2,
# 0,
# )


@pytest.mark.parametrize(
"padding_tested", ["circular", "constant", "symmetric", "reflect","replicate"]
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate"]
)
@pytest.mark.parametrize(
"input_shape, batch_size, kernel_size, filters",
Expand Down Expand Up @@ -167,7 +159,8 @@ def test_padding(padding_tested, input_shape, batch_size, kernel_size, filters):
reason="PadConv2d not available",
)
@pytest.mark.parametrize(
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"]
"padding_tested",
["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"],
)
@pytest.mark.parametrize(
"input_shape, batch_size, kernel_size, filters",
Expand Down Expand Up @@ -240,7 +233,8 @@ def test_predict(padding_tested, input_shape, batch_size, kernel_size, filters):
reason="PadConv2d not available",
)
@pytest.mark.parametrize(
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"]
"padding_tested",
["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"],
)
@pytest.mark.parametrize(
"input_shape, batch_size, kernel_size, filters",
Expand Down
4 changes: 2 additions & 2 deletions tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def is_supported_padding(padding):
"reflect",
"circular",
"symmetric",
'replicate'
"replicate",
] # "constant",


Expand All @@ -636,7 +636,7 @@ def pad_input(x, padding, kernel_size):
kernel_size = [kernel_size, kernel_size]
if padding.lower() in ["same", "valid"]:
return x
elif padding.lower() in ["constant", "reflect", "circular",'replicate']:
elif padding.lower() in ["constant", "reflect", "circular", "replicate"]:
p_vert, p_hor = kernel_size[0] // 2, kernel_size[1] // 2
pad_sizes = [
p_hor,
Expand Down

0 comments on commit b721cb7

Please sign in to comment.