From a8a804947d2e5baec795798b739ccafb4b752e19 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Sun, 5 May 2024 21:22:42 +0200 Subject: [PATCH 01/18] Update cyclegan.py Removed pool layers --- deeplay/models/discriminators/cyclegan.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deeplay/models/discriminators/cyclegan.py b/deeplay/models/discriminators/cyclegan.py index 9262ce19..89469eaf 100644 --- a/deeplay/models/discriminators/cyclegan.py +++ b/deeplay/models/discriminators/cyclegan.py @@ -11,7 +11,10 @@ @ConvolutionalEncoder2d.register_style def cyclegan_discriminator(encoder: ConvolutionalEncoder2d): encoder[..., "layer"].configure(kernel_size=4, padding=1) - encoder["blocks", 1:-1].all.normalized(nn.InstanceNorm2d, mode="insert", after="layer") + encoder["blocks", 1:-1].all.normalized( + nn.InstanceNorm2d, mode="insert", after="layer" + ) + encoder["blocks", :].all.remove("pool", allow_missing=True) encoder["blocks", :-1].configure("activation", nn.LeakyReLU, negative_slope=0.2) encoder["blocks", :-2].configure(stride=2) From ce45c0b3e60002f834cccc3b77ce1abd31c08529 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Sun, 5 May 2024 21:24:22 +0200 Subject: [PATCH 02/18] Update cyclegan.py Instance norm should not be there in the last layer --- deeplay/models/generators/cyclegan.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deeplay/models/generators/cyclegan.py b/deeplay/models/generators/cyclegan.py index a8fd96d4..b6eb40a9 100644 --- a/deeplay/models/generators/cyclegan.py +++ b/deeplay/models/generators/cyclegan.py @@ -27,7 +27,9 @@ def cyclegan_resnet_encoder(encoder: ConvolutionalEncoder2d): @ConvolutionalDecoder2d.register_style def cyclegan_resnet_decoder(decoder: ConvolutionalDecoder2d): - decoder.normalized(Layer(nn.InstanceNorm2d)) + decoder["blocks", :-1].all.normalized( + nn.InstanceNorm2d, mode="insert", after="layer" + ) decoder.blocks.configure(order=["layer", "normalization", "activation"]) decoder.blocks[:-1].configure( "layer", nn.ConvTranspose2d, stride=2, output_padding=1 From e553bed2f6dc991631a74cd8fe87c5422d666817 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Sun, 5 May 2024 21:25:16 +0200 Subject: [PATCH 03/18] Update dcgan.py Spelling mistakes!!! (Because of which activations were not being applied properly) --- deeplay/models/discriminators/dcgan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplay/models/discriminators/dcgan.py b/deeplay/models/discriminators/dcgan.py index 51ac0739..4f4cb504 100644 --- a/deeplay/models/discriminators/dcgan.py +++ b/deeplay/models/discriminators/dcgan.py @@ -13,7 +13,7 @@ def dcgan_discriminator(encoder: ConvolutionalEncoder2d): encoder.blocks[-1].configure("layer", padding=0) encoder["blocks", :].all.remove("pool", allow_missing=True) encoder["blocks", 1:-1].all.normalized() - encoder["block", :-1].all.configure("actication", nn.LeakyReLU, negative_slope=0.2) + encoder["blocks", :-1].all.configure("activation", nn.LeakyReLU, negative_slope=0.2) encoder.blocks[-1].activation.configure(nn.Sigmoid) init = Normal( From ef67856cb53dcb2844da0d51acedb3ad923c7e33 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Sun, 5 May 2024 21:25:55 +0200 Subject: [PATCH 04/18] Update dcgan.py Changed the dimensions of the hidden channels --- deeplay/models/generators/dcgan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deeplay/models/generators/dcgan.py b/deeplay/models/generators/dcgan.py index 48045f40..14ebe65e 100644 --- a/deeplay/models/generators/dcgan.py +++ b/deeplay/models/generators/dcgan.py @@ -86,11 +86,11 @@ def __init__( class_conditioned_model: bool = False, embedding_dim: int = 100, num_classes: int = 10, - output_channels=None + output_channels=None, ): if output_channels is not None: out_channels = output_channels - + self.latent_dim = latent_dim self.output_channels = out_channels self.class_conditioned_model = class_conditioned_model @@ -104,10 +104,10 @@ def __init__( super().__init__( in_channels=in_channels, hidden_channels=[ + features_dim * 16, features_dim * 8, features_dim * 4, features_dim * 2, - features_dim * 1, ], out_channels=out_channels, out_activation=Layer(nn.Tanh), From 668719135bd9c1c9b5eca70f00197c477945edc7 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Sun, 5 May 2024 22:29:58 +0200 Subject: [PATCH 05/18] Update dcgan.py Embedding layer missing from weight initialization. --- deeplay/models/discriminators/dcgan.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deeplay/models/discriminators/dcgan.py b/deeplay/models/discriminators/dcgan.py index 4f4cb504..87854fc6 100644 --- a/deeplay/models/discriminators/dcgan.py +++ b/deeplay/models/discriminators/dcgan.py @@ -20,6 +20,7 @@ def dcgan_discriminator(encoder: ConvolutionalEncoder2d): targets=( nn.Conv2d, nn.BatchNorm2d, + nn.Embedding, nn.Linear, ), mean=0, From 2d12b6249757c42b17a09f40bd006e5705ff4763 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Mon, 6 May 2024 11:17:41 +0200 Subject: [PATCH 06/18] Delete generator.py --- deeplay/models/generators/generator.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 deeplay/models/generators/generator.py diff --git a/deeplay/models/generators/generator.py b/deeplay/models/generators/generator.py deleted file mode 100644 index e69de29b..00000000 From 09582276b7ebf6ad9e7511d8c886af9929f66c6c Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Mon, 6 May 2024 11:19:33 +0200 Subject: [PATCH 07/18] Update cyclegan.py --- deeplay/models/generators/cyclegan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deeplay/models/generators/cyclegan.py b/deeplay/models/generators/cyclegan.py index b6eb40a9..ee5963b0 100644 --- a/deeplay/models/generators/cyclegan.py +++ b/deeplay/models/generators/cyclegan.py @@ -69,7 +69,7 @@ class CycleGANResnetGenerator(ConvolutionalEncoderDecoder2d): Examples -------- - >>> generator = CycleGANGenerator(in_channels=1, out_channels=3) + >>> generator = CycleGANResnetGenerator(in_channels=1, out_channels=3) >>> generator.build() >>> x = torch.randn(1, 1, 256, 256) >>> y = generator(x) From 3731d79a3d2b36c95a715ff8783e7a9ada41fcbb Mon Sep 17 00:00:00 2001 From: Benjamin Midtvedt Date: Mon, 6 May 2024 18:17:19 +0200 Subject: [PATCH 08/18] Allow to set tensors parameter to control what is initialized --- deeplay/initializers/constant.py | 5 +--- deeplay/initializers/initializer.py | 16 +++++------- deeplay/initializers/kaiming.py | 15 +++++++++--- deeplay/initializers/normal.py | 7 ++---- deeplay/module.py | 38 +++++++++++++++++------------ 5 files changed, 43 insertions(+), 38 deletions(-) diff --git a/deeplay/initializers/constant.py b/deeplay/initializers/constant.py index 7ead6a99..ae5359d9 100644 --- a/deeplay/initializers/constant.py +++ b/deeplay/initializers/constant.py @@ -21,8 +21,5 @@ def __init__( self.weight = weight self.bias = bias - def initialize_weight(self, tensor): + def initialize_tensor(self, tensor, name): tensor.data.fill_(self.weight) - - def initialize_bias(self, tensor): - tensor.data.fill_(self.bias) diff --git a/deeplay/initializers/initializer.py b/deeplay/initializers/initializer.py index e42a974c..de8d2cc1 100644 --- a/deeplay/initializers/initializer.py +++ b/deeplay/initializers/initializer.py @@ -3,15 +3,11 @@ class Initializer: def __init__(self, targets): self.targets = targets - def initialize(self, module): + def initialize(self, module, tensors=("weight", "bias")): if isinstance(module, self.targets): - if hasattr(module, "weight") and module.weight is not None: - self.initialize_weight(module.weight) - if hasattr(module, "bias") and module.bias is not None: - self.initialize_bias(module.bias) + for tensor in tensors: + if hasattr(module, tensor) and getattr(module, tensor) is not None: + self.initialize_tensor(getattr(module, tensor), name=tensor) - def initialize_weight(self, tensor): - pass - - def initialize_bias(self, tensor): - pass + def initialize_tensor(self, tensor, name): + raise NotImplementedError diff --git a/deeplay/initializers/kaiming.py b/deeplay/initializers/kaiming.py index 9c006dd2..95368388 100644 --- a/deeplay/initializers/kaiming.py +++ b/deeplay/initializers/kaiming.py @@ -24,13 +24,20 @@ def __init__( targets: Tuple[Type[nn.Module], ...] = _kaiming_default_targets, mode: str = "fan_out", nonlinearity: str = "relu", + fill_bias: bool = True, + bias: float = 0.0, ): super().__init__(targets) self.mode = mode self.nonlinearity = nonlinearity + self.fill_bias = fill_bias + self.bias = bias - def initialize_weight(self, tensor): - nn.init.kaiming_normal_(tensor, mode=self.mode, nonlinearity=self.nonlinearity) + def initialize_tensor(self, tensor, name): - def initialize_bias(self, tensor): - tensor.data.fill_(0.0) + if name == "bias" and self.fill_bias: + tensor.data.fill_(self.bias) + else: + nn.init.kaiming_normal_( + tensor, mode=self.mode, nonlinearity=self.nonlinearity + ) diff --git a/deeplay/initializers/normal.py b/deeplay/initializers/normal.py index 928bc516..e160c318 100644 --- a/deeplay/initializers/normal.py +++ b/deeplay/initializers/normal.py @@ -28,8 +28,5 @@ def __init__( self.mean = mean self.std = std - def initialize_bias(self, tensor): - tensor.data.fill_(self.mean) - - def initialize_weight(self, tensor): - tensor.data.normal_(self.mean, self.std) + def initialize_tensor(self, tensor, name): + tensor.data.normal_(mean=self.mean, std=self.std) diff --git a/deeplay/module.py b/deeplay/module.py index a73dac5c..b60ef718 100644 --- a/deeplay/module.py +++ b/deeplay/module.py @@ -935,7 +935,12 @@ def predict( if not isinstance(item, torch.Tensor): if isinstance(item, np.ndarray): batch[i] = torch.from_numpy(item).to(device) - if batch[i].dtype in [torch.float64, torch.float32, torch.float16, torch.float]: + if batch[i].dtype in [ + torch.float64, + torch.float32, + torch.float16, + torch.float, + ]: if hasattr(self, "dtype"): batch[i] = batch[i].to(self.dtype) else: @@ -1017,16 +1022,20 @@ def log_tensor(self, name: str, tensor: torch.Tensor): """ self.logs[name] = tensor - def initialize(self, initializer): + def initialize( + self, initializer, tensors: Union[str, Tuple[str, ...]] = ("weight", "bias") + ): + if isinstance(tensors, str): + tensors = (tensors,) for module in self.modules(): if isinstance(module, DeeplayModule): - module._initialize_after_build(initializer) + module._initialize_after_build(initializer, tensors) else: - initializer.initialize(module) + initializer.initialize(module, tensors) @after_build - def _initialize_after_build(self, initializer): - initializer.initialize(self) + def _initialize_after_build(self, initializer, tensors: Tuple[str, ...]): + initializer.initialize(self, tensors) @after_build def _validate_after_build(self): @@ -1208,13 +1217,10 @@ def _give_user_configuration(self, receiver: "DeeplayModule", name) -> bool: # break # else: # ... - # if config_before[key].value != config_after[key].value: - # any_change = True - # break - - + # if config_before[key].value != config_after[key].value: + # any_change = True + # break - # self._user_config._detached_configurations += ( # receiver._user_config._detached_configurations # ) @@ -1515,7 +1521,9 @@ def filter(self, func: Callable[[str, nn.Module], bool]) -> "Selection": return Selection(self.model[0], new_selections) - def hasattr(self, attr: str, strict=True, include_layer_classtype: bool = True) -> "Selection": + def hasattr( + self, attr: str, strict=True, include_layer_classtype: bool = True + ) -> "Selection": """Filter the selection based on whether the modules have a certain attribute. Note, for layers, the attribute is checked in the layer's classtype @@ -1545,14 +1553,14 @@ def _filter_fn(name: str, module: nn.Module): if include_layer_classtype and isinstance(module, Layer): return hasattr(module.classtype, attr) return False - if strict: from deeplay.list import ReferringLayerList + if isinstance(getattr(module, attr), ReferringLayerList): return False return True - + return self.filter(_filter_fn) def isinstance( From fdc21a48fd9770f4af76dfcfa20fb9a8154eee61 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Tue, 7 May 2024 15:16:30 +0200 Subject: [PATCH 09/18] Update dcgan.py Normalization is applied only to the weights and not biases --- deeplay/models/discriminators/dcgan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deeplay/models/discriminators/dcgan.py b/deeplay/models/discriminators/dcgan.py index 87854fc6..2052c09a 100644 --- a/deeplay/models/discriminators/dcgan.py +++ b/deeplay/models/discriminators/dcgan.py @@ -16,7 +16,7 @@ def dcgan_discriminator(encoder: ConvolutionalEncoder2d): encoder["blocks", :-1].all.configure("activation", nn.LeakyReLU, negative_slope=0.2) encoder.blocks[-1].activation.configure(nn.Sigmoid) - init = Normal( + initializer = Normal( targets=( nn.Conv2d, nn.BatchNorm2d, @@ -26,7 +26,7 @@ def dcgan_discriminator(encoder: ConvolutionalEncoder2d): mean=0, std=0.02, ) - encoder.initialize(init) + encoder.initialize(initializer, tensors="weight") class DCGANDiscriminator(ConvolutionalEncoder2d): From 3124490859c4c518966ca5fc4dcb043bcc3329e7 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Tue, 7 May 2024 15:17:54 +0200 Subject: [PATCH 10/18] Update dcgan.py Normalization is now applied only to the weights and not biases. Removed nn.Linear layer from the normalization as it does not exist in dcgan generator. --- deeplay/models/generators/dcgan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deeplay/models/generators/dcgan.py b/deeplay/models/generators/dcgan.py index 14ebe65e..4337997f 100644 --- a/deeplay/models/generators/dcgan.py +++ b/deeplay/models/generators/dcgan.py @@ -13,17 +13,17 @@ def dcgan_generator(generator: ConvolutionalDecoder2d): "layer", nn.ConvTranspose2d, kernel_size=4, stride=2, padding=1 ).remove("upsample", allow_missing=True) generator.blocks[0].layer.configure(stride=1, padding=0) - init = Normal( + + initializer = Normal( targets=( nn.ConvTranspose2d, nn.BatchNorm2d, nn.Embedding, - nn.Linear, ), mean=0, std=0.02, ) - generator.initialize(init) + generator.initialize(initializer, tensors="weight") class DCGANGenerator(ConvolutionalDecoder2d): From 50880f4288ff4dd30b56f72bd28a248be2e1101a Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Fri, 10 May 2024 17:35:54 +0200 Subject: [PATCH 11/18] Update dcgan.py --- deeplay/models/discriminators/dcgan.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/deeplay/models/discriminators/dcgan.py b/deeplay/models/discriminators/dcgan.py index 2052c09a..404cc169 100644 --- a/deeplay/models/discriminators/dcgan.py +++ b/deeplay/models/discriminators/dcgan.py @@ -135,9 +135,10 @@ def forward(self, x, y=None): ) if self.class_conditioned_model: - assert ( - y is not None - ), "Class label y must be provided for class-conditional discriminator" + if y is None: + raise ValueError( + "Class label y must be provided for class-conditional discriminator" + ) y = self.label_embedding(y) y = y.view(-1, 1, 64, 64) From 6097f3ece46275b2549d3bfb6f165863611c209b Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Fri, 10 May 2024 17:36:09 +0200 Subject: [PATCH 12/18] Update cyclegan.py --- deeplay/models/generators/cyclegan.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/deeplay/models/generators/cyclegan.py b/deeplay/models/generators/cyclegan.py index ee5963b0..31376030 100644 --- a/deeplay/models/generators/cyclegan.py +++ b/deeplay/models/generators/cyclegan.py @@ -21,8 +21,9 @@ def cyclegan_resnet_encoder(encoder: ConvolutionalEncoder2d): encoder.strided(2) encoder.normalized(Layer(nn.InstanceNorm2d)) encoder.blocks.configure(order=["layer", "normalization", "activation"]) - encoder.blocks[0].prepend(Layer(nn.ReflectionPad2d, 3)) - encoder.blocks[0].configure("layer", kernel_size=7, stride=1, padding=0) + encoder.blocks[0].configure( + "layer", kernel_size=7, stride=1, padding=3, padding_mode="reflect" + ) @ConvolutionalDecoder2d.register_style @@ -34,8 +35,9 @@ def cyclegan_resnet_decoder(decoder: ConvolutionalDecoder2d): decoder.blocks[:-1].configure( "layer", nn.ConvTranspose2d, stride=2, output_padding=1 ) - decoder.blocks[-1].configure(kernel_size=7, stride=1, padding=0) - decoder.blocks[-1].prepend(Layer(nn.ReflectionPad2d, 3)) + decoder.blocks[-1].configure( + "layer", kernel_size=7, stride=1, padding=3, padding_mode="reflect" + ) @ConvolutionalNeuralNetwork.register_style From ae68cb5a579a8e7811fae5cb8a90f8ddca0aceea Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Fri, 10 May 2024 17:36:22 +0200 Subject: [PATCH 13/18] Update dcgan.py --- deeplay/models/generators/dcgan.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/deeplay/models/generators/dcgan.py b/deeplay/models/generators/dcgan.py index 4337997f..4ab9fd06 100644 --- a/deeplay/models/generators/dcgan.py +++ b/deeplay/models/generators/dcgan.py @@ -60,11 +60,11 @@ class DCGANGenerator(ConvolutionalDecoder2d): Examples -------- - >>> generator = DCGAN_Generator(latent_dim=100, output_channels=1, class_conditioned_model=False) + >>> generator = DCGANGenerator(latent_dim=100, output_channels=1, class_conditioned_model=False) >>> generator.build() >>> batch_size = 16 >>> input = torch.randn([batch_size, 100, 1, 1]) - >>> output = generator(input) + >>> output = generator(x=input, y=None) Return Values ------------- @@ -81,7 +81,7 @@ class DCGANGenerator(ConvolutionalDecoder2d): def __init__( self, latent_dim: int = 100, - features_dim: int = 128, + features_dim: int = 64, out_channels: int = 1, class_conditioned_model: bool = False, embedding_dim: int = 100, @@ -122,9 +122,10 @@ def __init__( def forward(self, x, y=None): if self.class_conditioned_model: - assert ( - y is not None - ), "Class label y must be provided for class-conditional generator" + if y is None: + raise ValueError( + "Class label y must be provided for class-conditional generator" + ) y = self.label_embedding(y) y = y.view(-1, self.embedding_dim, 1, 1) From 699e0b8dd2604478a7063325709d78388c967d6f Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Fri, 10 May 2024 17:36:32 +0200 Subject: [PATCH 14/18] Create test_cyclegan.py --- .../models/discriminators/test_cyclegan.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 deeplay/tests/models/discriminators/test_cyclegan.py diff --git a/deeplay/tests/models/discriminators/test_cyclegan.py b/deeplay/tests/models/discriminators/test_cyclegan.py new file mode 100644 index 00000000..f6a44364 --- /dev/null +++ b/deeplay/tests/models/discriminators/test_cyclegan.py @@ -0,0 +1,33 @@ +import unittest + +import torch +import torch.nn as nn + +from deeplay.models.discriminators.cyclegan import CycleGANDiscriminator + + +class TestCycleGANDiscriminator(unittest.TestCase): + + def test_discriminator_defaults(self): + + discriminator = CycleGANDiscriminator().build() + + self.assertEqual(len(discriminator.blocks), 5) + self.assertTrue( + all( + isinstance(discriminator.blocks.normalization[i], nn.InstanceNorm2d) + for i in range(1, 4) + ) + ) + self.assertTrue( + all( + isinstance(discriminator.blocks.activation[i], nn.LeakyReLU) + for i in range(4) + ) + ) + self.assertTrue(isinstance(discriminator.blocks[-1].activation, nn.Sigmoid)) + + # Test on a batch of 2 + x = torch.rand(2, 1, 256, 256) + output = discriminator(x) + self.assertEqual(output.shape, (2, 1, 30, 30)) From 81075f57806fa63526305ee48c8184a6938d0960 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Fri, 10 May 2024 17:36:42 +0200 Subject: [PATCH 15/18] Update test_dcgan.py --- .../tests/models/discriminators/test_dcgan.py | 94 ++++++++++++++++--- 1 file changed, 82 insertions(+), 12 deletions(-) diff --git a/deeplay/tests/models/discriminators/test_dcgan.py b/deeplay/tests/models/discriminators/test_dcgan.py index 2c694dcc..9adba055 100644 --- a/deeplay/tests/models/discriminators/test_dcgan.py +++ b/deeplay/tests/models/discriminators/test_dcgan.py @@ -6,19 +6,89 @@ from deeplay.models.discriminators.dcgan import DCGANDiscriminator -class TestDCGANGenerator(unittest.TestCase): +class TestDCGANDiscriminator(unittest.TestCase): + ... - def test_init(self): - discr = DCGANDiscriminator().build() - data = torch.randn(1, 1, 64, 64) - output = discr(data) + def test_discriminator_defaults(self): - self.assertEqual(output.shape, (1, 1, 1, 1)) + discriminator = DCGANDiscriminator() + discriminator.build() - def test_conditioned(self): - discr = DCGANDiscriminator(class_conditioned_model=True).build() - data = torch.randn(1, 1, 64, 64) - labels = torch.randint(0, 10, (1,)) - output = discr(data, labels) + self.assertEqual(len(discriminator.blocks), 5) + self.assertEqual( + [discriminator.blocks[i].layer.kernel_size for i in range(5)], [(4, 4)] * 5 + ) - self.assertEqual(output.shape, (1, 1, 1, 1)) + self.assertEqual( + [discriminator.blocks[i].layer.stride for i in range(5)], [(2, 2)] * 5 + ) + + self.assertEqual( + [discriminator.blocks[i].layer.padding for i in range(4)], [(1, 1)] * 4 + ) + self.assertEqual(discriminator.blocks[-1].layer.padding, (0, 0)) + + self.assertTrue( + all( + isinstance(discriminator.blocks[i].activation, nn.LeakyReLU) + for i in range(4) + ) + ) + self.assertTrue(isinstance(discriminator.blocks[-1].activation, nn.Sigmoid)) + + self.assertTrue( + all( + isinstance(discriminator.blocks[1:-1].normalization[i], nn.BatchNorm2d) + for i in range(3) + ) + ) + + self.assertTrue(isinstance(discriminator.label_embedding, nn.Identity)) + + # Test on a batch of 2 + x = torch.rand(2, 1, 64, 64) + output = discriminator(x, y=None) + self.assertEqual(output.shape, (2, 1, 1, 1)) + + def test_conditional_discriminator_defaults(self): + + discriminator = DCGANDiscriminator(class_conditioned_model=True) + discriminator.build() + + self.assertTrue( + isinstance(discriminator.label_embedding.embedding, nn.Embedding) + ) + self.assertTrue(isinstance(discriminator.label_embedding.layer, nn.Linear)) + self.assertTrue( + isinstance(discriminator.label_embedding.activation, nn.LeakyReLU) + ) + + self.assertTrue(discriminator.label_embedding.embedding.num_embeddings, 10) + self.assertTrue(discriminator.label_embedding.layer.in_features, 100) + self.assertTrue(discriminator.label_embedding.layer.out_features, 64 * 64) + + # Test on a batch of 2 + x = torch.rand(2, 1, 64, 64) + y = torch.randint(0, 10, (2,)) + output = discriminator(x, y) + self.assertEqual(output.shape, (2, 1, 1, 1)) + + def test_weight_initialization(self): + + generator = DCGANDiscriminator() + generator.build() + + for m in generator.modules(): + if isinstance(m, (nn.Conv2d, nn.BatchNorm2d)): + self.assertAlmostEqual(m.weight.data.mean().item(), 0.0, places=2) + self.assertAlmostEqual(m.weight.data.std().item(), 0.02, places=2) + + def test_weight_initialization_conditional(self): + + generator = DCGANDiscriminator(class_conditioned_model=True) + generator.build() + + for m in generator.modules(): + if isinstance(m, (nn.Conv2d, nn.BatchNorm2d, nn.Embedding, nn.Linear)): + self.assertAlmostEqual(m.weight.data.mean().item(), 0.0, places=2) + self.assertAlmostEqual(m.weight.data.std().item(), 0.02, places=2) From 4c7ab94c40a8857f173711fed25bfdb29c5e23ea Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Fri, 10 May 2024 17:36:55 +0200 Subject: [PATCH 16/18] Create __init__.py --- deeplay/tests/models/generators/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 deeplay/tests/models/generators/__init__.py diff --git a/deeplay/tests/models/generators/__init__.py b/deeplay/tests/models/generators/__init__.py new file mode 100644 index 00000000..e69de29b From 368d1493d9ccff972585ce124c9c1b6ac65dcc76 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Fri, 10 May 2024 17:37:08 +0200 Subject: [PATCH 17/18] Update test_cyclegan.py --- .../tests/models/generators/test_cyclegan.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/deeplay/tests/models/generators/test_cyclegan.py b/deeplay/tests/models/generators/test_cyclegan.py index 84f351b6..b4baa804 100644 --- a/deeplay/tests/models/generators/test_cyclegan.py +++ b/deeplay/tests/models/generators/test_cyclegan.py @@ -10,6 +10,40 @@ class TestCycleGANResnetGenerator(unittest.TestCase): def test_init(self): generator = CycleGANResnetGenerator().build() + + # Encoder + self.assertEqual(len(generator.encoder.blocks), 3) + self.assertTrue( + all( + isinstance(generator.encoder.blocks.normalization[i], nn.InstanceNorm2d) + for i in range(3) + ) + ) + self.assertTrue( + all( + isinstance(generator.encoder.blocks.activation[i], nn.ReLU) + for i in range(3) + ) + ) + + # Decoder + self.assertEqual(len(generator.decoder.blocks), 3) + self.assertTrue( + all( + isinstance( + generator.decoder.blocks[:-1].normalization[i], nn.InstanceNorm2d + ) + for i in range(2) + ) + ) + self.assertTrue( + all( + isinstance(generator.decoder.blocks.activation[i], nn.ReLU) + for i in range(2) + ) + ) + self.assertTrue(isinstance(generator.decoder.blocks[-1].activation, nn.Tanh)) + data = torch.randn(1, 1, 32, 32) output = generator(data) From 69e652be67c1db33400546fe7743a0f888175243 Mon Sep 17 00:00:00 2001 From: Harshith Bachimanchi <62615092+HarshithBachimanchi@users.noreply.github.com> Date: Fri, 10 May 2024 17:37:20 +0200 Subject: [PATCH 18/18] Update test_dcgan.py --- deeplay/tests/models/generators/test_dcgan.py | 80 ++++++++++++++++--- 1 file changed, 69 insertions(+), 11 deletions(-) diff --git a/deeplay/tests/models/generators/test_dcgan.py b/deeplay/tests/models/generators/test_dcgan.py index 6b6974fb..bbe62662 100644 --- a/deeplay/tests/models/generators/test_dcgan.py +++ b/deeplay/tests/models/generators/test_dcgan.py @@ -7,18 +7,76 @@ class TestDCGANGenerator(unittest.TestCase): + ... - def test_init(self): - generator = DCGANGenerator().build() - data = torch.randn(1, 100, 1, 1) - output = generator(data) + def test_generator_defaults(self): - self.assertEqual(output.shape, (1, 1, 64, 64)) + generator = DCGANGenerator() + generator.build() - def test_conditioned(self): - generator = DCGANGenerator(class_conditioned_model=True).build() - data = torch.randn(1, 100, 1, 1) - labels = torch.randint(0, 10, (1,)) - output = generator(data, labels) + self.assertEqual(len(generator.blocks), 5) + self.assertEqual( + [generator.blocks[i].layer.kernel_size for i in range(5)], [(4, 4)] * 5 + ) - self.assertEqual(output.shape, (1, 1, 64, 64)) + self.assertEqual(generator.blocks[0].layer.stride, (1, 1)) + self.assertEqual( + [generator.blocks[i].layer.stride for i in range(1, 5)], [(2, 2)] * 4 + ) + + self.assertEqual(generator.blocks[0].layer.padding, (0, 0)) + self.assertEqual(generator.blocks[-1].layer.padding, (1, 1)) + + self.assertTrue( + all(isinstance(generator.blocks.activation[i], nn.ReLU) for i in range(4)) + ) + self.assertTrue(isinstance(generator.blocks[-1].activation, nn.Tanh)) + + self.assertTrue( + all( + isinstance(generator.blocks[:-1].normalization[i], nn.BatchNorm2d) + for i in range(4) + ) + ) + + self.assertTrue(isinstance(generator.label_embedding, nn.Identity)) + + # Test on a batch of 2 + x = torch.rand(2, 100, 1, 1) + output = generator(x, y=None) + self.assertEqual(output.shape, (2, 1, 64, 64)) + + def test_conditional_generator_defaults(self): + + generator = DCGANGenerator(class_conditioned_model=True) + generator.build() + + self.assertTrue(isinstance(generator.label_embedding, nn.Embedding)) + self.assertEqual(generator.label_embedding.num_embeddings, 10) + self.assertEqual(generator.label_embedding.embedding_dim, 100) + + # Test on a batch of 2 + x = torch.rand(2, 100, 1, 1) + y = torch.randint(0, 10, (2,)) + output = generator(x, y) + self.assertEqual(output.shape, (2, 1, 64, 64)) + + def test_weight_initialization(self): + + generator = DCGANGenerator() + generator.build() + + for m in generator.modules(): + if isinstance(m, (nn.ConvTranspose2d, nn.BatchNorm2d)): + self.assertAlmostEqual(m.weight.data.mean().item(), 0.0, places=2) + self.assertAlmostEqual(m.weight.data.std().item(), 0.02, places=2) + + def test_weight_initialization_conditional(self): + + generator = DCGANGenerator(class_conditioned_model=True) + generator.build() + + for m in generator.modules(): + if isinstance(m, (nn.ConvTranspose2d, nn.BatchNorm2d, nn.Embedding)): + self.assertAlmostEqual(m.weight.data.mean().item(), 0.0, places=2) + self.assertAlmostEqual(m.weight.data.std().item(), 0.02, places=2)