-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathkmeans_for_anchors.py
158 lines (133 loc) · 6.28 KB
/
kmeans_for_anchors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#-------------------------------------------------------------------------------------------------------#
# kmeans虽然会对数据集中的框进行聚类,但是很多数据集由于框的大小相近,聚类出来的9个框相差不大,
# 这样的框反而不利于模型的训练。因为不同的特征层适合不同大小的先验框,shape越小的特征层适合越大的先验框
# 原始网络的先验框已经按大中小比例分配好了,不进行聚类也会有非常好的效果。
#-------------------------------------------------------------------------------------------------------#
import glob
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
def cas_ratio(box,cluster):
ratios_of_box_cluster = box / cluster
ratios_of_cluster_box = cluster / box
ratios = np.concatenate([ratios_of_box_cluster, ratios_of_cluster_box], axis = -1)
return np.max(ratios, -1)
def avg_ratio(box,cluster):
return np.mean([np.min(cas_ratio(box[i],cluster)) for i in range(box.shape[0])])
def kmeans(box,k):
#-------------------------------------------------------------#
# 取出一共有多少框
#-------------------------------------------------------------#
row = box.shape[0]
#-------------------------------------------------------------#
# 每个框各个点的位置
#-------------------------------------------------------------#
distance = np.empty((row,k))
#-------------------------------------------------------------#
# 最后的聚类位置
#-------------------------------------------------------------#
last_clu = np.zeros((row,))
np.random.seed()
#-------------------------------------------------------------#
# 随机选5个当聚类中心
#-------------------------------------------------------------#
cluster = box[np.random.choice(row,k,replace = False)]
iter = 0
while True:
#-------------------------------------------------------------#
# 计算当前框和先验框的宽高比例
#-------------------------------------------------------------#
for i in range(row):
distance[i] = cas_ratio(box[i],cluster)
#-------------------------------------------------------------#
# 取出最小点
#-------------------------------------------------------------#
near = np.argmin(distance,axis=1)
if (last_clu == near).all():
break
#-------------------------------------------------------------#
# 求每一个类的中位点
#-------------------------------------------------------------#
for j in range(k):
cluster[j] = np.median(
box[near == j],axis=0)
last_clu = near
if iter % 5 == 0:
print('iter: {:d}. avg_ratio:{:.2f}'.format(iter, avg_ratio(box,cluster)))
iter += 1
return cluster, near
def load_data(path):
data = []
#-------------------------------------------------------------#
# 对于每一个xml都寻找box
#-------------------------------------------------------------#
for xml_file in tqdm(glob.glob('{}/*xml'.format(path))):
tree = ET.parse(xml_file)
height = int(tree.findtext('./size/height'))
width = int(tree.findtext('./size/width'))
if height<=0 or width<=0:
continue
#-------------------------------------------------------------#
# 对于每一个目标都获得它的宽高
#-------------------------------------------------------------#
for obj in tree.iter('object'):
xmin = int(float(obj.findtext('bndbox/xmin'))) / width
ymin = int(float(obj.findtext('bndbox/ymin'))) / height
xmax = int(float(obj.findtext('bndbox/xmax'))) / width
ymax = int(float(obj.findtext('bndbox/ymax'))) / height
xmin = np.float64(xmin)
ymin = np.float64(ymin)
xmax = np.float64(xmax)
ymax = np.float64(ymax)
# 得到宽高
data.append([xmax-xmin,ymax-ymin])
return np.array(data)
if __name__ == '__main__':
np.random.seed(0)
#-------------------------------------------------------------#
# 运行该程序会计算'./VOCdevkit/VOC2007/Annotations'的xml
# 会生成yolo_anchors.txt
#-------------------------------------------------------------#
input_shape = [640, 640]
anchors_num = 9
#-------------------------------------------------------------#
# 载入数据集,可以使用VOC的xml
#-------------------------------------------------------------#
path = 'VOCdevkit/VOC2007/Annotations'
#-------------------------------------------------------------#
# 载入所有的xml
# 存储格式为转化为比例后的width,height
#-------------------------------------------------------------#
print('Load xmls.')
data = load_data(path)
print('Load xmls done.')
#-------------------------------------------------------------#
# 使用k聚类算法
#-------------------------------------------------------------#
print('K-means boxes.')
cluster, near = kmeans(data, anchors_num)
print('K-means boxes done.')
data = data * np.array([input_shape[1], input_shape[0]])
cluster = cluster * np.array([input_shape[1], input_shape[0]])
#-------------------------------------------------------------#
# 绘图
#-------------------------------------------------------------#
for j in range(anchors_num):
plt.scatter(data[near == j][:,0], data[near == j][:,1])
plt.scatter(cluster[j][0], cluster[j][1], marker='x', c='black')
plt.savefig("kmeans_for_anchors.jpg")
plt.show()
print('Save kmeans_for_anchors.jpg in root dir.')
cluster = cluster[np.argsort(cluster[:, 0] * cluster[:, 1])]
print('avg_ratio:{:.2f}'.format(avg_ratio(data, cluster)))
print(cluster)
f = open("yolo_anchors.txt", 'w')
row = np.shape(cluster)[0]
for i in range(row):
if i == 0:
x_y = "%d,%d" % (cluster[i][0], cluster[i][1])
else:
x_y = ", %d,%d" % (cluster[i][0], cluster[i][1])
f.write(x_y)
f.close()