Skip to content

Commit

Permalink
Merge branch 'develop' into dy2st_pir_api_push_10
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Nov 27, 2023
2 parents 0abeca1 + d93061c commit 7f84a25
Show file tree
Hide file tree
Showing 238 changed files with 8,567 additions and 2,023 deletions.
51 changes: 50 additions & 1 deletion paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ namespace cinn {
namespace dialect {

const char *GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"};
const char *ConcatOp::attributes_name[GroupOp::attributes_num] = {"axis"};
const char *ConcatOp::attributes_name[ConcatOp::attributes_num] = {"axis"};
const char *SplitOp::attributes_name[SplitOp::attributes_num] = {
"num_or_sections", "axis"};

void GroupOp::Build(pir::Builder &builder,
pir::OperationArgument &argument,
Expand Down Expand Up @@ -129,8 +131,55 @@ void ConcatOp::Build(pir::Builder &builder, // NOLINT
"axis", pir::Int32Attribute::get(pir::IrContext::Instance(), axis));
}

void SplitOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value input,
const std::vector<int> &sections,
int axis) {
VLOG(4) << "Start build ConcatOp";

argument.inputs.push_back(input);

std::vector<pir::Type> output_type(sections.size());

auto input_ele = input.type().dyn_cast<paddle::dialect::DenseTensorType>();

if (axis < 0) {
axis += input_ele.dims().size();
}
std::vector<pir::Attribute> section_attrs;
for (size_t idx = 0; idx < sections.size(); ++idx) {
auto out_dims = input_ele.dims();
out_dims[axis] = sections[idx];
auto out_type =
paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(),
input_ele.dtype(),
out_dims,
input_ele.data_layout(),
input_ele.lod(),
input_ele.offset());

argument.output_types.emplace_back(out_type);

pir::Attribute attr_axis =
pir::Int32Attribute::get(pir::IrContext::Instance(), sections[idx]);

section_attrs.push_back(attr_axis);
}

PassStopGradientsDefaultly(argument);

argument.AddAttribute(
"num_or_sections",
pir::ArrayAttribute::get(pir::IrContext::Instance(), section_attrs));

argument.AddAttribute(
"axis", pir::Int32Attribute::get(pir::IrContext::Instance(), axis));
}

} // namespace dialect
} // namespace cinn

IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp)
20 changes: 20 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,28 @@ class IR_API ConcatOp : public pir::Op<ConcatOp> {
void VerifySig() const {}
};

class IR_API SplitOp : public pir::Op<SplitOp> {
public:
using Op::Op;

static const char *name() { return "cinn_op.split"; }

static constexpr uint32_t attributes_num = 2;

static const char *attributes_name[attributes_num];

static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value input,
const std::vector<int> &sections,
int axis);

void VerifySig() const {}
};

} // namespace dialect
} // namespace cinn

IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp)
1 change: 1 addition & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void OperatorDialect::initialize() {
>();
RegisterOp<GroupOp>();
RegisterOp<ConcatOp>();
RegisterOp<SplitOp>();
RegisterAttribute<GroupInfoAttribute>();
RegisterAttribute<CUDAJITInfoAttribute>();
}
Expand Down
58 changes: 58 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,63 @@ class ConcatOpPattern
}
};

class SplitWithNumOpPattern
: public pir::OpRewritePattern<paddle::dialect::SplitWithNumOp> {
public:
using pir::OpRewritePattern<
paddle::dialect::SplitWithNumOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::SplitWithNumOp op,
pir::PatternRewriter &rewriter) const override {
auto axis_gen_op = op->operand_source(1).dyn_cast<pir::OpResult>().owner();
if (auto full_op = axis_gen_op->dyn_cast<paddle::dialect::FullOp>()) {
int axis = phi::Scalar(full_op.attribute("value")
.dyn_cast<::pir::FloatAttribute>()
.data())
.to<int>();

auto input_ele = op->operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>();
if (axis < 0) {
axis += input_ele.dims().size();
}
std::vector<int> sections;

auto split_dim = input_ele.dims()[axis];

auto split_num =
op->attribute("num").dyn_cast<::pir::Int32Attribute>().data();
auto part_ele = (split_dim + split_num - 1) / split_num;

int total_split_num = 0;
for (int i = 0; i < split_num - 1; ++i) {
sections.push_back(part_ele);
total_split_num += part_ele;
}

sections.push_back(split_dim - total_split_num);

auto cinn_split = rewriter.Build<cinn::dialect::SplitOp>(
op->operand_source(0), sections, axis);

int index = 0;
auto orig_out = op.result(0);
for (auto it = orig_out.use_begin(); it != orig_out.use_end();) {
auto split_op = (it++)->owner();
rewriter.ReplaceAllUsesWith(split_op->result(0),
cinn_split.result(index++));
rewriter.EraseOp(split_op);
}

rewriter.EraseOp(op);

return true;
}
return false;
}
};

class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
Expand Down Expand Up @@ -307,6 +364,7 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns(
ps.Add<ReshapeOpPattern>(context);
ps.Add<ConcatOpPattern>(context);
ps.Add<SliceOpPattern>(context);
ps.Add<SplitWithNumOpPattern>(context);
// ps.Add(UniformOpPattern().Build(context));

return ps;
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const std::unordered_map<std::string, std::string> CompatibleInfo::OP_NAMES = {
{"pd_op.add", "elementwise_add"},
{"pd_op.elementwise_pow", "pow"},
{"pd_op.multiply", "elementwise_mul"},
{"pd_op.split_with_num", "split"},
{"cinn_op.reshape", "reshape"},
{"cinn_op.scale", "scale"},
{"cinn_op.broadcast", "broadcast_to"},
Expand Down
3 changes: 3 additions & 0 deletions paddle/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ set(COMMON_BUILD_TYPE
CACHE INTERNAL "" FORCE)

cc_library(common ${COMMON_BUILD_TYPE} SRCS ${common_srcs})
if(WIN32)
set_property(TARGET common PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
1 change: 1 addition & 0 deletions paddle/fluid/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ add_subdirectory(jit)
add_subdirectory(pir)
add_subdirectory(ir_adaptor)
add_subdirectory(primitive)
add_subdirectory(sub_graph)
# NOTE: please add subdirectory inference at last.
add_subdirectory(inference)
Loading

0 comments on commit 7f84a25

Please sign in to comment.