forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
accuracy_op.cu
79 lines (70 loc) · 2.23 KB
/
accuracy_op.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/accuracy_op.h"
#include "caffe2/utils/GpuAtomics.cuh"
#include "caffe2/utils/math.h"
#include "caffe2/utils/cub_namespace.cuh"
#include <cub/block/block_reduce.cuh>
namespace caffe2 {
namespace {
__global__ void AccuracyKernel(
const int N,
const int D,
const int top_k,
const float* Xdata,
const int* labelData,
float* accuracy) {
typedef cub::BlockReduce<int, CAFFE_CUDA_NUM_THREADS> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int correct = 0;
for (int row = blockIdx.x; row < N; row += gridDim.x) {
const int label = labelData[row];
const float label_pred = Xdata[row * D + label];
int ngt = 0;
for (int col = threadIdx.x; col < D; col += blockDim.x) {
const float pred = Xdata[row * D + col];
if (pred > label_pred || (pred == label_pred && col <= label)) {
++ngt;
}
}
ngt = BlockReduce(temp_storage).Sum(ngt);
if (ngt <= top_k) {
++correct;
}
__syncthreads();
}
if (threadIdx.x == 0) {
gpu_atomic_add(accuracy, static_cast<float>(correct));
}
}
__global__ void AccuracyDivideKernel(const int N, float* accuracy) {
*accuracy /= N;
}
} // namespace
template <>
bool AccuracyOp<float, CUDAContext>::RunOnDevice() {
auto& X = Input(PREDICTION);
auto& label = Input(LABEL);
CAFFE_ENFORCE_EQ(X.dim(), 2);
int N = X.dim32(0);
int D = X.dim32(1);
CAFFE_ENFORCE_EQ(label.dim(), 1);
CAFFE_ENFORCE_EQ(label.dim32(0), N);
auto* Y = Output(0, vector<int64_t>(), at::dtype<float>());
float* Ydata = Y->template mutable_data<float>();
math::Set<float, CUDAContext>(1, 0, Ydata, &context_);
AccuracyKernel<<<
std::min(CAFFE_MAXIMUM_NUM_BLOCKS, N),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(
N, D, top_k_, X.data<float>(), label.data<int>(), Ydata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// This is going to be executed only in one single kernel. Not very beautiful,
// but probably we have to do this?
AccuracyDivideKernel<<<1, 1, 0, context_.cuda_stream()>>>(
N, Ydata);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return true;
}
REGISTER_CUDA_OPERATOR(Accuracy, AccuracyOp<float, CUDAContext>);
} // namespace caffe2