-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify_mnist.c
160 lines (145 loc) · 5.13 KB
/
classify_mnist.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include <stdio.h>
#include <stdlib.h>
#include <omp.h>
#include "matrix.h"
#include "maxheap.h"
/* find the k nearest neighbors sorted from closest to farthest */
void find_knearest (matrix_type* train_images_ptr, matrix_row_type* test_image_ptr,
int k, int* knearest) {
matrix_row_type train_image;
key_value_type pair;
maxheap_type heap;
maxheap_init (&heap,k);
int i;
for (i = 0;i<train_images_ptr->num_rows;i++) {
matrix_get_row(train_images_ptr,&train_image,i);
pair.key = i;
pair.value = matrix_row_dist_sq(test_image_ptr,&train_image);
if (heap.size < k) {
maxheap_insert(&heap,pair);
} else if (pair.value < heap.array[0].value) {
maxheap_remove_root(&heap);
maxheap_insert(&heap,pair);
}
}
/* store the k nearest neighbors from closest to farthest */
for (i = k-1; i>=0; i--) {
pair = heap.array[0];
knearest[i] = pair.key;
maxheap_remove_root(&heap);
}
/* free up the heap */
maxheap_deinit(&heap);
}
/* classify a test image given the k nearest neighbor indices */
/* predict the class of a test image using the "majority rule" */
/* if there is a tie reduce k by 1 and repeat until a single class has a majority */
/* note that the tie breaking process is guaranteed to terminate when k=1 */
int classify (matrix_type* train_labels_ptr, int num_classes, int k, int* knearest){
//dynamically allocate knearest label buffer
int* knearest_label = calloc(k, sizeof(int));
int class = -100;
//fill in buffer for knearest label arr
for (int i=0;i<k;i++){
knearest_label[i] = train_labels_ptr->data_ptr[knearest[i]];
//printf("%d ", knearest_label[i]);
}
//printf("\n");
int count = 0;
int majority_idx = -1;
int majority_count = 0;
int majority_candidate = 0;
int num_of_majorities = 0;
int majority_flag = 1;
while (majority_flag == 1 && k != 1){
for (int i=0; i<k; i++){
count = 0;
for (int j=0;j<k;j++){
if (knearest_label[i] == knearest_label[j]){
count++;
}
}
if (count > 1){
num_of_majorities++;
}
if(count > majority_count){
majority_count = count;
majority_idx = i;
}
//printf("index: %d , freq: %d\n", i, count);
}
num_of_majorities = num_of_majorities/2;
/*
printf("majority_idx: %d\n", majority_idx);
printf("majority_candidate: %d\n", knearest_label[majority_idx]);
printf("majority_freq: %d\n", majority_count);
printf("current_k: %d\n", k);
printf("num_of_majorities: %d\n", num_of_majorities);
printf("\n");
*/
//check if we have to reduce k or not
if (num_of_majorities == 1 || majority_count > (k/2)){
class = knearest_label[majority_idx];
majority_flag = -1;
}
else {
k--;
num_of_majorities = 0;
majority_count = 0;
}
}
return class;
//free dynamically allocated buffer
free(knearest_label);
}
int main (int argc, char** argv) {
/* get k, start test image, and num_to_test from the command line */
if (argc != 4) {
printf ("Command usage : %s %s %s %s\n",argv[0],"k","start_test","num_to_test");
return 1;
}
int k = atoi(argv[1]);
int start_test = atoi(argv[2]);
int num_to_test = atoi(argv[3]);
if (num_to_test + start_test > 10000) num_to_test = 10000-start_test;
/* the MNIST dataset has 10 class labels */
int num_classes = 10;
/* read in the mnist training set of 60000 images and labels */
int num_train = 60000;
matrix_type train_images, train_labels;
matrix_init (&train_images,num_train,784);
matrix_read_bin(&train_images,"train-images-idx3-ubyte",16);
matrix_init (&train_labels,num_train,1);
matrix_read_bin(&train_labels,"train-labels-idx1-ubyte",8);
/* read in the mnist test set of 10000 images */
int num_test = 10000;
matrix_type test_images, test_labels;
matrix_init (&test_images,num_test,784);
matrix_read_bin(&test_images,"t10k-images-idx3-ubyte",16);
matrix_init (&test_labels,num_test,1);
matrix_read_bin(&test_labels,"t10k-labels-idx1-ubyte",8);
/* find the k training images nearest the given test image */
int i,j;
matrix_row_type test_image;
int knearest[k];
int predicted_label;
for (i = start_test;i<start_test+num_to_test;i++) {
matrix_get_row(&test_images,&test_image,i);
find_knearest (&train_images,&test_image,k,knearest);
predicted_label = classify (&train_labels,num_classes,k,knearest);
printf ("test index : %d, test label : %d, ",
i,test_labels.data_ptr[i]);
printf ("training labels : ");
for (j = 0;j<k;j++) {
printf ("%d ",train_labels.data_ptr[knearest[j]]);
}
printf (", predicted label : %d",predicted_label);
printf ("\n");
}
/* free up the training and test data sets */
matrix_deinit(&train_images);
matrix_deinit(&test_images);
matrix_deinit(&train_labels);
matrix_deinit(&test_labels);
return 0;
}