Skip to content

Commit

Permalink
modify InvertibleUpDownSampling t o use torch PixelShuffle and Unshuf…
Browse files Browse the repository at this point in the history
…fle; modify pytest (warning only support 2D inputs, and single value kernel size)
  • Loading branch information
franckma31 committed Oct 29, 2024
1 parent 852baa5 commit 0fed7bf
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 61 deletions.
28 changes: 20 additions & 8 deletions deel/torchlip/modules/downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,26 @@
from .module import LipschitzModule


class InvertibleDownSampling(torch.nn.Module, LipschitzModule):
def __init__(self, kernel_size: Tuple[int, int], k_coef_lip: float = 1.0):
torch.nn.Module.__init__(self)
class InvertibleDownSampling(torch.nn.PixelUnshuffle, LipschitzModule):
def __init__(self, kernel_size: int, k_coef_lip: float = 1.0):
torch.nn.PixelUnshuffle.__init__(self, kernel_size)
LipschitzModule.__init__(self, k_coef_lip)
self.kernel_size = kernel_size

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.invertible_downsample(input, self.kernel_size) * self._coefficient_lip

def vanilla_export(self):
return self
if self._coefficient_lip == 1.0:
return torch.nn.PixelUnshuffle(self.kernel_size)
else:
return self


# class InvertibleDownSampling(torch.nn.Module, LipschitzModule):
# def __init__(self, kernel_size: Tuple[int, int], k_coef_lip: float = 1.0):
# torch.nn.Module.__init__(self)
# LipschitzModule.__init__(self, k_coef_lip)
# self.kernel_size = kernel_size

# def forward(self, input: torch.Tensor) -> torch.Tensor:
# return F.invertible_downsample(input, self.kernel_size) * self._coefficient_lip

# def vanilla_export(self):
# return self
17 changes: 7 additions & 10 deletions deel/torchlip/modules/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,13 @@
from .module import LipschitzModule


class InvertibleUpSampling(torch.nn.Module, LipschitzModule):
def __init__(
self, kernel_size: Union[int, Tuple[int, ...]], k_coef_lip: float = 1.0
):
torch.nn.Module.__init__(self)
class InvertibleUpSampling(torch.nn.PixelShuffle, LipschitzModule):
def __init__(self, kernel_size: int, k_coef_lip: float = 1.0):
torch.nn.PixelShuffle.__init__(self, kernel_size)
LipschitzModule.__init__(self, k_coef_lip)
self.kernel_size = kernel_size

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.invertible_upsample(input, self.kernel_size) * self._coefficient_lip

