Skip to content

Commit

Permalink
Support dropout op in pir (PaddlePaddle#58773)
Browse files Browse the repository at this point in the history
* [CINN+PIR]Support DoGroupSchedule for PIRComppiler

fix complation problem

* fix conflict

* support cinn broadcast code gen

* fix op fusion pass bug

* using output_ops to parse function arguments

* update

* fix unittest

* remove VLOG(1)

* ignore some UT and add FIXME

* update

* remove args limit

* fix bug and remove useless code

* update

* update

* fix bug

* update

* fix bug

* update

* update

* update

* update

* remove useless code

* merge layer norm manual

* remove usless code

* remove usless code

* remove useless code

---------

Co-authored-by: Aurelius84 <[email protected]>
  • Loading branch information
2 people authored and SecretXV committed Nov 28, 2023
1 parent 0f94876 commit 2ba5706
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 1 deletion.
10 changes: 10 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,13 @@
func : ReduceInferMeta
kernel :
func : frobenius_norm

- op : uniform_random
args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0)
output : Tensor(out)
infer_meta :
func : CreateVecShapeInferMeta
param : [shape, dtype]
kernel :
func : full_int_array
param : [shape, dtype]
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ std::unordered_map<std::string, OpPatternKind> OpKindMap = {
{"pd_op.exp", OpPatternKind::kElementWise},
{"pd_op.sin", OpPatternKind::kElementWise},
{"pd_op.cos", OpPatternKind::kElementWise},
{"pd_op.cast", OpPatternKind::kElementWise},
{"pd_op.greater_than", OpPatternKind::kElementWise},
{"pd_op.sum", OpPatternKind::kReduction},
{"cinn_op.reduce_sum", OpPatternKind::kReduction},
{"cinn_op.reduce_max", OpPatternKind::kReduction},
{"cinn_op.broadcast", OpPatternKind::kBroadcast},
};
{"cinn_op.uniform_random", OpPatternKind::kElementWise}};

OpPatternKind GetOpKind(const std::string& op_name) {
auto found_it = OpKindMap.find(op_name);
Expand Down
48 changes: 48 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,60 @@ class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
}
};

class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
// Source Pattern
pir::drr::SourcePattern pattern = ctx->SourcePattern();
const auto &full_int_array =
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pattern.Attr("axis_info")},
{"dtype", pattern.Attr("dtype_2")},
{"place", pattern.Attr("place_2")}});

const auto &min_full = pattern.Op(paddle::dialect::FullOp::name(),
{{"shape", pattern.Attr("shape1")},
{"value", pattern.Attr("min_value")},
{"dtype", pattern.Attr("dtype_min")},
{"place", pattern.Attr("place_min")}});

const auto &max_full = pattern.Op(paddle::dialect::FullOp::name(),
{{"shape", pattern.Attr("shape2")},
{"value", pattern.Attr("max_value")},
{"dtype", pattern.Attr("dtype_max")},
{"place", pattern.Attr("place_max")}});

const auto &pd_uniform =
pattern.Op(paddle::dialect::UniformOp::name(),
{{"dtype", pattern.Attr("uniform_dtype")},
{"place", pattern.Attr("uniform_place")},
{"seed", pattern.Attr("seed")}});
pattern.Tensor("ret") =
pd_uniform(full_int_array(), min_full(), max_full());
// int64_t[] shape, float min, float max, int seed, DataType dtype, int
// diag_num, int diag_step, float diag_val)
// Result patterns
pir::drr::ResultPattern res = pattern.ResultPattern();
const auto &cinn_uniform =
res.Op(cinn::dialect::UniformRandomOp::name(),
{{"shape", pattern.Attr("axis_info")},
{"min", pattern.Attr("min_value")},
{"max", pattern.Attr("max_value")},
{"seed", pattern.Attr("seed")},
{"dtype", pattern.Attr("uniform_dtype")},
{"diag_num", pattern.Attr("seed")},
{"diag_step", pattern.Attr("seed")},
{"diag_val", pattern.Attr("min_value")}});
res.Tensor("ret") = cinn_uniform();
}
};
PdOpToCinnOpPass::PdOpToCinnOpPass() : pir::Pass("pd_to_cinn_pass", 1) {}

