diff --git a/nowcasting_gan/discriminators.py b/nowcasting_gan/discriminators.py index 4995d22..a59f5ac 100644 --- a/nowcasting_gan/discriminators.py +++ b/nowcasting_gan/discriminators.py @@ -7,15 +7,20 @@ class NowcastingDiscriminator(torch.nn.Module): - def __init__(self,input_channels: int = 12, num_spatial_frames: int = 8, - num_temporal_crop_size: int = 128, conv_type: str = 'standard'): + def __init__( + self, + input_channels: int = 12, + num_spatial_frames: int = 8, + num_temporal_crop_size: int = 128, + conv_type: str = "standard", + ): super().__init__() - self.spatial_discriminator = NowcastingSpatialDiscriminator(input_channels = - input_channels, num_timesteps - = num_spatial_frames, conv_type = conv_type) - self.temporal_discriminator = NowcastingTemporalDiscriminator(input_channels = - input_channels, crop_size = - num_temporal_crop_size, conv_type = conv_type) + self.spatial_discriminator = NowcastingSpatialDiscriminator( + input_channels=input_channels, num_timesteps=num_spatial_frames, conv_type=conv_type + ) + self.temporal_discriminator = NowcastingTemporalDiscriminator( + input_channels=input_channels, crop_size=num_temporal_crop_size, conv_type=conv_type + ) def forward(self, x: torch.Tensor) -> torch.Tensor: spatial_loss = self.spatial_discriminator(x)