-
Notifications
You must be signed in to change notification settings - Fork 9
/
xfc_gemm.cpp
41 lines (32 loc) · 1.2 KB
/
xfc_gemm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <torch/extension.h>
void xfc_gemm_cuda(
torch::Tensor mat_in1,
torch::Tensor mat_in2,
torch::Tensor mat_out,
float alpha,
float beta,
bool apply_sigmoid);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
void xfc_gemm(
torch::Tensor mat_in1,
torch::Tensor mat_in2,
torch::Tensor mat_out,
float alpha,
float beta,
bool apply_sigmoid){
//CHECK_INPUT(mat_in1);
//CHECK_INPUT(mat_in2);
//CHECK_INPUT(mat_out);
AT_ASSERTM(mat_in1.dim() == 2, "expected 2D tensor");
AT_ASSERTM(mat_in2.dim() == 2, "expected 2D tensor");
AT_ASSERTM(mat_out.dim() == 2, "expected 2D tensor");
//AT_ASSERTM(mat_in1.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
//AT_ASSERTM(mat_in2.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
xfc_gemm_cuda(mat_in1, mat_in2, mat_out, alpha, beta, apply_sigmoid);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("xfc_gemm", &xfc_gemm, "Optimized gemm.");
}