-
Notifications
You must be signed in to change notification settings - Fork 0
/
6.py
78 lines (59 loc) · 1.76 KB
/
6.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(1)
def generate_data(sample_size):
"""Generate training data.
Since
f(x) = w^{T}x + b
can be written as
f(x) = (w^{T}, b)(x^{T}, 1)^{T},
for the sake of simpler implementation of SVM,
we return (x^{T}, 1)^{T} instead of x
:param sample_size: number of data points in the sample
:return: a tuple of data point and label
"""
x = np.random.normal(size=(sample_size, 3))
x[:, 2] = 1.
x[:sample_size // 2, 0] -= 5.
x[sample_size // 2:, 0] += 5.
y = np.concatenate([np.ones(sample_size // 2, dtype=np.int64),
-np.ones(sample_size // 2, dtype=np.int64)])
x[:3, 1] -= 5.
y[:3] = -1
x[-3:, 1] += 5.
y[-3:] = 1
return x, y
def svm(x, y, l, lr):
"""Linear SVM implementation using gradient descent algorithm.
f_w(x) = w^{T} (x^{T}, 1)^{T}
:param x: data points
:param y: label
:param l: regularization parameter
:param lr: learning rate
:return: three-dimensional vector w
"""
w = np.zeros(3)
prev_w = w.copy()
for i in range(10 ** 4):
j=np.random.randint(0,len(y))
xi,yi=x[j],y[j]
if (1-yi*w.dot(xi))>=0:
tmp=-yi*xi
else:
tmp=0.
w=w-l*(2*lr*w+tmp)
if np.linalg.norm(w - prev_w) < 1e-3:
break
prev_w = w.copy()
return w
def visualize(x, y, w):
plt.clf()
plt.xlim(-10, 10)
plt.ylim(-10, 10)
plt.scatter(x[y == 1, 0], x[y == 1, 1])
plt.scatter(x[y == -1, 0], x[y == -1, 1])
plt.plot([-10, 10], -(w[2] + np.array([-10, 10]) * w[0]) / w[1])
plt.savefig('lecture6-h2.png')
x, y = generate_data(200)
w = svm(x, y, l=.1, lr=1.)
visualize(x, y, w)