-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_data.py
64 lines (49 loc) · 2.24 KB
/
train_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch.utils.data as data
from PIL import Image
from random import randrange
from torchvision.transforms import Compose, ToTensor, Normalize
import glob
class TrainData(data.Dataset):
def __init__(self, crop_size, train_data_dir):
super().__init__()
train_list = []
for file in glob.glob(train_data_dir + "haze/*"):
name = file.split('/')[-1]
if name not in train_list:
train_list.append(name)
haze_names = [i.strip() for i in train_list]
gt_names = [i.split('_')[0] for i in haze_names]
self.haze_names = haze_names
self.gt_names = gt_names
self.crop_size = crop_size
self.train_data_dir = train_data_dir
def get_images(self, index):
crop_width, crop_height = self.crop_size
haze_name = self.haze_names[index]
gt_name = self.gt_names[index]
haze_img = Image.open(self.train_data_dir + 'haze/' + haze_name)
try:
gt_img = Image.open(self.train_data_dir + 'clear/' + gt_name + '.jpg').convert('RGB')
except:
gt_img = Image.open(self.train_data_dir + 'clear/' + gt_name + '.png').convert('RGB')
width, height = haze_img.size
if width < crop_width or height < crop_height:
raise Exception('Bad image size: {}'.format(gt_name))
# --- x,y coordinate of left-top corner --- #
x, y = randrange(0, width - crop_width + 1), randrange(0, height - crop_height + 1)
haze_crop_img = haze_img.crop((x, y, x + crop_width, y + crop_height))
gt_crop_img = gt_img.crop((x, y, x + crop_width, y + crop_height))
# --- Transform to tensor --- #
transform_haze = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_gt = Compose([ToTensor()])
haze = transform_haze(haze_crop_img)
gt = transform_gt(gt_crop_img)
# --- Check the channel is 3 or not --- #
if list(haze.shape)[0] is not 3 or list(gt.shape)[0] is not 3:
raise Exception('Bad image channel: {}'.format(gt_name))
return haze, gt
def __getitem__(self, index):
res = self.get_images(index)
return res
def __len__(self):
return len(self.haze_names)