bool PdOpToCinnOpPass::Initialize(pir::IrContext *context) {
pir::RewritePatternSet ps(context);
ps.Add(SumOpPattern().Build(context));
ps.Add(MaxOpPattern().Build(context));
// ps.Add(UniformOpPattern().Build(context));

patterns_ = ::pir::FrozenRewritePatternSet(std::move(ps));
return true;
Expand Down
15 changes: 15 additions & 0 deletions paddle/cinn/hlir/framework/pir/op_mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,20 @@ void AppendAttrForReduceOp(const ::pir::Operation& op,
attrs["dim"] = dim;
}

void AppendAttrForUniformOp(const ::pir::Operation& op,
utils::AttributeMap& attrs) { // NOLINT
auto attr = op.attributes().at("shape");
auto attr_vec = attr.dyn_cast<::pir::ArrayAttribute>().AsVector();

std::vector<int> shape;
for (auto vec_element : attr_vec) {
shape.push_back(vec_element.dyn_cast<::pir::Int64Attribute>().data());
}

attrs["shape"] = shape;
attrs["dtype"] = "float32";
}

void AppendAttrForBoadcastToOp(const ::pir::Operation& op,
utils::AttributeMap& attrs) { // NOLINT
auto axes_attr = op.attributes().at("broadcast_axes");
Expand Down Expand Up @@ -81,6 +95,7 @@ void OpMapper::RegisterMapRules() {
REGISTER_ATTR_RULE(ReduceMaxOp, AppendAttrForReduceOp);
REGISTER_ATTR_RULE(ReduceSumOp, AppendAttrForReduceOp);
REGISTER_ATTR_RULE(BroadcastOp, AppendAttrForBoadcastToOp);
REGISTER_ATTR_RULE(UniformRandomOp, AppendAttrForUniformOp);
}

} // namespace pir
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pir/drr/api/drr_pattern_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,17 @@ Tensor& Op::operator()(const Tensor& arg1, const Tensor& arg2) const {
return out;
}

Tensor& Op::operator()(const Tensor& arg0,
const Tensor& arg1,
const Tensor& arg2) const {
std::vector<const Tensor*> inputs{&arg0, &arg1, &arg2};
auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr<Tensor>(new Tensor(
prefix + op_type_name_ + "_" + std::to_string(count++), pattern_graph_)));
std::vector<const Tensor*> outputs{&out};
pattern_graph_->AddOpCall(std::make_shared<OpCall>(this, inputs, outputs));
return out;
}

Tensor& Op::operator()() const {
std::vector<const Tensor*> inputs{};
auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr<Tensor>(new Tensor(
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pir/drr/api/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ class Op {

Tensor& operator()(const Tensor& arg) const;
Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const;
Tensor& operator()(const Tensor& arg0,
const Tensor& arg1,
const Tensor& arg2) const;
void operator()(const std::vector<const Tensor*>& args,
const std::vector<const Tensor*>& outputs) const;
// const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const
Expand Down
83 changes: 83 additions & 0 deletions test/cpp/pir/cinn/pir_all_path_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,86 @@ TEST(GroupOp, TestBuildLayerNorm) {
// auto out_tensor =
// executor.local_scope()->FindVar("out@fetch")->Get<phi::DenseTensor>();
}

std::shared_ptr<::pir::Program> BuildDropOutProgram() {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
auto program = std::make_shared<::pir::Program>(ctx);
::pir::Builder builder = ::pir::Builder(ctx, program->block());

auto x =
builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>({128, 128, 768}),
1.0,
phi::DataType::FLOAT32,
phi::GPUPlace())
.result(0);

auto prob = builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>({1}),
0.5,
phi::DataType::FLOAT32,
phi::GPUPlace())
.result(0);

auto random = builder
.Build<paddle::dialect::UniformOp>(
std::vector<int64_t>({128, 128, 768}),
phi::DataType::FLOAT32,
0.0,
1.0,
0,
phi::GPUPlace())
.result(0);

auto mask =
builder.Build<paddle::dialect::GreaterThanOp>(random, prob).result(0);
auto mask1 =
builder.Build<paddle::dialect::CastOp>(mask, phi::DataType::FLOAT32)
.result(0);
auto mul = builder.Build<paddle::dialect::MultiplyOp>(x, mask1).result(0);
auto neg_prob = prob =
builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>({1}),
0.5,
phi::DataType::FLOAT32,
phi::GPUPlace())
.result(0);
auto out = builder.Build<paddle::dialect::DivideOp>(mul, neg_prob).result(0);

builder.Build<paddle::dialect::FetchOp>(out, "out", 0);
return program;
}

TEST(GroupOp, TestBuildDropout) {
// Step 1: Construct pir::Program
::pir::IrContext* ctx = ::pir::IrContext::Instance();
std::shared_ptr<::pir::Program> program = BuildDropOutProgram();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();

cinn::dialect::ir::PdOp2CinnOpConverter(program.get());

pir::PassManager pm(ctx);
pm.AddPass(
std::make_unique<cinn::dialect::ir::AddBroadcastToElementwisePass>());
pm.AddPass(pir::CreateBuildCinnPass());
CHECK_EQ(pm.Run(program.get()), true);

auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get());

paddle::platform::Place place = paddle::platform::CUDAPlace(0);

auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(res.get(), place);

paddle::framework::Scope exe_scope;

paddle::framework::InterpreterCore executor(
place, {"out@fetch"}, kernel_program->block(), &exe_scope);

executor.Run({}, true);

auto out_tensor =
executor.local_scope()->FindVar("out@fetch")->Get<phi::DenseTensor>();
}

0 comments on commit 2ba5706

Please sign in to comment.