Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 8, 2021
1 parent c9a89e8 commit 1f2fc3c
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions nowcasting_gan/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@


class NowcastingDiscriminator(torch.nn.Module):
def __init__(self,input_channels: int = 12, num_spatial_frames: int = 8,
num_temporal_crop_size: int = 128, conv_type: str = 'standard'):
def __init__(
self,
input_channels: int = 12,
num_spatial_frames: int = 8,
num_temporal_crop_size: int = 128,
conv_type: str = "standard",
):
super().__init__()
self.spatial_discriminator = NowcastingSpatialDiscriminator(input_channels =
input_channels, num_timesteps
= num_spatial_frames, conv_type = conv_type)
self.temporal_discriminator = NowcastingTemporalDiscriminator(input_channels =
input_channels, crop_size =
num_temporal_crop_size, conv_type = conv_type)
self.spatial_discriminator = NowcastingSpatialDiscriminator(
input_channels=input_channels, num_timesteps=num_spatial_frames, conv_type=conv_type
)
self.temporal_discriminator = NowcastingTemporalDiscriminator(
input_channels=input_channels, crop_size=num_temporal_crop_size, conv_type=conv_type
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
spatial_loss = self.spatial_discriminator(x)
Expand Down

0 comments on commit 1f2fc3c

Please sign in to comment.