Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.34】为 Paddle 新增 bitwise_right_shift / bitwise_right_shift_ / bitwise_left_shift / bitwise_left_shift_ API -part #58092

Merged
merged 31 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
813c09d
test
cocoshe Dec 25, 2023
42b3972
fix
cocoshe Dec 25, 2023
198d875
fix
cocoshe Dec 25, 2023
6ab394d
test
cocoshe Dec 27, 2023
78d8fec
update
cocoshe Dec 27, 2023
5d627e1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Dec 27, 2023
994d832
update
cocoshe Dec 27, 2023
7587891
update
cocoshe Dec 27, 2023
aac04c8
test split op
cocoshe Dec 28, 2023
99aef13
codestyle
cocoshe Dec 28, 2023
d1e942d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Dec 28, 2023
af86d96
add register
cocoshe Dec 29, 2023
db05b52
split on win
cocoshe Dec 29, 2023
126e5a5
test bigger shape
cocoshe Dec 30, 2023
f5f51b8
add note
cocoshe Jan 2, 2024
0fd2774
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Jan 2, 2024
d28481d
fix
cocoshe Jan 2, 2024
899d1ff
Update bitwise_functors.h
cocoshe Jan 2, 2024
04c80b2
fix
cocoshe Jan 3, 2024
98b392d
Merge branch 'develop' into bitwise_shift_coco_dev
cocoshe Jan 5, 2024
60f0e92
fix doc
cocoshe Jan 5, 2024
3b1e7e4
fix doc
cocoshe Jan 5, 2024
d21eea0
refactor
cocoshe Jan 8, 2024
28e0e6a
enhence doc
cocoshe Jan 8, 2024
6214243
fix doc
cocoshe Jan 8, 2024
9d50a52
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
c74106d
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
fb5ad3c
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
b4e8edd
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
0ecabc8
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
809fb46
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,26 @@
backend : x
inplace: (x -> out)

- op : bitwise_left_shift_arithmetic
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_left_shift_arithmetic
backend : x
inplace: (x -> out)

- op : bitwise_left_shift_logic
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_left_shift_logic
backend : x
inplace: (x -> out)
Copy link
Contributor

@jeff41404 jeff41404 Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the C++ operator development specifications, the naming and parameters of operator need to be consistent with Python API, so that operator can reuse documentation of Python API. When user writes C++ code directly using operator, it will be relatively simple. So the operator needs to be bitwise_left_shift with is_arithmetic in args, kernels in file paddle/phi/kernels/bitwise_kernel.h should also haveis_arithmetic in args. and it's best if Functor in paddle/phi/kernels/funcs/bitwise_functors.h can be unified, but it's okay if it can't be unified.


- op : bitwise_not
args : (Tensor x)
output : Tensor(out)
Expand All @@ -364,6 +384,26 @@
backend : x
inplace: (x -> out)

- op : bitwise_right_shift_arithmetic
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_right_shift_arithmetic
backend : x
inplace: (x -> out)

- op : bitwise_right_shift_logic
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_right_shift_logic
backend : x
inplace: (x -> out)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the C++ operator development specifications, the naming and parameters of operator need to be consistent with Python API, so that operator can reuse documentation of Python API. When user writes C++ code directly using operator, it will be relatively simple. So the operator needs to be bitwise_right_shift with is_arithmetic in args, kernels in file paddle/phi/kernels/bitwise_kernel.h should also haveis_arithmetic in args. and it's best if Functor in paddle/phi/kernels/funcs/bitwise_functors.h can be unified, but it's okay if it can't be unified.

- op : bitwise_xor
args : (Tensor x, Tensor y)
output : Tensor(out)
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/kernels/bitwise_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,28 @@ void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);

template <typename T, typename Context>
void BitwiseLeftShiftArithmeticKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

template <typename T, typename Context>
void BitwiseLeftShiftLogicKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

template <typename T, typename Context>
void BitwiseRightShiftArithmeticKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

template <typename T, typename Context>
void BitwiseRightShiftLogicKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

} // namespace phi
65 changes: 65 additions & 0 deletions paddle/phi/kernels/cpu/bitwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,31 @@ DEFINE_BITWISE_KERNEL(Or)
DEFINE_BITWISE_KERNEL(Xor)
#undef DEFINE_BITWISE_KERNEL

