Skip to content

Commit

Permalink
clean and linter
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Nov 27, 2024
1 parent d19ad7d commit ae8366c
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 40 deletions.
15 changes: 0 additions & 15 deletions deel/torchlip/modules/downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
from typing import Tuple

import torch

from .. import functional as F
from .module import LipschitzModule


Expand All @@ -43,15 +40,3 @@ def vanilla_export(self):
else:
return self


# class InvertibleDownSampling(torch.nn.Module, LipschitzModule):
# def __init__(self, kernel_size: Tuple[int, int], k_coef_lip: float = 1.0):
# torch.nn.Module.__init__(self)
# LipschitzModule.__init__(self, k_coef_lip)
# self.kernel_size = kernel_size

# def forward(self, input: torch.Tensor) -> torch.Tensor:
# return F.invertible_downsample(input, self.kernel_size) * self._coefficient_lip

# def vanilla_export(self):
# return self
7 changes: 3 additions & 4 deletions deel/torchlip/modules/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,20 @@
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
from typing import Tuple
from typing import Union

import torch
from torch import nn


class LipResidual(nn.Module):
class LipResidual(nn.Module):
"""
This class is a 1-Lipschitz residual connection
With a learnable parameter alpha that give a tradeoff
between the x and the layer y=l(x)
Args:
"""

def __init__(self):
super().__init__()
self.alpha = nn.Parameter(torch.tensor(0.0), requires_grad=True)
Expand Down
7 changes: 4 additions & 3 deletions deel/torchlip/modules/unconstrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ def __init__(
self.old_padding_mode = padding_mode
if padding_mode.lower() == "symmetric":
# symmetric padding of one pixel can be replaced by replicate
if (isinstance(padding, int) and padding <= 1) or \
(isinstance(padding, tuple) and padding[0] <= 1 and padding[1] <= 1):
self.old_padding_mode = padding_mode = "replicate"
if (isinstance(padding, int) and padding <= 1) or (
isinstance(padding, tuple) and padding[0] <= 1 and padding[1] <= 1
):
self.old_padding_mode = padding_mode = "replicate"
else:
padding_mode = "zeros"
padding = "valid"
Expand Down
2 changes: 0 additions & 2 deletions deel/torchlip/modules/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
from typing import Tuple
from typing import Union

import torch

Expand Down
20 changes: 7 additions & 13 deletions tests/test_unconstrained_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,18 @@ def compare(x, x_ref, index_x=[], index_x_ref=[]):
np.testing.assert_allclose(x_cropped, np.zeros(x_cropped.shape), 1e-2, 0)
else:
np.testing.assert_allclose(
x_cropped-
x_ref[
x_cropped
- x_ref[
:, :, index_x_ref[0] : index_x_ref[1], index_x_ref[3] : index_x_ref[4]
][:, :, :: index_x_ref[2], :: index_x_ref[5]],
np.zeros(x_cropped.shape),
1e-2,
0,
)
# np.testing.assert_allclose(
# x_cropped,
# x_ref[
# :, :, index_x_ref[0] : index_x_ref[1], index_x_ref[3] : index_x_ref[4]
# ][:, :, :: index_x_ref[2], :: index_x_ref[5]],
# 1e-2,
# 0,
# )


@pytest.mark.parametrize(
"padding_tested", ["circular", "constant", "symmetric", "reflect","replicate"]
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate"]
)
@pytest.mark.parametrize(
"input_shape, batch_size, kernel_size, filters",
Expand Down Expand Up @@ -167,7 +159,8 @@ def test_padding(padding_tested, input_shape, batch_size, kernel_size, filters):
reason="PadConv2d not available",
)
@pytest.mark.parametrize(
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"]
"padding_tested",
["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"],
)
@pytest.mark.parametrize(
"input_shape, batch_size, kernel_size, filters",
Expand Down Expand Up @@ -240,7 +233,8 @@ def test_predict(padding_tested, input_shape, batch_size, kernel_size, filters):
reason="PadConv2d not available",
)
@pytest.mark.parametrize(
"padding_tested", ["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"]
"padding_tested",
["circular", "constant", "symmetric", "reflect", "replicate", "same", "valid"],
)
@pytest.mark.parametrize(
"input_shape, batch_size, kernel_size, filters",
Expand Down
1 change: 0 additions & 1 deletion tests/test_updownsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@


def check_downsample(x, y, kernel_size):
shape = uft.get_NCHW(x)
index = 0
for dx in range(kernel_size):
for dy in range(kernel_size):
Expand Down
5 changes: 3 additions & 2 deletions tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
"CategoricalHingeLoss",
"process_labels_for_multi_gpu",
"SpectralConv1d",
"LipResidual",
]


Expand Down Expand Up @@ -625,7 +626,7 @@ def is_supported_padding(padding):
"reflect",
"circular",
"symmetric",
'replicate'
"replicate",
] # "constant",


Expand All @@ -635,7 +636,7 @@ def pad_input(x, padding, kernel_size):
kernel_size = [kernel_size, kernel_size]
if padding.lower() in ["same", "valid"]:
return x
elif padding.lower() in ["constant", "reflect", "circular",'replicate']:
elif padding.lower() in ["constant", "reflect", "circular", "replicate"]:
p_vert, p_hor = kernel_size[0] // 2, kernel_size[1] // 2
pad_sizes = [
p_hor,
Expand Down

0 comments on commit ae8366c

Please sign in to comment.