Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.34】为 Paddle 新增 bitwise_right_shift / bitwise_right_shift_ / bitwise_left_shift / bitwise_left_shift_ API -part #58092

Merged
merged 31 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
813c09d
test
cocoshe Dec 25, 2023
42b3972
fix
cocoshe Dec 25, 2023
198d875
fix
cocoshe Dec 25, 2023
6ab394d
test
cocoshe Dec 27, 2023
78d8fec
update
cocoshe Dec 27, 2023
5d627e1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Dec 27, 2023
994d832
update
cocoshe Dec 27, 2023
7587891
update
cocoshe Dec 27, 2023
aac04c8
test split op
cocoshe Dec 28, 2023
99aef13
codestyle
cocoshe Dec 28, 2023
d1e942d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Dec 28, 2023
af86d96
add register
cocoshe Dec 29, 2023
db05b52
split on win
cocoshe Dec 29, 2023
126e5a5
test bigger shape
cocoshe Dec 30, 2023
f5f51b8
add note
cocoshe Jan 2, 2024
0fd2774
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Jan 2, 2024
d28481d
fix
cocoshe Jan 2, 2024
899d1ff
Update bitwise_functors.h
cocoshe Jan 2, 2024
04c80b2
fix
cocoshe Jan 3, 2024
98b392d
Merge branch 'develop' into bitwise_shift_coco_dev
cocoshe Jan 5, 2024
60f0e92
fix doc
cocoshe Jan 5, 2024
3b1e7e4
fix doc
cocoshe Jan 5, 2024
d21eea0
refactor
cocoshe Jan 8, 2024
28e0e6a
enhence doc
cocoshe Jan 8, 2024
6214243
fix doc
cocoshe Jan 8, 2024
9d50a52
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
c74106d
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
fb5ad3c
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
b4e8edd
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
0ecabc8
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
809fb46
Update python/paddle/tensor/math.py
cocoshe Jan 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,21 @@ void OperatorDialect::initialize() {
// NOTE(chenxi67): GET_OP_LIST is defined in cinn_op.h which is
// generated by op_gen.py, see details in
// paddle/cinn/hlir/dialect/CMakeLists.txt.
#ifdef WIN32
cocoshe marked this conversation as resolved.
Show resolved Hide resolved
RegisterOps<
#define GET_OP_LIST1
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op_info.cc" // NOLINT
>();
RegisterOps<
#define GET_OP_LIST2
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op_info.cc" // NOLINT
>();
#else
RegisterOps<
#define GET_OP_LIST
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op_info.cc" // NOLINT
>();
#endif
RegisterOp<GroupOp>();
RegisterOp<ConcatOp>();
RegisterOp<SplitOp>();
Expand Down
45 changes: 39 additions & 6 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,20 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
{define_type_id}
"""

CC_OP_INFO_FILE_TEMPLATE = """#ifdef GET_OP_LIST
CC_OP_INFO_FILE_TEMPLATE_PART1 = """#ifdef GET_OP_LIST
cocoshe marked this conversation as resolved.
Show resolved Hide resolved
#undef GET_OP_LIST
{op_declare}
"""

CC_OP_INFO_FILE_TEMPLATE_WIN_PART1 = """#ifdef GET_OP_LIST1
#undef GET_OP_LIST1
{op_declare_first_part}
#elif defined(GET_OP_LIST2)
#undef GET_OP_LIST2
{op_declare_second_part}
"""

CC_OP_INFO_FILE_TEMPLATE_PART2 = """
#else
// This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py"
#include "{h_file}"
Expand Down Expand Up @@ -1994,11 +2005,33 @@ def OpGenerator(
op_to_multi_kernels_map_str = ""

if op_info_file is not None:
op_info_str = CC_OP_INFO_FILE_TEMPLATE.format(
op_declare=",".join(op_list_strs).replace("\n", ""),
op_to_multi_kernels_map=op_to_multi_kernels_map_str,
h_file=op_def_h_file[:-4],
)
if sys.platform == "win32":
n = len(op_list_strs) // 2
first_part_op_info = op_list_strs[:n]
second_part_op_info = op_list_strs[n:]
CC_OP_INFO_FILE_TEMPLATE = (
CC_OP_INFO_FILE_TEMPLATE_WIN_PART1
+ CC_OP_INFO_FILE_TEMPLATE_PART2
)
op_info_str = CC_OP_INFO_FILE_TEMPLATE.format(
op_declare_first_part=",".join(first_part_op_info).replace(
"\n", ""
),
op_declare_second_part=",".join(second_part_op_info).replace(
"\n", ""
),
op_to_multi_kernels_map=op_to_multi_kernels_map_str,
h_file=op_def_h_file[:-4],
)
else:
CC_OP_INFO_FILE_TEMPLATE = (
CC_OP_INFO_FILE_TEMPLATE_PART1 + CC_OP_INFO_FILE_TEMPLATE_PART2
)
op_info_str = CC_OP_INFO_FILE_TEMPLATE.format(
op_declare=",".join(op_list_strs).replace("\n", ""),
op_to_multi_kernels_map=op_to_multi_kernels_map_str,
h_file=op_def_h_file[:-4],
)

with open(op_info_file, 'w') as f:
f.write(op_info_str)
Expand Down
13 changes: 12 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,22 @@ void OperatorDialect::initialize() {
// paddle/fluid/pir/dialect/CMakeLists.txt.
// NOTE(Ruting)GET_MANUAL_OP_LIST is define in manual_op.h"
// use RegisterOps when list has more than two ops.
#ifdef WIN32
cocoshe marked this conversation as resolved.
Show resolved Hide resolved
RegisterOps<
#define GET_OP_LIST
#define GET_OP_LIST1
#include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT
>();

RegisterOps<
#define GET_OP_LIST2
#include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT
>();
#else
RegisterOps<
#define GET_OP_LIST
#include "paddle/fluid/pir/dialect/operator/ir/pd_op_info.cc" // NOLINT
>();
#endif
RegisterOps<
#define GET_OP_LIST
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc" // NOLINT
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,21 @@ void OneDNNOperatorDialect::initialize() {
// paddle/fluid/pir/dialect/CMakeLists.txt.
// NOTE(Ruting)GET_MANUAL_OP_LIST is define in manual_op.h"
// use RegisterOps when list has more than two ops.
#ifdef WIN32
cocoshe marked this conversation as resolved.
Show resolved Hide resolved
RegisterOps<
#define GET_OP_LIST1
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.cc" // NOLINT
>();
RegisterOps<
#define GET_OP_LIST2
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.cc" // NOLINT
>();
#else
RegisterOps<
#define GET_OP_LIST
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.cc" // NOLINT
>();
#endif
}

void OneDNNOperatorDialect::PrintType(pir::Type type, std::ostream &os) const {
Expand Down
40 changes: 40 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,26 @@
backend : x
inplace: (x -> out)

- op : bitwise_left_shift_arithmetic
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_left_shift_arithmetic
backend : x
inplace: (x -> out)

- op : bitwise_left_shift_logic
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_left_shift_logic
backend : x
inplace: (x -> out)
Copy link
Contributor

@jeff41404 jeff41404 Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the C++ operator development specifications, the naming and parameters of operator need to be consistent with Python API, so that operator can reuse documentation of Python API. When user writes C++ code directly using operator, it will be relatively simple. So the operator needs to be bitwise_left_shift with is_arithmetic in args, kernels in file paddle/phi/kernels/bitwise_kernel.h should also haveis_arithmetic in args. and it's best if Functor in paddle/phi/kernels/funcs/bitwise_functors.h can be unified, but it's okay if it can't be unified.


- op : bitwise_not
args : (Tensor x)
output : Tensor(out)
Expand All @@ -364,6 +384,26 @@
backend : x
inplace: (x -> out)

- op : bitwise_right_shift_arithmetic
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_right_shift_arithmetic
backend : x
inplace: (x -> out)

- op : bitwise_right_shift_logic
args : (Tensor x, Tensor y)
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
func : bitwise_right_shift_logic
backend : x
inplace: (x -> out)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the C++ operator development specifications, the naming and parameters of operator need to be consistent with Python API, so that operator can reuse documentation of Python API. When user writes C++ code directly using operator, it will be relatively simple. So the operator needs to be bitwise_right_shift with is_arithmetic in args, kernels in file paddle/phi/kernels/bitwise_kernel.h should also haveis_arithmetic in args. and it's best if Functor in paddle/phi/kernels/funcs/bitwise_functors.h can be unified, but it's okay if it can't be unified.

- op : bitwise_xor
args : (Tensor x, Tensor y)
output : Tensor(out)
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/kernels/bitwise_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,28 @@ void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);

template <typename T, typename Context>
void BitwiseLeftShiftArithmeticKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

template <typename T, typename Context>
void BitwiseLeftShiftLogicKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

template <typename T, typename Context>
void BitwiseRightShiftArithmeticKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

template <typename T, typename Context>
void BitwiseRightShiftLogicKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);

} // namespace phi
65 changes: 65 additions & 0 deletions paddle/phi/kernels/cpu/bitwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,31 @@ DEFINE_BITWISE_KERNEL(Or)
DEFINE_BITWISE_KERNEL(Xor)
#undef DEFINE_BITWISE_KERNEL

#define DEFINE_BITWISE_KERNEL_WITH_INVERSE(op_type) \
template <typename T, typename Context> \
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Bitwise##op_type##Functor<T> func; \
funcs::InverseBitwise##op_type##Functor<T> inv_func; \
auto x_dims = x.dims(); \
auto y_dims = y.dims(); \
if (x_dims.size() >= y_dims.size()) { \
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T>( \
dev_ctx, x, y, func, out); \
} else { \
funcs::ElementwiseCompute<funcs::InverseBitwise##op_type##Functor<T>, \
T>(dev_ctx, x, y, inv_func, out); \
} \
}

DEFINE_BITWISE_KERNEL_WITH_INVERSE(LeftShiftArithmetic)
DEFINE_BITWISE_KERNEL_WITH_INVERSE(LeftShiftLogic)
DEFINE_BITWISE_KERNEL_WITH_INVERSE(RightShiftArithmetic)
DEFINE_BITWISE_KERNEL_WITH_INVERSE(RightShiftLogic)
#undef DEFINE_BITWISE_KERNEL_WITH_INVERSE

template <typename T, typename Context>
void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -97,3 +122,43 @@ PD_REGISTER_KERNEL(bitwise_not,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_left_shift_arithmetic,
CPU,
ALL_LAYOUT,
phi::BitwiseLeftShiftArithmeticKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_left_shift_logic,
CPU,
ALL_LAYOUT,
phi::BitwiseLeftShiftLogicKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_right_shift_arithmetic,
CPU,
ALL_LAYOUT,
phi::BitwiseRightShiftArithmeticKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(bitwise_right_shift_logic,
CPU,
ALL_LAYOUT,
phi::BitwiseRightShiftLogicKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After unification, registration can be reduced to bitwise_right_shift and bitwise_left_shift

Loading