Skip to content

Commit

Permalink
rm mv to cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobEliasWagner committed Sep 19, 2024
1 parent e4bddf5 commit 89205b7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/nos/data/pulsating_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(

def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get item at idx from dataset."""
x_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]).to("cuda")
x_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]).to("cuda")
x_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]])
x_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]])
x = ((self.x[idx] - x_min) / x_scale) * 2.0 - 1.0
v_max, _ = torch.max(torch.abs(self.v[idx]), dim=-1, keepdim=True)

Expand Down Expand Up @@ -130,8 +130,8 @@ def __init__(
y = y.expand(n_observations, -1, -1)

v = torch.cat([top_samples, right_samples, frequency_samples], dim=1).unsqueeze(-1)
self.v_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]]).to("cuda")
self.v_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]]).to("cuda")
self.v_min = torch.tensor([[0.0], [-1.0], [0.0], [-1.0], [400.0]])
self.v_scale = torch.tensor([[1.0], [1.0], [1.0], [1.0], [100.0]])
v = v.expand(-1, -1, y.size(-1))

perm = torch.randperm(n_observations)
Expand Down

0 comments on commit 89205b7

Please sign in to comment.