diff --git a/nowcasting_gan/generators.py b/nowcasting_gan/generators.py index 0732a96..ee06fdd 100644 --- a/nowcasting_gan/generators.py +++ b/nowcasting_gan/generators.py @@ -41,8 +41,9 @@ def __init__( in_channels=context_channels, out_channels=latent_channels, kernel_size=(1, 1) ) self.g1 = GBlock(input_channels=latent_channels, output_channels=latent_channels) - self.up_g1 = UpsampleGBlock(input_channels = latent_channels, output_channels = - latent_channels // 2) + self.up_g1 = UpsampleGBlock( + input_channels=latent_channels, output_channels=latent_channels // 2 + ) self.convGRU2 = ConvGRU( input_channels=latent_channels // 2 + context_channels // 2, @@ -53,8 +54,9 @@ def __init__( in_channels=context_channels // 2, out_channels=latent_channels // 2, kernel_size=(1, 1) ) self.g2 = GBlock(input_channels=latent_channels // 2, output_channels=latent_channels // 2) - self.up_g2 = UpsampleGBlock(input_channels = latent_channels // 2, output_channels = - latent_channels // 4) + self.up_g2 = UpsampleGBlock( + input_channels=latent_channels // 2, output_channels=latent_channels // 4 + ) self.convGRU3 = ConvGRU( input_channels=latent_channels // 4 + context_channels // 4, @@ -65,8 +67,9 @@ def __init__( in_channels=context_channels // 4, out_channels=latent_channels // 4, kernel_size=(1, 1) ) self.g3 = GBlock(input_channels=latent_channels // 4, output_channels=latent_channels // 4) - self.up_g3 = UpsampleGBlock(input_channels = latent_channels // 4, output_channels = - latent_channels // 8) + self.up_g3 = UpsampleGBlock( + input_channels=latent_channels // 4, output_channels=latent_channels // 8 + ) self.convGRU4 = ConvGRU( input_channels=latent_channels // 8 + context_channels // 8, @@ -77,15 +80,17 @@ def __init__( in_channels=context_channels // 8, out_channels=latent_channels // 8, kernel_size=(1, 1) ) self.g4 = GBlock(input_channels=latent_channels // 8, output_channels=latent_channels // 8) - self.up_g4 = UpsampleGBlock(input_channels = latent_channels // 8, output_channels = - latent_channels // 16) + self.up_g4 = UpsampleGBlock( + input_channels=latent_channels // 8, output_channels=latent_channels // 16 + ) self.bn = torch.nn.BatchNorm2d(latent_channels // 16) self.relu = torch.nn.ReLU() self.conv_1x1 = spectral_norm( torch.nn.Conv2d( - in_channels=latent_channels // 16, out_channels=4 * output_channels, - kernel_size=(1,1) + in_channels=latent_channels // 16, + out_channels=4 * output_channels, + kernel_size=(1, 1), ) )