diff --git a/nowcasting_gan/common.py b/nowcasting_gan/common.py index 26c0ef2..1c856c4 100644 --- a/nowcasting_gan/common.py +++ b/nowcasting_gan/common.py @@ -10,9 +10,13 @@ class GBlock(torch.nn.Module): """Residual generator block without upsampling""" + def __init__( - self, input_channels: int = 12, output_channels: int = 12, conv_type: str = "standard", - spectral_normalized_eps = 0.0001 + self, + input_channels: int = 12, + output_channels: int = 12, + conv_type: str = "standard", + spectral_normalized_eps=0.0001, ): """ G Block from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf @@ -71,10 +75,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class UpsampleGBlock(torch.nn.Module): """Residual generator block with upsampling""" + def __init__( - self, input_channels: int = 12, output_channels: int = 12, conv_type: str = "standard", - spectral_normalized_eps = 0.0001 - ): + self, + input_channels: int = 12, + output_channels: int = 12, + conv_type: str = "standard", + spectral_normalized_eps=0.0001, + ): """ G Block from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf Args: @@ -95,7 +103,7 @@ def __init__( in_channels=input_channels, out_channels=output_channels, kernel_size=1, - ) + ) self.upsample = torch.nn.Upsample(scale_factor=2, mode="nearest") # Upsample 2D conv self.first_conv_3x3 = torch.nn.ConvTranspose2d( @@ -104,10 +112,10 @@ def __init__( kernel_size=3, stride=2, padding=1, - ) + ) self.last_conv_3x3 = conv2d( in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1 - ) + ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Spectrally normalized 1x1 convolution