Skip to content

Commit

Permalink
fix transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobEliasWagner committed Sep 19, 2024
1 parent 89205b7 commit 88d0a31
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/nos/data/pulsating_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tenso
x_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]])
x_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]])
x = ((self.x[idx] - x_min) / x_scale) * 2.0 - 1.0
v_max, _ = torch.max(torch.abs(self.v[idx]), dim=-1, keepdim=True)
v_max, _ = torch.max(torch.abs(self.v[idx]), dim=1, keepdim=True)

return x, x, self.y[idx] * 2.0 - 1.0, self.v[idx] / v_max

Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get item at idx from dataset."""
x = self.x[idx] * 2.0 - 1.0
u_max, _ = torch.max(torch.abs(self.u[idx]), dim=-1, keepdim=True)
u_max, _ = torch.max(torch.abs(self.u[idx]), dim=1, keepdim=True)
u = self.u[idx] / u_max

y = self.y[idx] * 2.0 - 1.0
Expand Down

0 comments on commit 88d0a31

Please sign in to comment.