diff --git a/src/nos/data/pulsating_sphere.py b/src/nos/data/pulsating_sphere.py index 13eab36..04e9c05 100644 --- a/src/nos/data/pulsating_sphere.py +++ b/src/nos/data/pulsating_sphere.py @@ -60,8 +60,8 @@ def __init__( def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Get item at idx from dataset.""" - x_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]).to("cuda") - x_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]).to("cuda") + x_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]) + x_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]) x = ((self.x[idx] - x_min) / x_scale) * 2.0 - 1.0 v_max, _ = torch.max(torch.abs(self.v[idx]), dim=-1, keepdim=True) @@ -130,8 +130,8 @@ def __init__( y = y.expand(n_observations, -1, -1) v = torch.cat([top_samples, right_samples, frequency_samples], dim=1).unsqueeze(-1) - self.v_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]).to("cuda") - self.v_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]).to("cuda") + self.v_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]) + self.v_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]) v = v.expand(-1, -1, y.size(-1)) perm = torch.randperm(n_observations)