Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and jacobbieker committed Feb 4, 2022
1 parent c73e2ed commit 9aa59c5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
55 changes: 28 additions & 27 deletions dgmr/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
spatial_loss = self.spatial_discriminator(x)
temporal_loss = self.temporal_discriminator(x)

return torch.cat([spatial_loss, temporal_loss], dim = 1)
return torch.cat([spatial_loss, temporal_loss], dim=1)


class TemporalDiscriminator(torch.nn.Module, PyTorchModelHubMixin):
def __init__(
Expand Down Expand Up @@ -132,13 +133,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class SpatialDiscriminator(torch.nn.Module, PyTorchModelHubMixin):
def __init__(
self,
input_channels: int = 12,
num_timesteps: int = 8,
num_layers: int = 4,
conv_type: str = "standard",
**kwargs
):
self,
input_channels: int = 12,
num_timesteps: int = 8,
num_layers: int = 4,
conv_type: str = "standard",
**kwargs
):
"""
Spatial discriminator from Skillful Nowcasting, see https://arxiv.org/pdf/2104.00954.pdf
Expand All @@ -161,30 +162,30 @@ def __init__(
self.num_timesteps = num_timesteps
# First step is mean pooling 2x2 to reduce input by half
self.mean_pool = torch.nn.AvgPool2d(2)
self.space2depth = PixelUnshuffle(downscale_factor = 2)
self.space2depth = PixelUnshuffle(downscale_factor=2)
internal_chn = 24
self.d1 = DBlock(
input_channels = 4 * input_channels,
output_channels = 2 * internal_chn * input_channels,
first_relu = False,
conv_type = conv_type,
)
input_channels=4 * input_channels,
output_channels=2 * internal_chn * input_channels,
first_relu=False,
conv_type=conv_type,
)
self.intermediate_dblocks = torch.nn.ModuleList()
for _ in range(num_layers):
internal_chn *= 2
self.intermediate_dblocks.append(
DBlock(
input_channels = internal_chn * input_channels,
output_channels = 2 * internal_chn * input_channels,
conv_type = conv_type,
)
input_channels=internal_chn * input_channels,
output_channels=2 * internal_chn * input_channels,
conv_type=conv_type,
)
self.d6 = DBlock(
input_channels = 2 * internal_chn * input_channels,
output_channels = 2 * internal_chn * input_channels,
keep_same_output = True,
conv_type = conv_type,
)
self.d6 = DBlock(
input_channels=2 * internal_chn * input_channels,
output_channels=2 * internal_chn * input_channels,
keep_same_output=True,
conv_type=conv_type,
)

# Spectrally normalized linear layer for binary classification
self.fc = spectral_norm(torch.nn.Linear(2 * internal_chn * input_channels, 1))
Expand All @@ -193,7 +194,7 @@ def __init__(

def forward(self, x: torch.Tensor) -> torch.Tensor:
# x should be the chosen 8 or so
idxs = torch.randint(low = 0, high = x.size()[1], size = (self.num_timesteps,))
idxs = torch.randint(low=0, high=x.size()[1], size=(self.num_timesteps,))
representations = []
for idx in idxs:
rep = self.mean_pool(x[:, idx, :, :, :]) # 128x128
Expand All @@ -203,7 +204,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
for d in self.intermediate_dblocks:
rep = d(rep)
rep = self.d6(rep) # 2x2
rep = torch.sum(F.relu(rep), dim = [2, 3])
rep = torch.sum(F.relu(rep), dim=[2, 3])
rep = self.bn(rep)
rep = self.fc(rep)
"""
Expand All @@ -223,7 +224,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
representations.append(rep)

# The representations are summed together before the ReLU
x = torch.stack(representations, dim = 1)
x = torch.stack(representations, dim=1)
# Should be [Batch, N, 1]
x = torch.sum(x, keepdim = True, dim = 1)
x = torch.sum(x, keepdim=True, dim=1)
return x
1 change: 0 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def test_load_dgmr_from_hf():
model = DGMR().from_pretrained("openclimatefix/dgmr")



def test_train_dgmr():
forecast_steps = 8

Expand Down

0 comments on commit 9aa59c5

Please sign in to comment.