#define DEFINE_BITWISE_KERNEL_WITH_INVERSE(op_type) \
template <typename T, typename Context> \
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Bitwise##op_type##Functor<T> func; \
funcs::InverseBitwise##op_type##Functor<T> inv_func; \
auto x_dims = x.dims(); \
auto y_dims = y.dims(); \
if (x_dims.size() >= y_dims.size()) { \
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T>( \
dev_ctx, x, y, func, out); \
} else { \
funcs::ElementwiseCompute<funcs::InverseBitwise##op_type##Functor<T>, \
T>(dev_ctx, x, y, inv_func, out); \
} \
}

DEFINE_BITWISE_KERNEL_WITH_INVERSE(LeftShiftArithmetic)
DEFINE_BITWISE_KERNEL_WITH_INVERSE(LeftShiftLogic)
DEFINE_BITWISE_KERNEL_WITH_INVERSE(RightShiftArithmetic)
DEFINE_BITWISE_KERNEL_WITH_INVERSE(RightShiftLogic)
#undef DEFINE_BITWISE_KERNEL_WITH_INVERSE

template <typename T, typename Context>
void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -97,3 +122,43 @@ PD_REGISTER_KERNEL(bitwise_not,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_left_shift_arithmetic,
CPU,
ALL_LAYOUT,
phi::BitwiseLeftShiftArithmeticKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_left_shift_logic,
CPU,
ALL_LAYOUT,
phi::BitwiseLeftShiftLogicKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_right_shift_arithmetic,
CPU,
ALL_LAYOUT,
phi::BitwiseRightShiftArithmeticKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_right_shift_logic,
CPU,
ALL_LAYOUT,
phi::BitwiseRightShiftLogicKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After unification, registration can be reduced to bitwise_right_shift and bitwise_left_shift

159 changes: 159 additions & 0 deletions paddle/phi/kernels/funcs/bitwise_functors.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,164 @@ struct BitwiseNotFunctor<bool> {
HOSTDEVICE bool operator()(const bool a) const { return !a; }
};

template <typename T>
struct BitwiseLeftShiftArithmeticFunctor {
HOSTDEVICE T operator()(const T a, const T b) const {
if (b >= static_cast<T>(sizeof(T) * 8)) return static_cast<T>(0);
if (b < static_cast<T>(0)) return static_cast<T>(0);
return a << b;
}
};

template <typename T>
struct InverseBitwiseLeftShiftArithmeticFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
if (a >= static_cast<T>(sizeof(T) * 8)) return static_cast<T>(0);
if (a < static_cast<T>(0)) return static_cast<T>(0);
return b << a;
}
};

template <typename T>
struct BitwiseLeftShiftLogicFunctor {
HOSTDEVICE T operator()(const T a, const T b) const {
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(0);
return a << b;
}
};

template <typename T>
struct InverseBitwiseLeftShiftLogicFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
if (a < static_cast<T>(0) || a >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(0);
return b << a;
}
};

template <typename T>
struct BitwiseRightShiftArithmeticFunctor {
HOSTDEVICE T operator()(const T a, const T b) const {
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(-(a >> (sizeof(T) * 8 - 1) & 1));
return a >> b;
}
};

template <typename T>
struct InverseBitwiseRightShiftArithmeticFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
if (a < static_cast<T>(0) || a >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(-(b >> (sizeof(T) * 8 - 1) & 1));
return b >> a;
}
};

template <>
struct BitwiseRightShiftArithmeticFunctor<uint8_t> {
HOSTDEVICE uint8_t operator()(const uint8_t a, const uint8_t b) const {
if (b >= static_cast<uint8_t>(sizeof(uint8_t) * 8))
return static_cast<uint8_t>(0);
return a >> b;
}
};

template <>
struct InverseBitwiseRightShiftArithmeticFunctor<uint8_t> {
inline HOSTDEVICE uint8_t operator()(const uint8_t a, const uint8_t b) const {
if (a >= static_cast<uint8_t>(sizeof(uint8_t) * 8))
return static_cast<uint8_t>(0);
return b >> a;
}
};

