Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 8, 2021
1 parent c2606fc commit ba334b5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 35 deletions.
3 changes: 1 addition & 2 deletions import_deepmind_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@
module = tensorflow_hub.load("/home/jacob/256x256/")
print(module)
print(module.signatures)
sig_model = module.signatures['default']

sig_model = module.signatures["default"]
2 changes: 1 addition & 1 deletion nowcasting_gan/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def __init__(
self.att_block = SelfAttention2d(
input_dims=output_channels // 4, output_dims=output_channels // 4
)
#self.att_block_conv = torch.nn.Conv2d(in_channels=output_channels // 16, out_channels=output_channels // 4, kernel_size=(1,1))
# self.att_block_conv = torch.nn.Conv2d(in_channels=output_channels // 16, out_channels=output_channels // 4, kernel_size=(1,1))
self.l_block4 = LBlock(input_channels=output_channels // 4, output_channels=output_channels)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
41 changes: 32 additions & 9 deletions nowcasting_gan/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARN)


class NowcastingSampler(torch.nn.Module):
def __init__(
self,
Expand All @@ -30,36 +31,44 @@ def __init__(
super().__init__()
self.forecast_steps = forecast_steps
self.convGRU1 = ConvGRU(
input_channels=latent_channels+context_channels,
input_channels=latent_channels + context_channels,
hidden_channels=context_channels,
kernel_size=(3, 3),
n_layers=1,
)
self.gru_conv_1x1 = torch.nn.Conv2d(in_channels=context_channels, out_channels=latent_channels, kernel_size=(1,1))
self.gru_conv_1x1 = torch.nn.Conv2d(
in_channels=context_channels, out_channels=latent_channels, kernel_size=(1, 1)
)
self.g1 = GBlock(input_channels=latent_channels, output_channels=latent_channels // 2)
self.convGRU2 = ConvGRU(
input_channels=latent_channels // 2 + context_channels // 2,
hidden_channels=context_channels // 2,
kernel_size=(3, 3),
n_layers=1,
)
self.gru_conv_1x1_2 = torch.nn.Conv2d(in_channels=context_channels // 2, out_channels=latent_channels // 2, kernel_size=(1,1))
self.gru_conv_1x1_2 = torch.nn.Conv2d(
in_channels=context_channels // 2, out_channels=latent_channels // 2, kernel_size=(1, 1)
)
self.g2 = GBlock(input_channels=latent_channels // 2, output_channels=latent_channels // 4)
self.convGRU3 = ConvGRU(
input_channels=latent_channels // 4 + context_channels // 4,
hidden_channels=context_channels // 4,
kernel_size=(3, 3),
n_layers=1,
)
self.gru_conv_1x1_3 = torch.nn.Conv2d(in_channels=context_channels // 4, out_channels=latent_channels // 4, kernel_size=(1,1))
self.gru_conv_1x1_3 = torch.nn.Conv2d(
in_channels=context_channels // 4, out_channels=latent_channels // 4, kernel_size=(1, 1)
)
self.g3 = GBlock(input_channels=latent_channels // 4, output_channels=latent_channels // 8)
self.convGRU4 = ConvGRU(
input_channels=latent_channels // 8 + context_channels // 8,
hidden_channels=context_channels // 8,
kernel_size=(3, 3),
n_layers=1,
)
self.gru_conv_1x1_4 = torch.nn.Conv2d(in_channels=context_channels // 8, out_channels=latent_channels // 8, kernel_size=(1,1))
self.gru_conv_1x1_4 = torch.nn.Conv2d(
in_channels=context_channels // 8, out_channels=latent_channels // 8, kernel_size=(1, 1)
)
self.g4 = GBlock(input_channels=latent_channels // 8, output_channels=latent_channels // 16)
self.bn = torch.nn.BatchNorm2d(latent_channels // 16)
self.relu = torch.nn.ReLU()
Expand Down Expand Up @@ -118,7 +127,12 @@ def forward(
for i in range(self.forecast_steps):
# Start at lowest one and go up, conditioning states
# ConvGRU1
x, init_states[3] = self.stacks[f"forecast_{i}"][0](torch.cat([latent_dim, einops.rearrange(init_states[3], 'b t c h w -> t b c h w')], dim=2), hidden_state=init_states[3])
x, init_states[3] = self.stacks[f"forecast_{i}"][0](
torch.cat(
[latent_dim, einops.rearrange(init_states[3], "b t c h w -> t b c h w")], dim=2
),
hidden_state=init_states[3],
)
# Update for next timestep
logger.debug(f"GRU1 x: {x.shape} hidden: {init_states[3].shape}")
# init_states[3] = torch.squeeze(x, dim=0)
Expand All @@ -132,7 +146,10 @@ def forward(
x = torch.unsqueeze(x, dim=1)
logger.debug(f"x: {x.shape} hidden: {init_states[3].shape}")
# ConvGRU2
x, init_states[2] = self.stacks[f"forecast_{i}"][3](torch.cat([x, einops.rearrange(init_states[2], 'b t c h w -> t b c h w')], dim=2), hidden_state=init_states[2])
x, init_states[2] = self.stacks[f"forecast_{i}"][3](
torch.cat([x, einops.rearrange(init_states[2], "b t c h w -> t b c h w")], dim=2),
hidden_state=init_states[2],
)
logger.debug(f"GRU2 x: {x.shape} hidden: {init_states[2].shape}")
# Update for next timestep
# init_states[2] = torch.squeeze(x, dim=0)
Expand All @@ -145,7 +162,10 @@ def forward(
# Expand to 5D input
x = torch.unsqueeze(x, dim=1)
# ConvGRU3
x, init_states[1] = self.stacks[f"forecast_{i}"][6](torch.cat([x, einops.rearrange(init_states[1], 'b t c h w -> t b c h w')], dim=2), hidden_state=init_states[1])
x, init_states[1] = self.stacks[f"forecast_{i}"][6](
torch.cat([x, einops.rearrange(init_states[1], "b t c h w -> t b c h w")], dim=2),
hidden_state=init_states[1],
)
logger.debug(f"GRU3 x: {x.shape} hidden: {init_states[1].shape}")
# Update for next timestep
# init_states[1] = torch.squeeze(x, dim=0)
Expand All @@ -158,7 +178,10 @@ def forward(
# Expand to 5D input
x = torch.unsqueeze(x, dim=1)
# ConvGRU4
x, init_states[0] = self.stacks[f"forecast_{i}"][9](torch.cat([x, einops.rearrange(init_states[0], 'b t c h w -> t b c h w')], dim=2), hidden_state=init_states[0])
x, init_states[0] = self.stacks[f"forecast_{i}"][9](
torch.cat([x, einops.rearrange(init_states[0], "b t c h w -> t b c h w")], dim=2),
hidden_state=init_states[0],
)
logger.debug(f"GRU4 x: {x.shape} hidden: {init_states[0].shape}")
# Update for next timestep
# init_states[0] = torch.squeeze(x, dim=0)
Expand Down
46 changes: 23 additions & 23 deletions nowcasting_gan/layers/ConvGRU.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

class ConvGRUCell(nn.Module):
def __init__(
self,
input_dim,
hidden_dim,
kernel_size=(3, 3),
bias=True,
activation=F.tanh,
batchnorm=False,
self,
input_dim,
hidden_dim,
kernel_size=(3, 3),
bias=True,
activation=F.tanh,
batchnorm=False,
):
"""
Initialize ConvGRU cell.
Expand Down Expand Up @@ -124,17 +124,17 @@ def forward(self, x):

class ConvGRU(nn.Module):
def __init__(
self,
input_channels,
hidden_channels,
kernel_size,
n_layers,
batch_first=True,
bias=True,
activation=F.tanh,
input_p=0.2,
hidden_p=0.1,
batchnorm=False,
self,
input_channels,
hidden_channels,
kernel_size,
n_layers,
batch_first=True,
bias=True,
activation=F.tanh,
input_p=0.2,
hidden_p=0.1,
batchnorm=False,
):
super(ConvGRU, self).__init__()

Expand Down Expand Up @@ -233,11 +233,11 @@ def get_init_states(self, input):
@staticmethod
def _check_kernel_size_consistency(kernel_size):
if not (
isinstance(kernel_size, tuple)
or (
isinstance(kernel_size, list)
and all([isinstance(elem, tuple) for elem in kernel_size])
)
isinstance(kernel_size, tuple)
or (
isinstance(kernel_size, list)
and all([isinstance(elem, tuple) for elem in kernel_size])
)
):
raise ValueError("`kernel_size` must be tuple or list of tuples")

Expand Down

0 comments on commit ba334b5

Please sign in to comment.