-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdataset.py
75 lines (56 loc) · 2.33 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os, cv2
import numpy as np
import torch
from torch.utils.data import Dataset
def mat_to_tensor(mat):
mat = mat.transpose((2, 0, 1))
tensor = torch.Tensor(mat)
return tensor
def tensor_to_mat(tensor):
mat = tensor.detach().cpu().numpy()
mat = mat.transpose((0, 2, 3, 1))
return mat
def preprocess_image(img, target_shape: tuple):
img = cv2.resize(img, target_shape, interpolation=cv2.INTER_CUBIC).astype(np.float32)
img = img / 255.
if len(img.shape) == 2:
img = img.reshape(*img.shape, 1)
return img
def postprocess_image(img):
img = img * 255
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
class CustomDataset(Dataset):
def __init__(self,
data_dir,
set_name="train",
target_size=(256, 256)):
super().__init__()
self.root_dir = os.path.join(data_dir, set_name)
self.target_size = target_size
self.I_dir = os.path.join(self.root_dir, "I")
self.Itegt_dir = os.path.join(self.root_dir, "Itegt")
self.Mm_dir = os.path.join(self.root_dir, "Mm")
self.Msgt_dir = os.path.join(self.root_dir, "Msgt")
self.datas = os.listdir(self.I_dir)
def __len__(self):
return len(self.datas)
def __getitem__(self, idx):
img_name = self.datas[idx]
I = cv2.imread(os.path.join(self.I_dir, img_name))
Itegt = cv2.imread(os.path.join(self.Itegt_dir, img_name))
Mm = cv2.imread(os.path.join(self.Mm_dir, img_name), cv2.IMREAD_GRAYSCALE)
Msgt = cv2.imread(os.path.join(self.Msgt_dir, img_name), cv2.IMREAD_GRAYSCALE)
I = mat_to_tensor(preprocess_image(I, self.target_size))
Itegt = mat_to_tensor(preprocess_image(Itegt, self.target_size))
Mm = mat_to_tensor(preprocess_image(Mm, self.target_size))
Msgt = mat_to_tensor(preprocess_image(Msgt, self.target_size))
return I, Itegt, Mm, Msgt
if __name__ == "__main__":
ds = CustomDataset('dataset', 'train')
I, Itegt, Mm, Ms = ds.__getitem__(0)
print(f"Dataset length : {len(ds)}")
print(f"I shape : {I.shape}")
print(f"Itegt shape : {Itegt.shape}")
print(f"Mm shape : {Mm.shape}")
print(f"Ms shape : {Ms.shape}")