Skip to content

Commit

Permalink
Fix: pulsating sphere transformation (#69)
Browse files Browse the repository at this point in the history
* add pulsating sphere datasets

* add dataset

* change test use small dataset

* rm mv to cuda

* fix transformation
  • Loading branch information
JakobEliasWagner authored Sep 19, 2024
1 parent 16539c0 commit 3354ccb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/nos/data/pulsating_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tenso
x_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]])
x_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]])
x = ((self.x[idx] - x_min) / x_scale) * 2.0 - 1.0
v_max, _ = torch.max(torch.abs(self.v[idx]), dim=-1, keepdim=True)
v_max, _ = torch.max(torch.abs(self.v[idx]), dim=1, keepdim=True)

return x, x, self.y[idx] * 2.0 - 1.0, self.v[idx] / v_max

Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get item at idx from dataset."""
x = self.x[idx] * 2.0 - 1.0
u_max, _ = torch.max(torch.abs(self.u[idx]), dim=-1, keepdim=True)
u_max, _ = torch.max(torch.abs(self.u[idx]), dim=1, keepdim=True)
u = self.u[idx] / u_max

y = self.y[idx] * 2.0 - 1.0
Expand Down

0 comments on commit 3354ccb

Please sign in to comment.