From 89205b7fa221c9c50e1ae74fb3df03d12429f110 Mon Sep 17 00:00:00 2001 From: JakobEliasWagner Date: Thu, 19 Sep 2024 20:28:42 +0200 Subject: [PATCH] rm mv to cuda --- src/nos/data/pulsating_sphere.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/nos/data/pulsating_sphere.py b/src/nos/data/pulsating_sphere.py index 13eab36..04e9c05 100644 --- a/src/nos/data/pulsating_sphere.py +++ b/src/nos/data/pulsating_sphere.py @@ -60,8 +60,8 @@ def __init__( def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Get item at idx from dataset.""" - x_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]).to("cuda") - x_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]).to("cuda") + x_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]) + x_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]) x = ((self.x[idx] - x_min) / x_scale) * 2.0 - 1.0 v_max, _ = torch.max(torch.abs(self.v[idx]), dim=-1, keepdim=True) @@ -130,8 +130,8 @@ def __init__( y = y.expand(n_observations, -1, -1) v = torch.cat([top_samples, right_samples, frequency_samples], dim=1).unsqueeze(-1) - self.v_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]).to("cuda") - self.v_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]).to("cuda") + self.v_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]) + self.v_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]) v = v.expand(-1, -1, y.size(-1)) perm = torch.randperm(n_observations)