Skip to content

Commit

Permalink
fix device
Browse files Browse the repository at this point in the history
  • Loading branch information
clementchadebec committed Sep 11, 2023
1 parent 0e2a8e2 commit 4ce35b1
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 9 deletions.
18 changes: 18 additions & 0 deletions src/pythae/samplers/base/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
from typing import Any, Dict
from imageio import imwrite

from ...models import BaseAE
Expand Down Expand Up @@ -93,3 +94,20 @@ def save_img(self, img_tensor: torch.Tensor, dir_path: str, img_name: str):

img = img.astype("uint8")
imwrite(os.path.join(dir_path, f"{img_name}"), img)

def _set_inputs_to_device(self, inputs: Dict[str, Any]):

inputs_on_device = inputs

if self.device == "cuda":
cuda_inputs = dict.fromkeys(inputs)

for key in inputs.keys():
if torch.is_tensor(inputs[key]):
cuda_inputs[key] = inputs[key].cuda()

else:
cuda_inputs[key] = inputs[key]
inputs_on_device = cuda_inputs

return inputs_on_device
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def fit(self, train_data: Union[torch.Tensor, np.ndarray, Dataset], **kwargs):

if not isinstance(train_data, Dataset):
data_processor = DataProcessor()
train_data = data_processor.process_data(train_data).to(self.device)
train_data = data_processor.process_data(train_data)
train_dataset = data_processor.to_dataset(train_data)

else:
Expand All @@ -74,11 +74,13 @@ def fit(self, train_data: Union[torch.Tensor, np.ndarray, Dataset], **kwargs):
try:
with torch.no_grad():
for _, inputs in enumerate(train_loader):
inputs = self._set_inputs_to_device(inputs)
z_ = self.model(inputs).z
z.append(z_)

except RuntimeError:
for _, inputs in enumerate(train_loader):
inputs = self._set_inputs_to_device(inputs)
z_ = self.model(inputs).z.detach()
z.append(z_)

Expand Down
8 changes: 6 additions & 2 deletions src/pythae/samplers/iaf_sampler/iaf_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def fit(
"""
data_processor = DataProcessor()
if not isinstance(train_data, Dataset):
train_data = data_processor.process_data(train_data).to(self.device)
train_data = data_processor.process_data(train_data)
train_dataset = data_processor.to_dataset(train_data)

else:
Expand All @@ -91,12 +91,14 @@ def fit(
try:
with torch.no_grad():
for _, inputs in enumerate(train_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z
z.append(z_)

except RuntimeError:
for _, inputs in enumerate(train_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z.detach()
z.append(z_)
Expand All @@ -109,7 +111,7 @@ def fit(
if eval_data is not None:

if not isinstance(eval_data, Dataset):
eval_data = data_processor.process_data(eval_data).to(self.device)
eval_data = data_processor.process_data(eval_data)
eval_dataset = data_processor.to_dataset(eval_data)

else:
Expand All @@ -127,12 +129,14 @@ def fit(
try:
with torch.no_grad():
for _, inputs in enumerate(eval_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z
z.append(z_)

except RuntimeError:
for _, inputs in enumerate(eval_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z.detach()
z.append(z_)
Expand Down
8 changes: 6 additions & 2 deletions src/pythae/samplers/maf_sampler/maf_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def fit(

data_processor = DataProcessor()
if not isinstance(train_data, Dataset):
train_data = data_processor.process_data(train_data).to(self.device)
train_data = data_processor.process_data(train_data)
train_dataset = data_processor.to_dataset(train_data)

else:
Expand All @@ -92,12 +92,14 @@ def fit(
try:
with torch.no_grad():
for _, inputs in enumerate(train_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z
z.append(z_)

except RuntimeError:
for _, inputs in enumerate(train_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z.detach()
z.append(z_)
Expand All @@ -110,7 +112,7 @@ def fit(
if eval_data is not None:

if not isinstance(eval_data, Dataset):
eval_data = data_processor.process_data(eval_data).to(self.device)
eval_data = data_processor.process_data(eval_data)
eval_dataset = data_processor.to_dataset(eval_data)

else:
Expand All @@ -127,12 +129,14 @@ def fit(
try:
with torch.no_grad():
for _, inputs in enumerate(eval_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z
z.append(z_)

except RuntimeError:
for _, inputs in enumerate(eval_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z.detach()
z.append(z_)
Expand Down
6 changes: 4 additions & 2 deletions src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def fit(

data_processor = DataProcessor()
if not isinstance(train_data, Dataset):
train_data = data_processor.process_data(train_data).to(self.device)
train_data = data_processor.process_data(train_data)
train_dataset = data_processor.to_dataset(train_data)

else:
Expand All @@ -91,6 +91,7 @@ def fit(

with torch.no_grad():
for _, inputs in enumerate(train_loader):
inputs = self._set_inputs_to_device(inputs)
model_output = self.model(inputs)
mean_z = model_output.quantized_indices
z.append(
Expand All @@ -107,7 +108,7 @@ def fit(
if eval_data is not None:

if not isinstance(eval_data, Dataset):
eval_data = data_processor.process_data(eval_data).to(self.device)
eval_data = data_processor.process_data(eval_data)
eval_dataset = data_processor.to_dataset(eval_data)

else:
Expand All @@ -124,6 +125,7 @@ def fit(

with torch.no_grad():
for _, inputs in enumerate(eval_loader):
inputs = self._set_inputs_to_device(inputs)
model_output = self.model(inputs)
mean_z = model_output.quantized_indices
z.append(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def fit(

data_processor = DataProcessor()
if not isinstance(train_data, Dataset):
train_data = data_processor.process_data(train_data).to(self.device)
train_data = data_processor.process_data(train_data)
train_dataset = data_processor.to_dataset(train_data)

else:
Expand All @@ -171,12 +171,14 @@ def fit(
try:
with torch.no_grad():
for _, inputs in enumerate(train_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z
z.append(z_)

except RuntimeError:
for _, inputs in enumerate(train_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z.detach()
z.append(z_)
Expand All @@ -189,7 +191,7 @@ def fit(
if eval_data is not None:

if not isinstance(eval_data, Dataset):
eval_data = data_processor.process_data(eval_data).to(self.device)
eval_data = data_processor.process_data(eval_data)
eval_dataset = data_processor.to_dataset(eval_data)

else:
Expand All @@ -207,12 +209,14 @@ def fit(
try:
with torch.no_grad():
for _, inputs in enumerate(eval_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z
z.append(z_)

except RuntimeError:
for _, inputs in enumerate(eval_loader):
inputs = self._set_inputs_to_device(inputs)
encoder_output = self.model(inputs)
z_ = encoder_output.z.detach()
z.append(z_)
Expand Down

0 comments on commit 4ce35b1

Please sign in to comment.