This repository has been archived by the owner on Jul 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathkgemm_nt.hpp
201 lines (163 loc) · 5.26 KB
/
kgemm_nt.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
#ifndef KGEMM_NT_HPP
#define KGEMM_NT_HPP 1
#include "kroncommon.hpp"
// -----------------------
// NotransA and TransB case
// C = alpha*A*transpose(B) + beta *C
// -----------------------
template<typename T>
DEVICE_FUNCTION
void kgemm_nt( int const mm, int const nn, int const kk,
T const alpha,
T const * const A_, int const ldA,
T const * const B_, int const ldB,
T const beta,
T * C_, int const ldC)
{
#ifdef USE_LAMBDA
auto min = []( int const x, int const y) {
return( (x < y) ? x : y );
};
auto max = []( int const x, int const y) {
return( (x > y) ? x : y );
};
#else
#define min(x,y) (((x) < (y)) ? (x) : (y) )
#define max(x,y) (((x) > (y)) ? (x) : (y) )
#endif
int constexpr nb = 2*32;
#ifdef USE_GPU
// ---------------------------
// use matlab 1 based indexing
// ---------------------------
int constexpr warpsize = 32;
int const nthreads = blockDim.x;
expect( blockDim.y == 1);
expect( blockDim.z == 1);
// -----------------------------------------
// reorganize threads as nx_threads by ny_threads
// -----------------------------------------
int const nx_threads = warpsize;
int const ny_threads = max(1,nthreads/nx_threads);
expect( (nthreads % warpsize) == 0);
int const ix_start = ( threadIdx.x % nx_threads ) + 1;
int const iy_start = (threadIdx.x/nx_threads) + 1;
int const ix_size = nx_threads;
int const iy_size = ny_threads;
int const ij_start = threadIdx.x + 1;
int const ij_size = nthreads;
#else
int const ix_start = 1;
int const ix_size = 1;
int const iy_start = 1;
int const iy_size = 1;
int const ij_start = 1;
int const ij_size = 1;
#endif
expect( ix_start >= 1);
expect( iy_start >= 1);
expect( ix_size >= 1 );
expect( iy_size >= 1 );
// ------------------------------------
// commonly mm is large, but kk, nn are small
// ------------------------------------
#ifdef USE_LAMBDA
auto A = [&] (int const ia,
int const ja) -> T const & {
return( A_[ indx2f(ia,ja,ldA) ] );
};
auto B = [&] (int const ib,
int const jb) -> T const & {
return( B_[ indx2f(ib,jb,ldB) ] );
};
auto C = [&] (int const ic,
int const jc) -> T& {
return( C_[ indx2f(ic,jc,ldC) ] );
};
#else
#define A(ia,ja) A_[indx2f(ia,ja,ldA)]
#define B(ib,jb) B_[indx2f(ib,jb,ldB)]
#define C(ic,jc) C_[indx2f(ic,jc,ldC)]
#endif
for(int jstart=1; jstart <= nn; jstart += nb) {
int const jend = min(nn, jstart + nb-1);
int const jsize = jend - jstart + 1;
for(int istart=1; istart <= mm; istart += nb) {
int const iend = min( mm, istart + nb-1);
int const isize = iend - istart + 1;
// ---------------------------
// perform matrix calculations
// ---------------------------
// for(int j=iy_start; j <= jsize; j += iy_size)
// for(int i=ix_start; i <= isize; i += ix_size) {
auto const inc_A = ldA;
auto const inc_B = ldB;
for(int ij0=ij_start-1; ij0 < (isize*jsize); ij0 += ij_size) {
int const i = (ij0 % isize) + 1;
int const j = (ij0 - (i-1))/isize + 1;
T cij = 0;
bool constexpr use_pointer = true;
if (use_pointer) {
int k = 1;
int ia = (istart-1) + i;
int ib = (jstart-1) + j;
T const * Ap = &(A(ia,k));
T const * Bp = &(B(ib,k));
#define case_code(kk) { \
for(k=0; k < kk; k++) { \
cij += (*Ap) * (*Bp); \
Ap += inc_A; \
Bp += inc_B; \
}; \
break; \
}
switch(kk) {
case 1: case_code(1)
case 2: case_code(2)
case 3: case_code(3)
case 4: case_code(4)
case 5: case_code(5)
case 6: case_code(6)
case 7: case_code(7)
case 8: case_code(8)
default:
#ifdef USE_GPU
#pragma unroll
#endif
for(k=0; k < kk; k++) {
cij += (*Ap) * (*Bp);
Ap += inc_A;
Bp += inc_B;
};
};
}
else {
for(int k=1; k <= kk; k++) {
cij += A( (istart-1) + i, k) *
B( (jstart-1) + j, k);
};
};
// ------------------
// store results to C
// ------------------
int const ic = (istart-1) + i;
int const jc = (jstart-1) + j;
if (beta == 1) {
atomicAdd( &(C(ic,jc)), alpha*cij );
}
else if (beta == 0) {
C(ic,jc) = alpha * cij;
}
else {
C(ic,jc) = beta * C(ic,jc) + alpha*cij;
};
};
}; // end istart
}; // end jstart
}
#undef min
#undef max
#undef A
#undef B
#undef C
#endif