-
Notifications
You must be signed in to change notification settings - Fork 82
/
cnn_mnist.f90
66 lines (51 loc) · 1.78 KB
/
cnn_mnist.f90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
program cnn_mnist
use nf, only: network, sgd, &
input, conv2d, maxpool2d, flatten, dense, reshape, &
load_mnist, label_digits, softmax, relu
implicit none
type(network) :: net
real, allocatable :: training_images(:,:), training_labels(:)
real, allocatable :: validation_images(:,:), validation_labels(:)
real, allocatable :: testing_images(:,:), testing_labels(:)
integer :: n
integer, parameter :: num_epochs = 10
call load_mnist(training_images, training_labels, &
validation_images, validation_labels, &
testing_images, testing_labels)
net = network([ &
input(784), &
reshape([1,28,28]), &
conv2d(filters=8, kernel_size=3, activation=relu()), &
maxpool2d(pool_size=2), &
conv2d(filters=16, kernel_size=3, activation=relu()), &
maxpool2d(pool_size=2), &
dense(10, activation=softmax()) &
])
call net % print_info()
epochs: do n = 1, num_epochs
call net % train( &
training_images, &
label_digits(training_labels), &
batch_size=128, &
epochs=1, &
optimizer=sgd(learning_rate=3.) &
)
print '(a,i2,a,f5.2,a)', 'Epoch ', n, ' done, Accuracy: ', accuracy( &
net, validation_images, label_digits(validation_labels)) * 100, ' %'
end do epochs
print '(a,f5.2,a)', 'Testing accuracy: ', &
accuracy(net, testing_images, label_digits(testing_labels)) * 100, '%'
contains
real function accuracy(net, x, y)
type(network), intent(in out) :: net
real, intent(in) :: x(:,:), y(:,:)
integer :: i, good
good = 0
do i = 1, size(x, dim=2)
if (all(maxloc(net % predict(x(:,i))) == maxloc(y(:,i)))) then
good = good + 1
end if
end do
accuracy = real(good) / size(x, dim=2)
end function accuracy
end program cnn_mnist