From 0fed7bf5c49386715d8f44f3cf76cf0c237482e7 Mon Sep 17 00:00:00 2001 From: Franck Mamalet <49721198+franckma31@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:45:28 +0100 Subject: [PATCH] modify InvertibleUpDownSampling t o use torch PixelShuffle and Unshuffle; modify pytest (warning only support 2D inputs, and single value kernel size) --- deel/torchlip/modules/downsampling.py | 28 +++++-- deel/torchlip/modules/upsampling.py | 17 ++--- tests/test_layers.py | 4 +- tests/test_updownsampling.py | 103 ++++++++++++++++---------- tests/utils_framework.py | 2 +- 5 files changed, 93 insertions(+), 61 deletions(-) diff --git a/deel/torchlip/modules/downsampling.py b/deel/torchlip/modules/downsampling.py index 7c00fc8..7c3db9c 100644 --- a/deel/torchlip/modules/downsampling.py +++ b/deel/torchlip/modules/downsampling.py @@ -32,14 +32,26 @@ from .module import LipschitzModule -class InvertibleDownSampling(torch.nn.Module, LipschitzModule): - def __init__(self, kernel_size: Tuple[int, int], k_coef_lip: float = 1.0): - torch.nn.Module.__init__(self) +class InvertibleDownSampling(torch.nn.PixelUnshuffle, LipschitzModule): + def __init__(self, kernel_size: int, k_coef_lip: float = 1.0): + torch.nn.PixelUnshuffle.__init__(self, kernel_size) LipschitzModule.__init__(self, k_coef_lip) - self.kernel_size = kernel_size - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.invertible_downsample(input, self.kernel_size) * self._coefficient_lip def vanilla_export(self): - return self + if self._coefficient_lip == 1.0: + return torch.nn.PixelUnshuffle(self.kernel_size) + else: + return self + + +# class InvertibleDownSampling(torch.nn.Module, LipschitzModule): +# def __init__(self, kernel_size: Tuple[int, int], k_coef_lip: float = 1.0): +# torch.nn.Module.__init__(self) +# LipschitzModule.__init__(self, k_coef_lip) +# self.kernel_size = kernel_size + +# def forward(self, input: torch.Tensor) -> torch.Tensor: +# return F.invertible_downsample(input, self.kernel_size) * self._coefficient_lip + +# def vanilla_export(self): +# return self diff --git a/deel/torchlip/modules/upsampling.py b/deel/torchlip/modules/upsampling.py index 7f40354..62aeb56 100644 --- a/deel/torchlip/modules/upsampling.py +++ b/deel/torchlip/modules/upsampling.py @@ -33,16 +33,13 @@ from .module import LipschitzModule -class InvertibleUpSampling(torch.nn.Module, LipschitzModule): - def __init__( - self, kernel_size: Union[int, Tuple[int, ...]], k_coef_lip: float = 1.0 - ): - torch.nn.Module.__init__(self) +class InvertibleUpSampling(torch.nn.PixelShuffle, LipschitzModule): + def __init__(self, kernel_size: int, k_coef_lip: float = 1.0): + torch.nn.PixelShuffle.__init__(self, kernel_size) LipschitzModule.__init__(self, k_coef_lip) - self.kernel_size = kernel_size - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.invertible_upsample(input, self.kernel_size) * self._coefficient_lip def vanilla_export(self): - return self + if self._coefficient_lip == 1.0: + return torch.nn.PixelShuffle(self.kernel_size) + else: + return self diff --git a/tests/test_layers.py b/tests/test_layers.py index 7fbf288..a69f202 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1143,7 +1143,7 @@ def test_callbacks(test_params): [ dict( layer_type=InvertibleDownSampling, - layer_params={"kernel_size": (2, 3)}, + layer_params={"kernel_size": 3}, batch_size=250, steps_per_epoch=1, epochs=5, @@ -1164,7 +1164,7 @@ def test_invertibledownsampling(test_params): [ dict( layer_type=InvertibleUpSampling, - layer_params={"kernel_size": (2, 3)}, + layer_params={"kernel_size": 3}, batch_size=250, steps_per_epoch=1, epochs=5, diff --git a/tests/test_updownsampling.py b/tests/test_updownsampling.py index 7f0dedb..ab085b1 100644 --- a/tests/test_updownsampling.py +++ b/tests/test_updownsampling.py @@ -29,69 +29,92 @@ import numpy as np from . import utils_framework as uft -from .utils_framework import invertible_downsample, invertible_upsample +from .utils_framework import InvertibleDownSampling, InvertibleUpSampling + + +def check_downsample(x, y, kernel_size): + shape = uft.get_NCHW(x) + index = 0 + for dx in range(kernel_size): + for dy in range(kernel_size): + xx = x[:, :, dx::kernel_size, dy::kernel_size] + yy = y[:, index :: (kernel_size * kernel_size), :, :] + np.testing.assert_array_equal(xx, yy) + index += 1 @pytest.mark.skipif( - hasattr(invertible_downsample, "unavailable_class"), - reason="invertible_downsample not available", + hasattr(InvertibleDownSampling, "unavailable_class"), + reason="InvertibleDownSampling not available", ) def test_invertible_downsample(): - # 1D input - x = uft.to_tensor([[[1, 2, 3, 4], [5, 6, 7, 8]]]) - x = uft.get_instance_framework( - invertible_downsample, {"input": x, "kernel_size": (2,)} - ) - assert x.shape == (1, 4, 2) - # TODO: Check this. - np.testing.assert_equal(uft.to_numpy(x), [[[1, 2], [3, 4], [5, 6], [7, 8]]]) + x = np.arange(32).reshape(1, 2, 4, 4) + x = uft.to_tensor(x) + dw_layer = uft.get_instance_framework(InvertibleDownSampling, {"kernel_size": 2}) + y = dw_layer(x) + assert y.shape == (1, 8, 2, 2) + check_downsample(x, y, 2) # 2D input x = np.random.rand(10, 1, 128, 128) # torch.rand(10, 1, 128, 128) x = uft.to_tensor(x) - assert invertible_downsample(x, (4, 4)).shape == (10, 16, 32, 32) - x = np.random.rand(10, 4, 64, 64) - x = uft.to_tensor(x) - assert invertible_downsample(x, (2, 2)).shape == (10, 16, 32, 32) + dw_layer = uft.get_instance_framework(InvertibleDownSampling, {"kernel_size": 4}) + y = dw_layer(x) + assert y.shape == (10, 16, 32, 32) + check_downsample(x, y, 4) - # 3D input - x = np.random.rand(10, 2, 128, 64, 64) + x = np.random.rand(10, 4, 64, 64) x = uft.to_tensor(x) - assert invertible_downsample(x, 2).shape == (10, 16, 64, 32, 32) + dw_layer = uft.get_instance_framework(InvertibleDownSampling, {"kernel_size": 2}) + y = dw_layer(x) + assert y.shape == (10, 16, 32, 32) + check_downsample(x, y, 2) @pytest.mark.skipif( - hasattr(invertible_upsample, "unavailable_class"), - reason="invertible_upsample not available", + hasattr(InvertibleUpSampling, "unavailable_class"), + reason="InvertibleUpSampling not available", ) def test_invertible_upsample(): - # 1D input - x = uft.to_tensor([[[1, 2], [3, 4], [5, 6], [7, 8]]]) - x = uft.get_instance_framework( - invertible_upsample, {"input": x, "kernel_size": (2,)} - ) - - assert x.shape == (1, 2, 4) - - # Check output. - np.testing.assert_equal(uft.to_numpy(x), [[[1, 2, 3, 4], [5, 6, 7, 8]]]) # 2D input x = np.random.rand(10, 16, 32, 32) x = uft.to_tensor(x) - y = uft.get_instance_framework( - invertible_upsample, {"input": x, "kernel_size": (4, 4)} - ) + dw_layer = uft.get_instance_framework(InvertibleUpSampling, {"kernel_size": 4}) + y = dw_layer(x) assert y.shape == (10, 1, 128, 128) - y = uft.get_instance_framework( - invertible_upsample, {"input": x, "kernel_size": (2, 2)} - ) + check_downsample(y, x, 4) + + dw_layer = uft.get_instance_framework(InvertibleUpSampling, {"kernel_size": 2}) + y = dw_layer(x) assert y.shape == (10, 4, 64, 64) + check_downsample(y, x, 2) - # 3D input - x = np.random.rand(10, 16, 64, 32, 32) + +@pytest.mark.skipif( + hasattr(InvertibleUpSampling, "unavailable_class") + or hasattr(InvertibleDownSampling, "unavailable_class"), + reason="InvertibleUpSampling not available", +) +def test_invertible_upsample_downsample(): + x = np.random.rand(10, 16, 32, 32) x = uft.to_tensor(x) - y = uft.get_instance_framework(invertible_upsample, {"input": x, "kernel_size": 2}) - assert y.shape == (10, 2, 128, 64, 64) + up_layer = uft.get_instance_framework(InvertibleUpSampling, {"kernel_size": 4}) + y = up_layer(x) + + dw_layer = uft.get_instance_framework(InvertibleDownSampling, {"kernel_size": 4}) + z = dw_layer(y) + assert z.shape == x.shape + np.testing.assert_array_equal(x, z) + + x = np.random.rand(10, 1, 128, 128) # torch.rand(10, 1, 128, 128) + x = uft.to_tensor(x) + + dw_layer = uft.get_instance_framework(InvertibleDownSampling, {"kernel_size": 4}) + y = dw_layer(x) + up_layer = uft.get_instance_framework(InvertibleUpSampling, {"kernel_size": 4}) + z = up_layer(y) + assert z.shape == x.shape + np.testing.assert_array_equal(x, z) diff --git a/tests/utils_framework.py b/tests/utils_framework.py index d9e8944..035fbd0 100644 --- a/tests/utils_framework.py +++ b/tests/utils_framework.py @@ -557,7 +557,7 @@ def to_NCHW_inv(x): def get_NCHW(x): - return (x.shape[0], x.shape[1], x.shape[2], x.shape[3]) + return (x.shape[-4], x.shape[-3], x.shape[-2], x.shape[-1]) def scaleAlpha(alpha):