def vanilla_export(self):
return self
if self._coefficient_lip == 1.0:
return torch.nn.PixelShuffle(self.kernel_size)
else:
return self
4 changes: 2 additions & 2 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ def test_callbacks(test_params):
[
dict(
layer_type=InvertibleDownSampling,
layer_params={"kernel_size": (2, 3)},
layer_params={"kernel_size": 3},
batch_size=250,
steps_per_epoch=1,
epochs=5,
Expand All @@ -1164,7 +1164,7 @@ def test_invertibledownsampling(test_params):
[
dict(
layer_type=InvertibleUpSampling,
layer_params={"kernel_size": (2, 3)},
layer_params={"kernel_size": 3},
batch_size=250,
steps_per_epoch=1,
epochs=5,
Expand Down
103 changes: 63 additions & 40 deletions tests/test_updownsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,69 +29,92 @@
import numpy as np

from . import utils_framework as uft
from .utils_framework import invertible_downsample, invertible_upsample
from .utils_framework import InvertibleDownSampling, InvertibleUpSampling


def check_downsample(x, y, kernel_size):
shape = uft.get_NCHW(x)
index = 0
for dx in range(kernel_size):
for dy in range(kernel_size):
xx = x[:, :, dx::kernel_size, dy::kernel_size]
yy = y[:, index :: (kernel_size * kernel_size), :, :]
np.testing.assert_array_equal(xx, yy)
index += 1


@pytest.mark.skipif(
hasattr(invertible_downsample, "unavailable_class"),
reason="invertible_downsample not available",
hasattr(InvertibleDownSampling, "unavailable_class"),
reason="InvertibleDownSampling not available",
)
def test_invertible_downsample():
# 1D input
x = uft.to_tensor([[[1, 2, 3, 4], [5, 6, 7, 8]]])
x = uft.get_instance_framework(
invertible_downsample, {"input": x, "kernel_size": (2,)}
)
assert x.shape == (1, 4, 2)

# TODO: Check this.
np.testing.assert_equal(uft.to_numpy(x), [[[1, 2], [3, 4], [5, 6], [7, 8]]])
x = np.arange(32).reshape(1, 2, 4, 4)
x = uft.to_tensor(x)
dw_layer = uft.get_instance_framework(InvertibleDownSampling, {"kernel_size": 2})
y = dw_layer(x)
assert y.shape == (1, 8, 2, 2)
check_downsample(x, y, 2)

# 2D input
x = np.random.rand(10, 1, 128, 128) # torch.rand(10, 1, 128, 128)
x = uft.to_tensor(x)
assert invertible_downsample(x, (4, 4)).shape == (10, 16, 32, 32)

x = np.random.rand(10, 4, 64, 64)
x = uft.to_tensor(x)
assert invertible_downsample(x, (2, 2)).shape == (10, 16, 32, 32)
dw_layer = uft.get_instance_framework(InvertibleDownSampling, {"kernel_size": 4})
y = dw_layer(x)
assert y.shape == (10, 16, 32, 32)
check_downsample(x, y, 4)

# 3D input
x = np.random.rand(10, 2, 128, 64, 64)
x = np.random.rand(10, 4, 64, 64)
x = uft.to_tensor(x)
assert invertible_downsample(x, 2).shape == (10, 16, 64, 32, 32)
dw_layer = uft.get_instance_framework(InvertibleDownSampling, {"kernel_size": 2})
y = dw_layer(x)
assert y.shape == (10, 16, 32, 32)
check_downsample(x, y, 2)


@pytest.mark.skipif(
hasattr(invertible_upsample, "unavailable_class"),
reason="invertible_upsample not available",
hasattr(InvertibleUpSampling, "unavailable_class"),
reason="InvertibleUpSampling not available",
)
def test_invertible_upsample():
# 1D input
x = uft.to_tensor([[[1, 2], [3, 4], [5, 6], [7, 8]]])
x = uft.get_instance_framework(
invertible_upsample, {"input": x, "kernel_size": (2,)}
)

assert x.shape == (1, 2, 4)

# Check output.
np.testing.assert_equal(uft.to_numpy(x), [[[1, 2, 3, 4], [5, 6, 7, 8]]])

# 2D input
x = np.random.rand(10, 16, 32, 32)
x = uft.to_tensor(x)
y = uft.get_instance_framework(
invertible_upsample, {"input": x, "kernel_size": (4, 4)}
)
dw_layer = uft.get_instance_framework(InvertibleUpSampling, {"kernel_size": 4})
y = dw_layer(x)
assert y.shape == (10, 1, 128, 128)
y = uft.get_instance_framework(
invertible_upsample, {"input": x, "kernel_size": (2, 2)}
)
check_downsample(y, x, 4)

dw_layer = uft.get_instance_framework(InvertibleUpSampling, {"kernel_size": 2})
y = dw_layer(x)
assert y.shape == (10, 4, 64, 64)
check_downsample(y, x, 2)

# 3D input
x = np.random.rand(10, 16, 64, 32, 32)

@pytest.mark.skipif(
hasattr(InvertibleUpSampling, "unavailable_class")
or hasattr(InvertibleDownSampling, "unavailable_class"),
reason="InvertibleUpSampling not available",
)
def test_invertible_upsample_downsample():
x = np.random.rand(10, 16, 32, 32)
x = uft.to_tensor(x)
y = uft.get_instance_framework(invertible_upsample, {"input": x, "kernel_size": 2})
assert y.shape == (10, 2, 128, 64, 64)
up_layer = uft.get_instance_framework(InvertibleUpSampling, {"kernel_size": 4})
y = up_layer(x)

dw_layer = uft.get_instance_framework(InvertibleDownSampling, {"kernel_size": 4})
z = dw_layer(y)
assert z.shape == x.shape
np.testing.assert_array_equal(x, z)

x = np.random.rand(10, 1, 128, 128) # torch.rand(10, 1, 128, 128)
x = uft.to_tensor(x)

dw_layer = uft.get_instance_framework(InvertibleDownSampling, {"kernel_size": 4})
y = dw_layer(x)
up_layer = uft.get_instance_framework(InvertibleUpSampling, {"kernel_size": 4})
z = up_layer(y)
assert z.shape == x.shape
np.testing.assert_array_equal(x, z)
2 changes: 1 addition & 1 deletion tests/utils_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def to_NCHW_inv(x):


def get_NCHW(x):
return (x.shape[0], x.shape[1], x.shape[2], x.shape[3])
return (x.shape[-4], x.shape[-3], x.shape[-2], x.shape[-1])


def scaleAlpha(alpha):
Expand Down

0 comments on commit 0fed7bf

Please sign in to comment.