-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcreate_dataset.py
111 lines (85 loc) · 3.88 KB
/
create_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import cv2
import glob
import random
import progressbar
import numpy as np
import matplotlib.pyplot as plt
rand_color = lambda : (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
rand_pos = lambda a, b: (random.randint(a, b-1), random.randint(a, b-1))
target_size = 256
imgs_per_back = 30
backs = glob.glob('./dataset/backs/*.png')
fonts = glob.glob('./dataset/font_mask/*.png')
os.makedirs('./dataset/train/I', exist_ok=True)
os.makedirs('./dataset/train/Itegt', exist_ok=True)
os.makedirs('./dataset/train/Mm', exist_ok=True)
os.makedirs('./dataset/train/Msgt', exist_ok=True)
os.makedirs('./dataset/val/I', exist_ok=True)
os.makedirs('./dataset/val/Itegt', exist_ok=True)
os.makedirs('./dataset/val/Mm', exist_ok=True)
os.makedirs('./dataset/val/Msgt', exist_ok=True)
t_idx = len(os.listdir('./dataset/train/I'))
v_idx = len(os.listdir('./dataset/val/I'))
bar = progressbar.ProgressBar(maxval=len(backs)*imgs_per_back)
bar.start()
for back in backs:
back_img = cv2.imread(back)
bh, bw, _ = back_img.shape
if bh < target_size or bw < target_size:
back_img = cv2.resize(back_img, (target_size, target_size), interpolation=cv2.INTER_CUBIC)
bh, bw, _ = back_img.shape
for bi in range(imgs_per_back):
sx, sy = random.randint(0, bw-target_size), random.randint(0, bh-target_size)
Itegt = back_img[sy:sy+target_size, sx:sx+target_size, :].copy()
I = Itegt.copy()
Mm = np.zeros_like(I)
Msgt = np.zeros_like(I)
hist = []
for font in random.sample(fonts, random.randint(2, 4)):
font_img = cv2.imread(font)
mask_img = np.ones_like(font_img, dtype=np.uint8)*255
height, width, _ = font_img.shape
angle = random.randint(-30, +30)
fs = random.randint(90, 120)
ratio = fs / height - 0.2
matrix = cv2.getRotationMatrix2D((width/2, height/2), angle, ratio)
font_rot = cv2.warpAffine(font_img, matrix, (width, height), cv2.INTER_CUBIC)
mask_rot = cv2.warpAffine(mask_img, matrix, (width, height), cv2.INTER_CUBIC)
h, w, _ = font_rot.shape
font_in_I = np.zeros_like(I)
mask_in_I = np.zeros_like(I)
allow = 0
while True:
sx, sy = rand_pos(0, target_size-w)
done = True
for sx_, sy_ in hist:
if (sx_ - sx)**2 + (sy_ - sy)**2 < (fs * ratio)**2 - allow:
done = False
break
allow += 5
if done:
hist.append([sx, sy])
break
font_in_I[sy:sy+h, sx:sx+w, :] = font_rot
mask_in_I[sy:sy+h, sx:sx+w, :] = mask_rot
font_in_I[font_in_I > 30] = 255
mask_in_I[mask_in_I > 30] = 255
I = cv2.bitwise_and(I, 255-font_in_I)
I = cv2.bitwise_or(I, (font_in_I // 255 * rand_color()).astype(np.uint8))
Mm = cv2.bitwise_or(Mm, mask_in_I)
Msgt = cv2.bitwise_or(Msgt, font_in_I)
if bi < imgs_per_back*0.8:
cv2.imwrite(f'dataset/train/I/{t_idx}.png', I)
cv2.imwrite(f'dataset/train/Itegt/{t_idx}.png', Itegt)
cv2.imwrite(f'dataset/train/Mm/{t_idx}.png', Mm)
cv2.imwrite(f'dataset/train/Msgt/{t_idx}.png', Msgt)
t_idx += 1
else:
cv2.imwrite(f'dataset/val/I/{v_idx}.png', I)
cv2.imwrite(f'dataset/val/Itegt/{v_idx}.png', Itegt)
cv2.imwrite(f'dataset/val/Mm/{v_idx}.png', Mm)
cv2.imwrite(f'dataset/val/Msgt/{v_idx}.png', Msgt)
v_idx += 1
bar.update(t_idx + v_idx)
bar.finish()