diff --git a/src/nos/data/pulsating_sphere.py b/src/nos/data/pulsating_sphere.py index 04e9c05..bba1c3e 100644 --- a/src/nos/data/pulsating_sphere.py +++ b/src/nos/data/pulsating_sphere.py @@ -63,7 +63,7 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tenso x_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]) x_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]) x = ((self.x[idx] - x_min) / x_scale) * 2.0 - 1.0 - v_max, _ = torch.max(torch.abs(self.v[idx]), dim=-1, keepdim=True) + v_max, _ = torch.max(torch.abs(self.v[idx]), dim=1, keepdim=True) return x, x, self.y[idx] * 2.0 - 1.0, self.v[idx] / v_max @@ -142,7 +142,7 @@ def __init__( def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Get item at idx from dataset.""" x = self.x[idx] * 2.0 - 1.0 - u_max, _ = torch.max(torch.abs(self.u[idx]), dim=-1, keepdim=True) + u_max, _ = torch.max(torch.abs(self.u[idx]), dim=1, keepdim=True) u = self.u[idx] / u_max y = self.y[idx] * 2.0 - 1.0