Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 9, 2021
1 parent 0f517c5 commit 7d4c987
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_discriminator():
assert out.shape == (2, 2, 1)
assert not torch.isnan(out).any()


def test_sampler():
input_channels = 1
conv_type = "standard"
Expand All @@ -93,16 +94,16 @@ def test_sampler():
input_channels=input_channels,
conv_type=conv_type,
output_channels=context_channels,
)
)
latent_stack = LatentConditioningStack(
shape=(8 * input_channels, output_shape // 32, output_shape // 32),
output_channels=latent_channels,
)
)
sampler = Sampler(
forecast_steps=forecast_steps,
latent_channels=latent_channels,
context_channels=context_channels,
)
)
latent_stack.eval()
conditioning_stack.eval()
sampler.eval()
Expand All @@ -115,7 +116,7 @@ def test_sampler():
# Expand latent dim to match batch size
latent_dim = einops.repeat(
latent_dim, "b c h w -> (repeat b) c h w", repeat=init_states[0].shape[0]
)
)
assert not torch.isnan(latent_dim).any()
hidden_states = [latent_dim] * forecast_steps
assert not all(torch.isnan(hidden_states[i]).any() for i in range(len(hidden_states)))
Expand Down

0 comments on commit 7d4c987

Please sign in to comment.