-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmnist.py
77 lines (63 loc) · 2.43 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Code for number recognition of the mnist dataset
with a neural network written from scratch
"""
import numpy as np
import matplotlib.pyplot as plt
from neural import NeuralNetworkClassifier
#=============================================================
# load of dataset
#=============================================================
# train, all 60000 data is too much
N = 3000
train_data = np.loadtxt('MNIST_data/mnist_train.csv', max_rows=N, delimiter=',')
# normalize input
X_train, Y_train = train_data[:, 1:].T/255, train_data[:, 0]
Y_train = np.array([int(y) for y in Y_train])
# test
M = 1000
test_data = np.loadtxt('MNIST_data/mnist_test.csv', max_rows=M, delimiter=',')
X_test, Y_test = test_data[:, 1:].T/255, test_data[:, 0]
Y_test = np.array([int(y) for y in Y_test])
#=============================================================
# Parameter of computation and train of the network
#=============================================================
n_epoch = 3000 + 1
lr_rate = 0.01
NN = NeuralNetworkClassifier([50, 50], n_epoch, f_act='relu')
result = NN.train(X_train, Y_train, alpha=lr_rate, verbose=True)
A = NN.predict(X_train[:, :N-N//4 ])
NN.confmat(Y_train[:N-N//4], A, plot=True, title='Confusion matrix for train data', k=0)
#plt.savefig("conf_mat_fit_train.pdf")
L_t = result['train_Loss']
L_v = result['valid_Loss']
#=============================================================
# Plot Loss
#=============================================================
plt.figure(1)
plt.plot(np.linspace(1, n_epoch, n_epoch), L_t, 'b', label='train Loss')
plt.plot(np.linspace(1, n_epoch, n_epoch), L_v, 'r', label='validation loss')
plt.title('Binay cross entropy', fontsize=15)
plt.xlabel('epoch', fontsize=15)
plt.ylabel('Loss', fontsize=15)
plt.legend(loc='best')
#plt.savefig("Loss_fit.pdf")
plt.grid()
#=============================================================
# Test of the network
#=============================================================
A = NN.predict(X_test)
M = NN.confmat(Y_test, A, plot=True, title='Confusion matrix for test data', k=3)
#plt.savefig("conf_mat_fit_test.pdf")
acc = NN.accuracy(A, Y_test)
print(f"Accuracy on test set = {acc:.5f}")
plt.figure(2, figsize=(16, 10))
for i in range(40):
plt.subplot(5, 8, 1+i)
image = X_test[:, i]
image = image.reshape((28, 28)) * 255
plt.title(f'pred={A[i]} label={Y_test[i]}')
plt.imshow(image)
plt.tight_layout()
#plt.savefig("MNIST.pdf")
plt.show()