Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 8, 2021
1 parent 1ea30f5 commit eef675f
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions nowcasting_gan/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def __init__(
in_channels=context_channels, out_channels=latent_channels, kernel_size=(1, 1)
)
self.g1 = GBlock(input_channels=latent_channels, output_channels=latent_channels)
self.up_g1 = UpsampleGBlock(input_channels = latent_channels, output_channels =
latent_channels // 2)
self.up_g1 = UpsampleGBlock(
input_channels=latent_channels, output_channels=latent_channels // 2
)

self.convGRU2 = ConvGRU(
input_channels=latent_channels // 2 + context_channels // 2,
Expand All @@ -53,8 +54,9 @@ def __init__(
in_channels=context_channels // 2, out_channels=latent_channels // 2, kernel_size=(1, 1)
)
self.g2 = GBlock(input_channels=latent_channels // 2, output_channels=latent_channels // 2)
self.up_g2 = UpsampleGBlock(input_channels = latent_channels // 2, output_channels =
latent_channels // 4)
self.up_g2 = UpsampleGBlock(
input_channels=latent_channels // 2, output_channels=latent_channels // 4
)

self.convGRU3 = ConvGRU(
input_channels=latent_channels // 4 + context_channels // 4,
Expand All @@ -65,8 +67,9 @@ def __init__(
in_channels=context_channels // 4, out_channels=latent_channels // 4, kernel_size=(1, 1)
)
self.g3 = GBlock(input_channels=latent_channels // 4, output_channels=latent_channels // 4)
self.up_g3 = UpsampleGBlock(input_channels = latent_channels // 4, output_channels =
latent_channels // 8)
self.up_g3 = UpsampleGBlock(
input_channels=latent_channels // 4, output_channels=latent_channels // 8
)

self.convGRU4 = ConvGRU(
input_channels=latent_channels // 8 + context_channels // 8,
Expand All @@ -77,15 +80,17 @@ def __init__(
in_channels=context_channels // 8, out_channels=latent_channels // 8, kernel_size=(1, 1)
)
self.g4 = GBlock(input_channels=latent_channels // 8, output_channels=latent_channels // 8)
self.up_g4 = UpsampleGBlock(input_channels = latent_channels // 8, output_channels =
latent_channels // 16)
self.up_g4 = UpsampleGBlock(
input_channels=latent_channels // 8, output_channels=latent_channels // 16
)

self.bn = torch.nn.BatchNorm2d(latent_channels // 16)
self.relu = torch.nn.ReLU()
self.conv_1x1 = spectral_norm(
torch.nn.Conv2d(
in_channels=latent_channels // 16, out_channels=4 * output_channels,
kernel_size=(1,1)
in_channels=latent_channels // 16,
out_channels=4 * output_channels,
kernel_size=(1, 1),
)
)

Expand Down

0 comments on commit eef675f

Please sign in to comment.