Skip to content

Commit

Permalink
add support for spectralConv1d
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Oct 28, 2024
1 parent 27bd16c commit 08f82de
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 49 deletions.
1 change: 1 addition & 0 deletions deel/torchlip/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .activation import MaxMin
from .conv import FrobeniusConv2d
from .conv import SpectralConv2d
from .conv import SpectralConv1d
from .conv import SpectralConvTranspose2d
from .downsampling import InvertibleDownSampling
from .linear import FrobeniusLinear
Expand Down
70 changes: 35 additions & 35 deletions deel/torchlip/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def __init__(
bias: bool = True,
padding_mode: str = "zeros",
k_coef_lip: float = 1.0,
niter_spectral: int = DEFAULT_NITER_SPECTRAL,
niter_bjorck: int = DEFAULT_NITER_BJORCK,
eps_spectral: int = DEFAULT_EPS_SPECTRAL,
eps_bjorck: int = DEFAULT_EPS_BJORCK,
):
"""
This class is a Conv2d Layer constrained such that all singular of it's kernel
This class is a Conv1d Layer constrained such that all singular of it's kernel
are 1. The computation based on BjorckNormalizer algorithm. As this is not
enough to ensure 1-Lipschitz a coercive coefficient is applied on the
output.
Expand All @@ -79,10 +79,10 @@ def __init__(
bias (bool, optional): If ``True``, adds a learnable bias to the
output.
k_coef_lip: Lipschitz constant to ensure.
niter_spectral: Number of iteration to find the maximum singular value.
niter_bjorck: Number of iteration with BjorckNormalizer algorithm.
eps_spectral: stopping criterion for the iterative power algorithm.
eps_bjorck: stopping criterion Bjorck algorithm.
This documentation reuse the body of the original torch.nn.Conv2D doc.
This documentation reuse the body of the original torch.nn.Conv1D doc.
"""
# if not ((dilation == (1, 1)) or (dilation == [1, 1]) or (dilation == 1)):
# raise RuntimeError("NormalizedConv does not support dilation rate")
Expand All @@ -108,11 +108,11 @@ def __init__(
spectral_norm(
self,
name="weight",
n_power_iterations=niter_spectral,
eps=eps_spectral,
)
bjorck_norm(self, name="weight", n_iterations=niter_bjorck)
bjorck_norm(self, name="weight", eps=eps_bjorck)
lconv_norm(self)
self.register_forward_pre_hook(self._hook)
self.apply_lipschitz_factor()

def vanilla_export(self):
layer = torch.nn.Conv1d(
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(
the input.
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
dilation (int or tuple, optional): Spacing between kernel elements.
dilation (int or tuple, optional): Spacing between kernel elements.
Has to be one
groups (int, optional): Number of blocked connections from input
channels to output channels. Has to be one
Expand Down Expand Up @@ -290,31 +290,31 @@ def vanilla_export(self):

class SpectralConvTranspose2d(torch.nn.ConvTranspose2d, LipschitzModule):
r"""Applies a 2D transposed convolution operator over an input image
such that all singular of it's kernel are 1.
The computation are the same as for SpectralConv2d layer
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution.
padding (int or tuple, optional): Zero-padding added to both sides of
the input.
output_padding: only 0 or none are supported
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
dilation (int or tuple, optional): Spacing between kernel elements.
Has to be one.
groups (int, optional): Number of blocked connections from input
channels to output channels. Has to be one.
bias (bool, optional): If ``True``, adds a learnable bias to the
output.
k_coef_lip: Lipschitz constant to ensure.
eps_spectral: stopping criterion for the iterative power algorithm.
eps_bjorck: stopping criterion Bjorck algorithm.
This documentation reuse the body of the original torch.nn.ConvTranspose2d
doc.
such that all singular of it's kernel are 1.
The computation are the same as for SpectralConv2d layer
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution.
padding (int or tuple, optional): Zero-padding added to both sides of
the input.
output_padding: only 0 or none are supported
padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
dilation (int or tuple, optional): Spacing between kernel elements.
Has to be one.
groups (int, optional): Number of blocked connections from input
channels to output channels. Has to be one.
bias (bool, optional): If ``True``, adds a learnable bias to the
output.
k_coef_lip: Lipschitz constant to ensure.
eps_spectral: stopping criterion for the iterative power algorithm.
eps_bjorck: stopping criterion Bjorck algorithm.
This documentation reuse the body of the original torch.nn.ConvTranspose2d
doc.
"""

def __init__(
Expand Down
20 changes: 11 additions & 9 deletions deel/torchlip/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@ class LayerCentering(nn.Module):
`\beta` is learnable bias parameter. that can be
applied after the mean subtraction.
Unlike Layer Normalization, this layer is 1-Lipschitz
This layer uses statistics computed from input data in
This layer uses statistics computed from input data in
both training and evaluation modes.
Args:
size: number of features in the input tensor
dim: dimensions over which to compute the mean
dim: dimensions over which to compute the mean
(default ``input.mean((-2, -1))`` for a 4D tensor).
bias: if `True`, adds a learnable bias to the output
of shape (size,). Default: `True`
Shape:
- Input: :math:`(N, size, *)`
- Output: :math:`(N, size, *)` (same shape as input)
"""

def __init__(self, size: int = 1, dim: tuple = [-2, -1], bias=True):
super(LayerCentering, self).__init__()
if bias:
Expand Down Expand Up @@ -60,28 +61,29 @@ class BatchCentering(nn.Module):
y = x - \mathrm{E}[x] + \beta
The mean is calculated per-dimension over the mini-batchesa and
other dimensions excepted the feature/channel dimension.
This layer uses statistics computed from input data in
other dimensions excepted the feature/channel dimension.
This layer uses statistics computed from input data in
training mode and a constant in evaluation mode computed as
the running mean on training samples.
:math:`\beta` is a learnable parameter vectors
of size `C` (where `C` is the number of features or channels of the input).
that can be applied after the mean subtraction.
Unlike Batch Normalization, this layer is 1-Lipschitz
Args:
size: number of features in the input tensor
dim: dimensions over which to compute the mean
dim: dimensions over which to compute the mean
(default ``input.mean((0, -2, -1))`` for a 4D tensor).
momentum: the value used for the running mean computation
bias: if `True`, adds a learnable bias to the output
of shape (size,). Default: `True`
Shape:
- Input: :math:`(N, size, *)`
- Output: :math:`(N, size, *)` (same shape as input)
"""

def __init__(
self,
size: int = 1,
Expand Down
46 changes: 42 additions & 4 deletions deel/torchlip/utils/lconv_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,54 @@
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
from typing import Tuple
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize


def compute_lconv_coef_1d(
kernel_size: Tuple[int],
input_shape: Tuple[int] = None,
strides: Tuple[int] = (1,),
padding_mode: str = "zeros",
) -> float:
stride = strides[0]
k1 = kernel_size[0]

if (
(padding_mode in ["zeros", "same"])
and (stride == 1)
and (input_shape is not None)
):
# See https://arxiv.org/abs/2006.06520
l = input_shape[-1]
k1_div2 = (k1 - 1) / 2
coefLip = l / (k1 * l - k1_div2 * (k1_div2 + 1))
else:
sn1 = strides[0]
coefLip = 1.0 / np.ceil(k1 / sn1)

return coefLip # type: ignore


def compute_lconv_coef(
kernel_size: Tuple[int, ...],
input_shape: Tuple[int, ...] = None,
strides: Tuple[int, ...] = (1, 1),
padding_mode: str = "zeros",
) -> float:
# See https://arxiv.org/abs/2006.06520
stride = np.prod(strides)
k1, k2 = kernel_size

if stride == 1 and input_shape is not None:
if (
(padding_mode in ["zeros", "same"])
and (stride == 1)
and (input_shape is not None)
):
h, w = input_shape[-2:]
k1_div2 = (k1 - 1) / 2
k2_div2 = (k2 - 1) / 2
Expand All @@ -68,7 +98,10 @@ def forward(self, weight: torch.Tensor) -> torch.Tensor:
return weight * self.lconv_coefficient


def lconv_norm(module: torch.nn.Conv2d, name: str = "weight") -> torch.nn.Conv2d:
ConvType = Union[torch.nn.Conv2d, torch.nn.Conv1d]


def lconv_norm(module: ConvType, name: str = "weight") -> ConvType:
r"""
Applies Lipschitz normalization to a kernel in the given convolutional.
This is implemented via a hook that multiplies the kernel by a value computed
Expand All @@ -80,6 +113,7 @@ def lconv_norm(module: torch.nn.Conv2d, name: str = "weight") -> torch.nn.Conv2d
Args:
module: Containing module.
name: Name of weight parameter.
onedim: False for conv2d, True for conv1d.
Returns:
The original module with the Lipschitz normalization hook.
Expand All @@ -91,7 +125,11 @@ def lconv_norm(module: torch.nn.Conv2d, name: str = "weight") -> torch.nn.Conv2d
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
"""
coefficient = compute_lconv_coef(module.kernel_size, None, module.stride)
onedim = isinstance(module, torch.nn.Conv1d)
if onedim:
coefficient = compute_lconv_coef_1d(module.kernel_size, None, module.stride)
else:
coefficient = compute_lconv_coef(module.kernel_size, None, module.stride)
parametrize.register_parametrization(module, name, _LConvNorm(coefficient))
return module

Expand Down
54 changes: 53 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .utils_framework import (
SpectralLinear,
SpectralConv2d,
SpectralConv1d,
SpectralConvTranspose2d,
FrobeniusLinear,
FrobeniusConv2d,
Expand Down Expand Up @@ -277,7 +278,7 @@ def _apply_tests_bank(test_params):
) = train_k_lip_model(**test_params)
print("test mse: %f" % mse)
print(
"empirical lip const: %f ( expected %s )"
"empirical lip const: %f ( expected min data and model %s )"
% (
emp_lip_const,
min(test_params["k_lip_model"], test_params["k_lip_data"]),
Expand Down Expand Up @@ -586,6 +587,57 @@ def test_spectralconv2d(test_params):
_apply_tests_bank(test_params)


@pytest.mark.skipif(
hasattr(SpectralConv1d, "unavailable_class"),
reason="SpectralConv1d not available",
)
@pytest.mark.parametrize(
"test_params",
[
dict(
layer_type=SpectralConv1d,
layer_params={
"in_channels": 1,
"out_channels": 2,
"kernel_size": 3,
"bias": False,
},
batch_size=250,
steps_per_epoch=125,
epochs=5,
input_shape=(1, 5),
k_lip_data=1.0,
k_lip_model=1.0,
callbacks=[],
),
dict(
layer_type=SpectralConv1d,
layer_params={"in_channels": 1, "out_channels": 2, "kernel_size": 3},
batch_size=250,
steps_per_epoch=125,
epochs=5,
input_shape=(1, 5),
k_lip_data=5.0,
k_lip_model=1.0,
callbacks=[],
),
dict(
layer_type=SpectralConv1d,
layer_params={"in_channels": 1, "out_channels": 2, "kernel_size": 3},
batch_size=250,
steps_per_epoch=125,
epochs=5,
input_shape=(1, 5),
k_lip_data=1.0,
k_lip_model=5.0,
callbacks=[],
),
],
)
def test_spectralconv1d(test_params):
_apply_tests_bank(test_params)


@pytest.mark.skipif(
hasattr(SpectralConvTranspose2d, "unavailable_class"),
reason="SpectralConvTranspose2d not available",
Expand Down
2 changes: 2 additions & 0 deletions tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from deel.torchlip.modules import LipschitzModule as LipschitzLayer
from deel.torchlip.modules import SpectralLinear
from deel.torchlip.modules import SpectralConv2d
from deel.torchlip.modules import SpectralConv1d
from deel.torchlip.modules import SpectralConvTranspose2d
from deel.torchlip.modules import FrobeniusLinear
from deel.torchlip.modules import FrobeniusConv2d
Expand Down Expand Up @@ -123,6 +124,7 @@
"tReshape",
"CategoricalHingeLoss",
"process_labels_for_multi_gpu",
"SpectralConv1d",
]


Expand Down

0 comments on commit 08f82de

Please sign in to comment.