template <typename T>
struct BitwiseRightShiftLogicFunctor {
HOSTDEVICE T operator()(const T a, const T b) const {
if (b >= static_cast<T>(sizeof(T) * 8) || b < static_cast<T>(0))
return static_cast<T>(0);
return a >> b;
}
};

template <typename T>
struct InverseBitwiseRightShiftLogicFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
if (a >= static_cast<T>(sizeof(T) * 8) || a < static_cast<T>(0))
return static_cast<T>(0);
return b >> a;
}
};

template <typename T>
HOSTDEVICE T logic_shift_func(const T a, const T b) {
if (b < static_cast<T>(0) || b >= static_cast<T>(sizeof(T) * 8))
return static_cast<T>(0);
T t = static_cast<T>(sizeof(T) * 8 - 1);
T mask = (((a >> t) << t) >> b) << 1;
return (a >> b) ^ mask;
}

// signed int8
template <>
struct BitwiseRightShiftLogicFunctor<int8_t> {
HOSTDEVICE int8_t operator()(const int8_t a, const int8_t b) const {
return logic_shift_func<int8_t>(a, b);
}
};

template <>
struct InverseBitwiseRightShiftLogicFunctor<int8_t> {
inline HOSTDEVICE int8_t operator()(const int8_t a, const int8_t b) const {
return logic_shift_func<int8_t>(b, a);
}
};

// signed int16
template <>
struct BitwiseRightShiftLogicFunctor<int16_t> {
HOSTDEVICE int16_t operator()(const int16_t a, const int16_t b) const {
return logic_shift_func<int16_t>(a, b);
}
};

template <>
struct InverseBitwiseRightShiftLogicFunctor<int16_t> {
inline HOSTDEVICE int16_t operator()(const int16_t a, const int16_t b) const {
return logic_shift_func<int16_t>(b, a);
}
};

// signed int32
template <>
struct BitwiseRightShiftLogicFunctor<int> {
HOSTDEVICE int operator()(const int a, const int b) const {
return logic_shift_func<int32_t>(a, b);
}
};

template <>
struct InverseBitwiseRightShiftLogicFunctor<int> {
inline HOSTDEVICE int operator()(const int a, const int b) const {
return logic_shift_func<int32_t>(b, a);
}
};

// signed int64
template <>
struct BitwiseRightShiftLogicFunctor<int64_t> {
HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
return logic_shift_func<int64_t>(a, b);
}
};

template <>
struct InverseBitwiseRightShiftLogicFunctor<int64_t> {
inline HOSTDEVICE int64_t operator()(const int64_t a, const int64_t b) const {
return logic_shift_func<int64_t>(b, a);
}
};

} // namespace funcs
} // namespace phi
44 changes: 44 additions & 0 deletions paddle/phi/kernels/kps/bitwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ namespace phi {
DEFINE_BITWISE_KERNEL(And)
DEFINE_BITWISE_KERNEL(Or)
DEFINE_BITWISE_KERNEL(Xor)
DEFINE_BITWISE_KERNEL(LeftShiftArithmetic)
DEFINE_BITWISE_KERNEL(LeftShiftLogic)
DEFINE_BITWISE_KERNEL(RightShiftArithmetic)
DEFINE_BITWISE_KERNEL(RightShiftLogic)
#undef DEFINE_BITWISE_KERNEL

template <typename T, typename Context>
Expand Down Expand Up @@ -112,4 +116,44 @@ PD_REGISTER_KERNEL(bitwise_not,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_left_shift_arithmetic,
KPS,
ALL_LAYOUT,
phi::BitwiseLeftShiftArithmeticKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_left_shift_logic,
KPS,
ALL_LAYOUT,
phi::BitwiseLeftShiftLogicKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_right_shift_arithmetic,
KPS,
ALL_LAYOUT,
phi::BitwiseRightShiftArithmeticKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_right_shift_logic,
KPS,
ALL_LAYOUT,
phi::BitwiseRightShiftLogicKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

#endif
Loading