-
Notifications
You must be signed in to change notification settings - Fork 100
/
Copy pathlibSVM.py
108 lines (102 loc) · 3.32 KB
/
libSVM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import sys
from numpy import *
from svm import *
from os import listdir
from plattSMO import PlattSMO
import pickle
class LibSVM:
def __init__(self,data=[],label=[],C=0,toler=0,maxIter=0,**kernelargs):
self.classlabel = unique(label)
self.classNum = len(self.classlabel)
self.classfyNum = (self.classNum * (self.classNum-1))/2
self.classfy = []
self.dataSet={}
self.kernelargs = kernelargs
self.C = C
self.toler = toler
self.maxIter = maxIter
m = shape(data)[0]
for i in range(m):
if label[i] not in self.dataSet.keys():
self.dataSet[label[i]] = []
self.dataSet[label[i]].append(data[i][:])
else:
self.dataSet[label[i]].append(data[i][:])
def train(self):
num = self.classNum
for i in range(num):
for j in range(i+1,num):
data = []
label = [1.0]*shape(self.dataSet[self.classlabel[i]])[0]
label.extend([-1.0]*shape(self.dataSet[self.classlabel[j]])[0])
data.extend(self.dataSet[self.classlabel[i]])
data.extend(self.dataSet[self.classlabel[j]])
svm = PlattSMO(array(data),array(label),self.C,self.toler,self.maxIter,**self.kernelargs)
svm.smoP()
self.classfy.append(svm)
self.dataSet = None
def predict(self,data,label):
m = shape(data)[0]
num = self.classNum
classlabel = []
count = 0.0
for n in range(m):
result = [0] * num
index = -1
for i in range(num):
for j in range(i + 1, num):
index += 1
s = self.classfy[index]
t = s.predict([data[n]])[0]
if t > 0.0:
result[i] +=1
else:
result[j] +=1
classlabel.append(result.index(max(result)))
if classlabel[-1] != label[n]:
count +=1
print label[n],classlabel[n]
#print classlabel
print "error rate:",count / m
return classlabel
def save(self,filename):
fw = open(filename,'wb')
pickle.dump(self,fw,2)
fw.close()
@staticmethod
def load(filename):
fr = open(filename,'rb')
svm = pickle.load(fr)
fr.close()
return svm
def loadImage(dir,maps = None):
dirList = listdir(dir)
data = []
label = []
for file in dirList:
label.append(file.split('_')[0])
lines = open(dir +'/'+file).readlines()
row = len(lines)
col = len(lines[0].strip())
line = []
for i in range(row):
for j in range(col):
line.append(float(lines[i][j]))
data.append(line)
if maps != None:
label[-1] = float(maps[label[-1]])
else:
label[-1] = float(label[-1])
return data,label
def main():
'''
data,label = loadImage('trainingDigits')
svm = LibSVM(data, label, 200, 0.0001, 10000, name='rbf', theta=20)
svm.train()
svm.save("svm.txt")
'''
svm = LibSVM.load("svm.txt")
test,testlabel = loadImage('testDigits')
svm.predict(test,testlabel)
if __name__ == "__main__":
sys.exit(main())