Skip to content

Commit

Permalink
Fix/add tests (#53)
Browse files Browse the repository at this point in the history
* rm untested feature

* add tests; rm deprecated features

* fix formatting
  • Loading branch information
JakobEliasWagner authored Aug 19, 2024
1 parent f2a2b09 commit 291a1a3
Show file tree
Hide file tree
Showing 15 changed files with 341 additions and 99 deletions.
4 changes: 0 additions & 4 deletions src/nos/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from .self_supervised_dataset import (
SelfSupervisedDataset,
)
from .transmission_loss import (
TLDataset,
TLDatasetCompact,
Expand All @@ -13,5 +10,4 @@
"TLDatasetCompact",
"TLDatasetCompactExp",
"TLDatasetCompactWave",
"SelfSupervisedDataset",
]
62 changes: 0 additions & 62 deletions src/nos/data/self_supervised_dataset.py

This file was deleted.

16 changes: 1 addition & 15 deletions src/nos/physics/helmholtz_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ def forward(self, y: torch.Tensor, v: torch.Tensor, k: torch.Tensor) -> torch.Te
ks = k.squeeze() ** 2
ks = ks.reshape(-1, 1, 1)
ks = ks.expand(v.size(0), 1, 1)
lpl = []
for dim in range(v.size(-1)):
lpl.append(self.laplace(y, v[:, :, dim]))
lpl = torch.stack(lpl, dim=-1)
lpl = self.laplace(y, v)
return lpl + ks * v


Expand All @@ -31,14 +28,3 @@ def forward(self, y: torch.Tensor, v: torch.Tensor, k: torch.Tensor) -> torch.Te
residual = self.pde(y, v, k)
residual = residual**2
return torch.mean(residual)


class HelmholtzDomainMedianSE(nn.Module):
def __init__(self):
super().__init__()
self.pde = HelmholtzDomainResidual()

def forward(self, y: torch.Tensor, v: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
residual = self.pde(y, v, k)
residual = residual**2
return torch.median(residual)
2 changes: 1 addition & 1 deletion src/nos/physics/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ def forward(self, x: torch.Tensor, u: torch.Tensor, y: torch.Tensor = None) -> t
for dim in range(x.size(-1)):
second_derivatives.append(Grad()(x, derivative[:, :, dim])[:, :, dim])
second_derivatives = torch.stack(second_derivatives, dim=-1)
return torch.sum(second_derivatives, dim=-1)
return torch.sum(second_derivatives, dim=-1, keepdim=True)
3 changes: 1 addition & 2 deletions src/nos/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
)
from .log10_scale import (
Log10Scale,
SymmetricLog10Scale,
)
from .median_peak_scaler_without_shift import (
MedianPeak,
Expand All @@ -15,4 +14,4 @@
QuantileScaler,
)

__all__ = ["MinMaxScale", "Log10Scale", "SymmetricLog10Scale", "QuantileScaler", "CenterQuantileScaler", "MedianPeak"]
__all__ = ["MinMaxScale", "Log10Scale", "QuantileScaler", "CenterQuantileScaler", "MedianPeak"]
12 changes: 0 additions & 12 deletions src/nos/transforms/log10_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,3 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:

def undo(self, tensor: torch.Tensor) -> torch.Tensor:
return 10**tensor


class SymmetricLog10Scale(Transform):
def __init__(self):
super().__init__()

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
out = torch.log10(torch.abs(tensor) + 1)
return out * torch.sign(tensor)

def undo(self, tensor: torch.Tensor) -> torch.Tensor:
return 10**tensor
6 changes: 3 additions & 3 deletions src/nos/transforms/median_peak_scaler_without_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def __init__(
src: torch.Tensor,
):
super().__init__()
peaks, _ = torch.max(torch.abs(src), dim=1)
medians, _ = torch.median(peaks, dim=0)
self.medians = nn.Parameter(medians.view(1, -1))
peaks, _ = torch.max(torch.abs(src), dim=1, keepdim=True)
medians, _ = torch.median(peaks, dim=0, keepdim=True)
self.medians = nn.Parameter(medians)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor / self.medians
Expand Down
Empty file added tests/physics/__init__.py
Empty file.
99 changes: 99 additions & 0 deletions tests/physics/test_helmholtz_residual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import torch

from nos.physics import (
HelmholtzDomainMSE,
HelmholtzDomainResidual,
)


@pytest.fixture(scope="function")
def wave_1d():
n_obs = 13
n_eval = 31

x = torch.rand(n_obs, n_eval, 1)
x.requires_grad = True

ks = 10 * torch.rand(n_obs).reshape(-1, 1, 1)

u = torch.sin(ks * x)

return x, u, ks


@pytest.fixture(scope="function")
def wave_1d_wrong():
n_obs = 13
n_eval = 31

x = torch.rand(n_obs, n_eval, 1)
x.requires_grad = True

ks = 10 * torch.rand(n_obs).reshape(-1, 1, 1)

u = torch.sin(ks * 1.1 * x)

return x, u, ks


class TestHelmholtzDomainResidual:
def test_can_initialize(self):
res = HelmholtzDomainResidual()

assert isinstance(res, HelmholtzDomainResidual)

def test_can_forward(self, wave_1d):
x, u, ks = wave_1d

res = HelmholtzDomainResidual()
res_val = res(x, u, ks)

assert isinstance(res_val, torch.Tensor)

def test_forward_correct(self, wave_1d):
x, u, ks = wave_1d

res = HelmholtzDomainResidual()
res_val = res(x, u, ks)

assert torch.allclose(res_val, torch.zeros(res_val.shape), atol=1e-5)

def test_forward_wrong(self, wave_1d_wrong):
x, u, ks = wave_1d_wrong

res = HelmholtzDomainResidual()
res_val = res(x, u, ks)

assert not torch.any(torch.isclose(res_val, torch.zeros(res_val.shape), atol=1e-5))


class TestHelmholtzDomainMSE:
def test_can_initialize(self):
res = HelmholtzDomainMSE()

assert isinstance(res, HelmholtzDomainMSE)

def test_can_forward(self, wave_1d):
x, u, ks = wave_1d

res = HelmholtzDomainMSE()
res_val = res(x, u, ks)

assert isinstance(res_val, torch.Tensor)

def test_forward_correct(self, wave_1d):
x, u, ks = wave_1d

res = HelmholtzDomainMSE()
res_val = res(x, u, ks)

assert res_val < 1e-8

def test_forward_wrong(self, wave_1d_wrong):
x, u, ks = wave_1d_wrong

res = HelmholtzDomainMSE()
res_val = res(x, u, ks)

assert res_val > 1.0
50 changes: 50 additions & 0 deletions tests/physics/test_laplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
import torch

from nos.physics import (
Laplace,
)


@pytest.fixture(scope="function")
def derivative_pair():
xx, yy = torch.meshgrid(
torch.linspace(-1, 1, 100),
torch.linspace(-1, 1, 100),
)
x = torch.stack([xx.flatten(), yy.flatten()], dim=-1)
x = x.unsqueeze(0) # 1 observation
x.requires_grad = True
u = torch.sin(x[:, :, 0]).unsqueeze(-1)
return x, u


class TestLaplace:
def test_can_initialize(self):
_ = Laplace()

assert True

def test_can_forward(self, derivative_pair):
x, u = derivative_pair

laplace = Laplace()
lpl_u = laplace(x, u)

assert isinstance(lpl_u, torch.Tensor)

def test_forward_shape_correct(self, derivative_pair):
x, u = derivative_pair

laplace = Laplace()
lpl_u = laplace(x, u)

assert u.shape == lpl_u.shape

def test_forward_correct(self, derivative_pair):
x, u = derivative_pair

laplace = Laplace()
lpl_u = laplace(x, u)

assert torch.allclose(lpl_u, -u)
50 changes: 50 additions & 0 deletions tests/physics/test_weight_scheduler_lin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch

from nos.physics import (
WeightSchedulerLinear,
)


class TestWeightSchedulerLin:
def test_can_initialize(self):
scheduler = WeightSchedulerLinear(1, 2)

assert isinstance(scheduler, WeightSchedulerLinear)

def test_data_weight_correct(self):
scheduler = WeightSchedulerLinear(1, 2)

epochs = torch.arange(0, 100)
weight = scheduler._get_data_weight(epochs)

assert isinstance(weight, torch.Tensor)
assert torch.allclose(weight, torch.ones(weight.shape))

def test_pde_weight_correct(self):
scheduler = WeightSchedulerLinear(10, 20)

epochs = torch.arange(0, 100)
weight = scheduler._get_pde_weight(epochs)

assert torch.allclose(weight[0:10], torch.zeros(10))
assert torch.allclose(weight[20:], 1e-3 * torch.ones(80))
assert torch.allclose(weight[10:21], 1e-3 * (epochs[10:21] - 10.0) / 10.0)

def test_forward_weight_correct(self):
scheduler = WeightSchedulerLinear(10, 20)

epochs = torch.arange(0, 100)
weight = scheduler(epochs)

assert torch.allclose(weight[0:10, 1], torch.zeros(10))
assert torch.allclose(weight[20:, 1], 1e-3 * torch.ones(80))
assert torch.allclose(weight[10:21, 1], 1e-3 * (epochs[10:21] - 10.0) / 10.0)

assert torch.allclose(weight[:, 0], torch.ones(100))

def test_forward_float_weight_correct(self):
scheduler = WeightSchedulerLinear(10, 20)

weight = scheduler(15.0)

assert torch.allclose(weight, torch.tensor([1.0, 5e-4]))
36 changes: 36 additions & 0 deletions tests/transforms/test_center_quantile_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch

from nos.transforms import (
CenterQuantileScaler,
)


class TestCenterQuantileScaler:
x = torch.rand(7, 89, 13)
y = torch.rand(17, 97, 13)

def test_can_init(self):
transform = CenterQuantileScaler(self.x)

assert isinstance(transform, CenterQuantileScaler)

def test_can_forward(self):
transform = CenterQuantileScaler(self.x)

out = transform(self.y)

assert isinstance(out, torch.Tensor)

def test_forward_centered(self):
transform = CenterQuantileScaler(self.x)

out = transform(self.y)

assert torch.isclose(torch.median(out), torch.zeros(1), atol=1e-1)

def test_can_undo(self):
transform = CenterQuantileScaler(self.x)

out = transform.undo(self.y)

assert isinstance(out, torch.Tensor)
Loading

0 comments on commit 291a1a3

Please sign in to comment.