Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add several new layers to be able to support Resnet like architectures #23

Merged
merged 27 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
73467a2
update groupsort and groupsort2 activ to be consistent with deellip (…
Oct 16, 2024
c8bdf64
change to_NCHW to process numpy arrays
Oct 16, 2024
c38ebb6
Add HouseHolder activation and adapt pytest
Oct 19, 2024
fc0a3ba
update shape information in test_activation
Oct 19, 2024
d0820eb
add ScaledGlobalL2NormPool2d class
Oct 20, 2024
4eba8f2
add SpectralConvTranspose2d layer and update test for tf compatibility
Oct 21, 2024
30922f9
cleaning
Oct 21, 2024
01d8b4a
layer and batch centering
Jun 17, 2024
defbcf6
update layer and batch centering + pytests
franckma31 Oct 22, 2024
ad4fc38
simplify householder reshape computation + correct test
franckma31 Oct 22, 2024
7305a91
clean normalization
franckma31 Oct 22, 2024
493156a
add description comments on added layers
franckma31 Oct 22, 2024
184bdd4
spectral conv 1d
Jun 17, 2024
60961af
add support for spectralConv1d
franckma31 Oct 28, 2024
4ac8ff6
add padconv support + tests
franckma31 Oct 28, 2024
c8ca131
modify InvertibleUpDownSampling t o use torch PixelShuffle and Unshuf…
franckma31 Oct 29, 2024
7227db4
add a LipResidual layer and tests
franckma31 Nov 4, 2024
d19ad7d
add support for replicate padding + trick to change symmetric+1 paddi…
franckma31 Nov 7, 2024
ccdc261
clean and linter
franckma31 Nov 7, 2024
1fd81f4
update for compatibility with tensorflow tests
franckma31 Dec 2, 2024
58d8fba
updates on docs et notebook + resolve issue #17
franckma31 Dec 2, 2024
c6bf80f
update vanilla export with PixelUnshuffle variable in InvertibleUp(Do…
franckma31 Dec 9, 2024
16b0059
modify vanilla_model functio to support parametrization
franckma31 Dec 9, 2024
34105b9
update to solve comment https://github.com/deel-ai/deel-torchlip/pull…
franckma31 Dec 9, 2024
f5e28ce
use LPPool2d in ScaledL2NormPool2d + pooling tests
franckma31 Dec 9, 2024
a8721bb
use lp_lppool2d in ScaledAdaptativeL2NormPool2d ( https://github.com/…
franckma31 Dec 9, 2024
76a6e69
remove same case for padding_mode to solve https://github.com/deel-ai…
franckma31 Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions deel/torchlip/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,30 +212,38 @@ def max_min(input: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
return torch.cat((F.relu(input), F.relu(-input)), dim=dim)


def group_sort(input: torch.Tensor, group_size: Optional[int] = None) -> torch.Tensor:
def group_sort(
input: torch.Tensor, group_size: Optional[int] = None, dim: int = 1
) -> torch.Tensor:
r"""
Applies GroupSort activation on the given tensor.

See Also:
:py:func:`group_sort_2`
:py:func:`full_sort`
"""
if group_size is None or group_size > input.shape[1]:
group_size = input.shape[1]

if input.shape[1] % group_size != 0:
if group_size is None or group_size > input.shape[dim]:
group_size = input.shape[dim]

if input.shape[dim] % group_size != 0:
raise ValueError("The input size must be a multiple of the group size.")

fv = input.reshape([-1, group_size])
new_shape = (
input.shape[:dim]
+ (input.shape[dim] // group_size, group_size)
+ input.shape[dim + 1 :]
)
if group_size == 2:
sfv = torch.chunk(fv, 2, 1)
b = sfv[0]
c = sfv[1]
newv = torch.cat((torch.min(b, c), torch.max(b, c)), dim=1)
newv = newv.reshape(input.shape)
return newv
resh_input = input.view(new_shape)
a, b = (
torch.min(resh_input, dim + 1, keepdim=True)[0],
torch.max(resh_input, dim + 1, keepdim=True)[0],
)
return torch.cat([a, b], dim=dim + 1).view(input.shape)
fv = input.reshape(new_shape)

return torch.sort(fv)[0].reshape(input.shape)
return torch.sort(fv, dim=dim + 1)[0].reshape(input.shape)


def group_sort_2(input: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -568,3 +576,42 @@ def process_labels_for_multi_gpu(labels: torch.Tensor) -> torch.Tensor:
# Since element-wise KR terms are averaged by loss reduction later on, it is needed
# to multiply by batch_size here.
return torch.where(labels > 0, pos_factor, neg_factor)


class SymmetricPad(torch.nn.Module):
"""
Pads a 2D tensor symmetrically.

Args:
pad (tuple): A tuple (pad_left, pad_right, pad_top, pad_bottom) specifying
the number of pixels to pad on each side. (or single int if
common padding).

onedim: False for conv2d, True for conv1d.

"""

def __init__(self, pad, onedim=False):
super().__init__()
self.onedim = onedim
num_dim = 2 if onedim else 4
if isinstance(pad, int):
self.pad = (pad,) * num_dim
else:
self.pad = torch.nn.modules.utils._reverse_repeat_tuple(pad, 2)
assert len(self.pad) == num_dim, f"Pad must be a tuple of {num_dim} integers"

def forward(self, x):

# Horizontal padding
left = x[:, ..., : self.pad[0]].flip(dims=[-1])
right = x[:, ..., -self.pad[1] :].flip(dims=[-1])
x = torch.cat([left, x, right], dim=-1)
if self.onedim:
return x
# Vertical padding
top = x[:, :, : self.pad[2], :].flip(dims=[-2])
bottom = x[:, :, -self.pad[3] :, :].flip(dims=[-2])
x = torch.cat([top, x, bottom], dim=-2)

return x
9 changes: 9 additions & 0 deletions deel/torchlip/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@
from .activation import FullSort
from .activation import GroupSort
from .activation import GroupSort2
from .activation import HouseHolder
from .activation import LPReLU
from .activation import MaxMin
from .conv import FrobeniusConv2d
from .conv import SpectralConv2d
from .conv import SpectralConv1d
from .conv import SpectralConvTranspose2d
from .downsampling import InvertibleDownSampling
from .linear import FrobeniusLinear
from .linear import SpectralLinear
Expand All @@ -72,4 +75,10 @@
from .pooling import ScaledAdaptiveAvgPool2d
from .pooling import ScaledAvgPool2d
from .pooling import ScaledL2NormPool2d
from .pooling import ScaledGlobalL2NormPool2d
from .upsampling import InvertibleUpSampling
from .normalization import LayerCentering
from .normalization import BatchCentering
from .unconstrained import PadConv2d
from .unconstrained import PadConv1d
from .residual import LipResidual
53 changes: 53 additions & 0 deletions deel/torchlip/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import torch
import torch.nn as nn
import numpy as np

from .. import functional as F
from .module import LipschitzModule
Expand Down Expand Up @@ -211,3 +212,55 @@ def vanilla_export(self):
layer = LPReLU(num_parameters=self.num_parameters)
layer.weight.data = self.weight.data
return layer


class HouseHolder(nn.Module, LipschitzModule):
thib-s marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, channels, k_coef_lip: float = 1.0, theta_initializer=None):
"""
Householder activation:
[this review](https://openreview.net/pdf?id=tD7eCtaSkR)
Adapted from [this repository](https://github.com/singlasahil14/SOC)
"""
nn.Module.__init__(self)
LipschitzModule.__init__(self, k_coef_lip)
assert (channels % 2) == 0
eff_channels = channels // 2

if isinstance(theta_initializer, float):
coef_theta = theta_initializer
else:
coef_theta = 0.5 * np.pi
self.theta = nn.Parameter(
coef_theta * torch.ones(eff_channels), requires_grad=True
)
if theta_initializer is not None:
if isinstance(theta_initializer, str):
name2init = {
"zeros": torch.nn.init.zeros_,
"ones": torch.nn.init.ones_,
"normal": torch.nn.init.normal_,
}
assert (
theta_initializer in name2init
), f"Unknown initializer {theta_initializer}"
name2init[theta_initializer](self.theta)
elif isinstance(theta_initializer, float):
pass
else:
raise ValueError(f"Unknown initializer {theta_initializer}")

def forward(self, z, axis=1):
theta_shape = (1, -1) + (1,) * (len(z.shape) - 2)
theta = self.theta.to(z.device).view(theta_shape)
thib-s marked this conversation as resolved.
Show resolved Hide resolved
x, y = z.split(z.shape[axis] // 2, axis)
selector = (x * torch.sin(0.5 * theta)) - (y * torch.cos(0.5 * theta))

a_2 = x * torch.cos(theta) + y * torch.sin(theta)
b_2 = x * torch.sin(theta) - y * torch.cos(theta)

a = x * (selector <= 0) + a_2 * (selector > 0)
b = y * (selector <= 0) + b_2 * (selector > 0)
return torch.cat([a, b], dim=axis)

def vanilla_export(self):
return self
Loading
Loading