-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlab_utils.py
254 lines (209 loc) · 9.07 KB
/
lab_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import cv2
from IPython.display import display, Image
import matplotlib.pyplot as plt
import numpy as np
def lrSchedule(base_lr, iter, iters, epoch=0, time=0, step=(30, 60, 90), target_lr=0.0, mode='cosine', per_epoch=0):
lr = target_lr if target_lr else base_lr
iters = iters if iter < iters else iter
# every iteration
if mode == 'cosine':
lr += (base_lr - target_lr) * (1 + np.cos(np.pi * iter / iters)) / 2.0
# every epochs
elif mode == 'step':
if epoch in step:
pass
# warmup
elif mode == 'cycleCosine':
time = time if time else 4
T = iters // time
cur_iter, cur_time = iter % T, iter // T
cur_base = (base_lr - target_lr) / time * (time - cur_time) + target_lr
lr += (cur_base - target_lr) * (1 + np.cos(np.pi * cur_iter / T)) * 0.5
elif mode == 'hhh':
if epoch < 200:
cur_lr = base_lr
cur_iter = iter
lr += (cur_lr - target_lr) * (1 + np.cos(np.pi * cur_iter / 200 * per_epoch)) * 0.5
elif epoch < 300:
cur_lr = (base_lr - target_lr) / 4 * 3 + target_lr
cur_iter = iter - per_epoch * 200
lr += (cur_lr - target_lr) * (1 + np.cos(np.pi * cur_iter / 100 * per_epoch)) * 0.5
elif epoch < 350:
cur_lr = (base_lr - target_lr) / 4 * 2 + target_lr
cur_iter = iter - per_epoch * 300
lr += (cur_lr - target_lr) * (1 + np.cos(np.pi * cur_iter / 50 * per_epoch)) * 0.5
else:
cur_lr = (base_lr - target_lr) / 4 + target_lr
cur_iter = iter - epoch * 350
lr += (cur_lr - target_lr) * (1 + np.cos(np.pi * cur_iter / 50 * per_epoch)) * 0.5
return lr
class CycleLR:
def __init__(self, warm_epoch, all_epoch, target_lr, iters_epoch, period=4):
self.iter = -1
self.warmIters = iters_epoch * warm_epoch
self.target_lr = target_lr
self.step_iters = (all_epoch - warm_epoch) // period * iters_epoch
def step(self):
self.iter += 1
if self.iter <= self.warmIters:
lr = self.warm()
else:
cur_iter = (self.iter - self.warmIters) % self.step_iters
lr = self.cosine(cur_iter)
return lr
def warm(self):
return self.iter / self.warmIters * self.target_lr
def cosine(self, cur_iter):
lr = self.target_lr * (1 + np.cos(np.pi * cur_iter / self.step_iters)) * 0.5
return lr
components = [[i for i in range(33)], [76, 87, 86, 85, 84, 83, 82], [88, 95, 94, 93, 92], [88, 89, 90, 91, 92],
[76, 77, 78, 79, 80, 81, 82], [55 + i for i in range(5)], [51 + i for i in range(4)],
[60, 67, 66, 65, 64], [60 + i for i in range(5)], [33 + i for i in range(9)], [68, 75, 74, 73, 72],
[68 + i for i in range(5)], [42 + i for i in range(9)]]
def getBBox(points):
"""
:param points: x1, y1, x2, y2
:return:
"""
bbox = np.array([min(points[0::4]), min((points[1::4])), max(points[2::4]), max(points[3::4])])
return bbox
def enlargeBBox(bbox, factor=0.05):
x1, y1, x2, y2 = bbox
width = x2 - x1
height = y2 - y1
x1 = x1 - width * factor
x2 = x2 + width * factor
y1 = y1 - height * factor
y2 = y2 + height * factor
return np.array([x1, y1, x2, y2])
def show(img_path):
display(Image(img_path))
def landmark(path, points, detection=False):
img = cv2.imread('WFLW_crop/' + path)
if detection:
x0, y0, x1, y1 = points[98 * 2:98 * 2 + 4]
# img = img[y0:y1, x0:x1]
cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 2)
for i in range(0, 98 * 2, 2):
cv2.circle(img, (points[i], points[i + 1]), 1, (23, 25, 0), 1)
plt.imshow(img[:, :, ::-1])
def drawLine(img, points, color=(25, 100, 0), return_image=True, thickness=1):
"""
:param thickness:
:param return_image:
:param img:
:param points: index 9 and 12 must be close (196 landmarks and 4 detection boxes)
:param color:
:return:
"""
if return_image:
color = color if len(img.shape) else (255,)
for com in range(len(components)):
for i in range(len(components[com]) - 1):
p1 = components[com][i]
p2 = components[com][i + 1]
cv2.line(img, (points[p1 * 2], points[p1 * 2 + 1]),
(points[p2 * 2], points[p2 * 2 + 1]), color, thickness, )
if com == 9 or com == 12:
p1 = components[com][0]
p2 = components[com][-1]
cv2.line(img, (points[p1 * 2], points[p1 * 2 + 1]),
(points[p2 * 2], points[p2 * 2 + 1]), color, thickness, )
else:
img = np.zeros((img.shape[0], img.shape[1], 13), dtype=np.uint8)
color = (255,)
for com in range(len(components)):
image = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
for i in range(len(components[com]) - 1):
p1 = components[com][i]
p2 = components[com][i + 1]
cv2.line(image, (points[p1 * 2], points[p1 * 2 + 1]),
(points[p2 * 2], points[p2 * 2 + 1]),
color,
2, )
if com in [9, 12]:
p1 = components[com][0]
p2 = components[com][-1]
cv2.line(image, (points[p1 * 2], points[p1 * 2 + 1]),
(points[p2 * 2], points[p2 * 2 + 1]), color, thickness, )
img[:, :, com] = image
return img
def drawDistanceImg(img, points, color=(25, 100, 0)):
img = drawLine(img, points, color=color, return_image=False)
assert img.shape[2] == 13, "This is not the 13 components heatmap!"
for i in range(13):
img[:, :, i] = cv2.distanceTransform(255 - img[:, :, i], cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
return img
def drawGaussianHeatmap(img, points, color=(25, 100, 0), sigma=4):
dist_img = drawDistanceImg(img, points, color=color)
# heatmap = (1.0 / np.sqrt(2 * np.pi * sigma)) * np.exp(-1.0 * dist_img ** 2 / (2.0 * sigma ** 2))
heatmap = np.exp(-1.0 * dist_img ** 2 / (2.0 * sigma ** 2))
heatmap = np.where(dist_img < (3.0 * sigma), heatmap, 0)
for i in range(13):
maxVal = heatmap[:, :, i].max()
minVal = heatmap[:, :, i].min()
if maxVal == minVal:
heatmap[:, :, i] = 0
else:
heatmap[:, :, i] = (heatmap[:, :, i] - minVal) / (maxVal - minVal)
return heatmap
def drawPoint(img, points, color=(25, 100, 0)):
"""
:param color:
:param img: RGB Image
:param points: list type
:return:
"""
for i in range(0, 98 * 2, 2):
cv2.circle(img, (points[i], points[i + 1]), 2, color, 2)
return img
def figPicture(data, heatmap, widehatHeatmap, pred_real_landmarks, pred_fake_landmarks):
"""
:param data: [3, 256, 256] RGB
:param heatmap: [13, 64, 64]
:param widehatHeatmap: [13, 64, 64]
:return: [64 * 4, 64 * 7, 1]
"""
data = np.mean(cv2.resize(np.moveaxis(data, 0, 2), (64, 64)), axis=2, keepdims=False, dtype=np.float32)
for com in range(len(components)):
for i in range(len(components[com])):
p1 = components[com][i]
heatmap[com] = cv2.circle(heatmap[com], (pred_real_landmarks[p1 * 2], pred_real_landmarks[p1 * 2 + 1]), 2,
(1), 2, )
widehatHeatmap[com] = cv2.circle(widehatHeatmap[com],
(pred_fake_landmarks[p1 * 2], pred_fake_landmarks[p1 * 2 + 1]), 2, (1),
2, )
heatmap = np.moveaxis(heatmap, 0, 2).copy()
widehatHeatmap = np.moveaxis(widehatHeatmap, 0, 2).copy()
line1 = np.concatenate([data] + [heatmap[..., i] for i in range(6)], axis=1)
line2 = np.concatenate([heatmap[..., i + 6] for i in range(7)], axis=1)
line3 = np.concatenate([data] + [widehatHeatmap[..., i] for i in range(6)], axis=1)
line4 = np.concatenate([widehatHeatmap[..., i + 6] for i in range(7)], axis=1)
fig = np.concatenate([line1, line2, line3, line4], axis=0)
return fig * 256
def plot_land(data, landmarks):
"""
:param data: (256, 256)
:param landmarks: (196,)
:return: data: (256, 256)
"""
data = data.copy()
for i in range(int(landmarks.shape[0]/2)):
data = cv2.circle(data, (int(landmarks[i*2]), int(landmarks[i*2+1])), 2, (1), 2, )
return data
def figPicture2Land2(data, heatmap, widehatHeatmap, pred_fake_landmarks, pred_real_landmarks, landmarks):
"""
:param data: [3, 256, 256] RGB
:param heatmap: [13, 64, 64]
:param widehatHeatmap: [13, 64, 64]
:param pred_fake_landmarks: (196,)
:param pred_real_landmarks: (196,)
:param landmarks: (196,)
:return: fig: (256*3, 256)
"""
data = np.mean(np.moveaxis(data, 0, 2), axis=2, keepdims=False, dtype=np.float32)
data_land_fake = plot_land(data, pred_fake_landmarks)
data_land_real = plot_land(data, pred_real_landmarks)
data_land_truth = plot_land(data, landmarks)
fig = np.concatenate([data_land_fake, data_land_real, data_land_truth], axis=1)
return fig