-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.34】为 Paddle 新增 bitwise_right_shift / bitwise_right_shift_ / bitwise_left_shift / bitwise_left_shift_ API -part #58092
Changes from 22 commits
813c09d
42b3972
198d875
6ab394d
78d8fec
5d627e1
994d832
7587891
aac04c8
99aef13
d1e942d
af86d96
db05b52
126e5a5
f5f51b8
0fd2774
d28481d
899d1ff
04c80b2
98b392d
60f0e92
3b1e7e4
d21eea0
28e0e6a
6214243
9d50a52
c74106d
fb5ad3c
b4e8edd
0ecabc8
809fb46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -343,6 +343,26 @@ | |
backend : x | ||
inplace: (x -> out) | ||
|
||
- op : bitwise_left_shift_arithmetic | ||
args : (Tensor x, Tensor y) | ||
output : Tensor(out) | ||
infer_meta : | ||
func : ElementwiseInferMeta | ||
kernel : | ||
func : bitwise_left_shift_arithmetic | ||
backend : x | ||
inplace: (x -> out) | ||
|
||
- op : bitwise_left_shift_logic | ||
args : (Tensor x, Tensor y) | ||
output : Tensor(out) | ||
infer_meta : | ||
func : ElementwiseInferMeta | ||
kernel : | ||
func : bitwise_left_shift_logic | ||
backend : x | ||
inplace: (x -> out) | ||
|
||
- op : bitwise_not | ||
args : (Tensor x) | ||
output : Tensor(out) | ||
|
@@ -364,6 +384,26 @@ | |
backend : x | ||
inplace: (x -> out) | ||
|
||
- op : bitwise_right_shift_arithmetic | ||
args : (Tensor x, Tensor y) | ||
output : Tensor(out) | ||
infer_meta : | ||
func : ElementwiseInferMeta | ||
kernel : | ||
func : bitwise_right_shift_arithmetic | ||
backend : x | ||
inplace: (x -> out) | ||
|
||
- op : bitwise_right_shift_logic | ||
args : (Tensor x, Tensor y) | ||
output : Tensor(out) | ||
infer_meta : | ||
func : ElementwiseInferMeta | ||
kernel : | ||
func : bitwise_right_shift_logic | ||
backend : x | ||
inplace: (x -> out) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to the C++ operator development specifications, the naming and parameters of operator need to be consistent with Python API, so that operator can reuse documentation of Python API. When user writes C++ code directly using operator, it will be relatively simple. So the operator needs to be |
||
- op : bitwise_xor | ||
args : (Tensor x, Tensor y) | ||
output : Tensor(out) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,31 @@ DEFINE_BITWISE_KERNEL(Or) | |
DEFINE_BITWISE_KERNEL(Xor) | ||
#undef DEFINE_BITWISE_KERNEL | ||
|
||
#define DEFINE_BITWISE_KERNEL_WITH_INVERSE(op_type) \ | ||
template <typename T, typename Context> \ | ||
void Bitwise##op_type##Kernel(const Context& dev_ctx, \ | ||
const DenseTensor& x, \ | ||
const DenseTensor& y, \ | ||
DenseTensor* out) { \ | ||
funcs::Bitwise##op_type##Functor<T> func; \ | ||
funcs::InverseBitwise##op_type##Functor<T> inv_func; \ | ||
auto x_dims = x.dims(); \ | ||
auto y_dims = y.dims(); \ | ||
if (x_dims.size() >= y_dims.size()) { \ | ||
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T>( \ | ||
dev_ctx, x, y, func, out); \ | ||
} else { \ | ||
funcs::ElementwiseCompute<funcs::InverseBitwise##op_type##Functor<T>, \ | ||
T>(dev_ctx, x, y, inv_func, out); \ | ||
} \ | ||
} | ||
|
||
DEFINE_BITWISE_KERNEL_WITH_INVERSE(LeftShiftArithmetic) | ||
DEFINE_BITWISE_KERNEL_WITH_INVERSE(LeftShiftLogic) | ||
DEFINE_BITWISE_KERNEL_WITH_INVERSE(RightShiftArithmetic) | ||
DEFINE_BITWISE_KERNEL_WITH_INVERSE(RightShiftLogic) | ||
#undef DEFINE_BITWISE_KERNEL_WITH_INVERSE | ||
|
||
template <typename T, typename Context> | ||
void BitwiseNotKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
|
@@ -97,3 +122,43 @@ PD_REGISTER_KERNEL(bitwise_not, | |
int16_t, | ||
int, | ||
int64_t) {} | ||
|
||
PD_REGISTER_KERNEL(bitwise_left_shift_arithmetic, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::BitwiseLeftShiftArithmeticKernel, | ||
uint8_t, | ||
int8_t, | ||
int16_t, | ||
int, | ||
int64_t) {} | ||
|
||
PD_REGISTER_KERNEL(bitwise_left_shift_logic, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::BitwiseLeftShiftLogicKernel, | ||
uint8_t, | ||
int8_t, | ||
int16_t, | ||
int, | ||
int64_t) {} | ||
|
||
PD_REGISTER_KERNEL(bitwise_right_shift_arithmetic, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::BitwiseRightShiftArithmeticKernel, | ||
uint8_t, | ||
int8_t, | ||
int16_t, | ||
int, | ||
int64_t) {} | ||
|
||
PD_REGISTER_KERNEL(bitwise_right_shift_logic, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::BitwiseRightShiftLogicKernel, | ||
uint8_t, | ||
int8_t, | ||
int16_t, | ||
int, | ||
int64_t) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After unification, registration can be reduced to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the C++ operator development specifications, the naming and parameters of operator need to be consistent with Python API, so that operator can reuse documentation of Python API. When user writes C++ code directly using operator, it will be relatively simple. So the operator needs to be
bitwise_left_shift
withis_arithmetic
in args, kernels in file paddle/phi/kernels/bitwise_kernel.h should also haveis_arithmetic
in args. and it's best if Functor in paddle/phi/kernels/funcs/bitwise_functors.h can be unified, but it's okay if it can't be unified.