diff --git a/build.sh b/build.sh index 5732b86bad..32bc29105a 100755 --- a/build.sh +++ b/build.sh @@ -70,6 +70,8 @@ function prepare_model { tar -xvf ResNet18.tar wget https://paddle-inference-dist.bj.bcebos.com/CINN/MobileNetV2.tar tar -xvf MobileNetV2.tar + wget https://paddle-inference-dist.bj.bcebos.com/CINN/EfficientNet.tar + tar -xvf EfficientNet.tar python $workspace/python/tests/fake_model/naive_mul.py python $workspace/python/tests/fake_model/naive_multi_fc.py python $workspace/python/tests/fake_model/resnet_model.py diff --git a/cinn/frontend/paddle_model_to_program.cc b/cinn/frontend/paddle_model_to_program.cc index 7d7a78a3a4..20398d96e5 100644 --- a/cinn/frontend/paddle_model_to_program.cc +++ b/cinn/frontend/paddle_model_to_program.cc @@ -21,7 +21,7 @@ void PaddleModelToProgram::AddOpMapper_feed() { void PaddleModelToProgram::AddOpMapper_fetch() { op_mappers_["fetch"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("X").empty()); + CHECK_EQ(op_desc.Input("X").size(), 1UL); auto output_name = op_desc.Input("X").front(); LOG(INFO) << "detect model output: [" << output_name << "]"; }; @@ -29,7 +29,7 @@ void PaddleModelToProgram::AddOpMapper_fetch() { void PaddleModelToProgram::AddOpMapper_scale() { op_mappers_["scale"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("X").empty()); + CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); auto x = GetVar(utils::TransValidVarName(x_name)); float scale{}; @@ -37,7 +37,7 @@ void PaddleModelToProgram::AddOpMapper_scale() { scale = op_desc.GetAttr("scale"); } else { // the newly refactored format // load scale tensor - CHECK(!op_desc.Input("ScaleTensor").empty()); + CHECK_EQ(op_desc.Input("ScaleTensor").size(), 1UL); auto* scale_tensor_var = scope_->FindVar(op_desc.Input("ScaleTensor").front()); CHECK(scale_tensor_var) << "No scale tensor found in the scope"; auto& scale_tensor = std::get(*scale_tensor_var); @@ -46,7 +46,7 @@ void PaddleModelToProgram::AddOpMapper_scale() { std::unordered_map attrs; attrs["scale"] = scale; auto out = program_->scale(x, attrs); - CHECK(!op_desc.Output("Out").empty()); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); AddVar(utils::TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -55,9 +55,9 @@ void PaddleModelToProgram::AddOpMapper_scale() { void PaddleModelToProgram::AddOpMapper_mul() { op_mappers_["mul"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("X").empty()); + CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - CHECK(!op_desc.Input("Y").empty()); + CHECK_EQ(op_desc.Input("Y").size(), 1UL); auto y_name = op_desc.Input("Y").front(); auto x = GetVar(utils::TransValidVarName(x_name)); auto y = GetVar(utils::TransValidVarName(y_name)); @@ -68,7 +68,7 @@ void PaddleModelToProgram::AddOpMapper_mul() { VLOG(4) << "x shape: " << utils::Join(x->shape, ","); VLOG(4) << "y shape: " << utils::Join(y->shape, ","); auto out = program_->mul(x, y, x_num_col_dims, y_num_col_dims); - CHECK(!op_desc.Output("Out").empty()); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); AddVar(utils::TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -77,9 +77,9 @@ void PaddleModelToProgram::AddOpMapper_mul() { void PaddleModelToProgram::AddOpMapper_relu() { op_mappers_["relu"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("X").empty()); + CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - CHECK(!op_desc.Output("Out").empty()); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); auto x = GetVar(TransValidVarName(x_name)); auto out = program_->relu(x); @@ -91,16 +91,16 @@ void PaddleModelToProgram::AddOpMapper_relu() { void PaddleModelToProgram::AddOpMapper_softmax() { op_mappers_["softmax"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("X").empty()); + CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - CHECK(!op_desc.Output("Out").empty()); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); std::unordered_map attrs; if (op_desc.HasAttr("axis")) { attrs["axis"] = op_desc.GetAttr("axis"); } else { - attrs["axis"] = int(-1); + attrs["axis"] = static_cast(-1); } auto x = GetVar(TransValidVarName(x_name)); auto out = program_->softmax(x, attrs); @@ -111,11 +111,11 @@ void PaddleModelToProgram::AddOpMapper_softmax() { void PaddleModelToProgram::AddOpMapper_elementwise_add() { op_mappers_["elementwise_add"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("X").empty()); + CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - CHECK(!op_desc.Input("Y").empty()); + CHECK_EQ(op_desc.Input("Y").size(), 1UL); auto y_name = op_desc.Input("Y").front(); - CHECK(!op_desc.Output("Out").empty()); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); int axis = op_desc.GetAttr("axis"); @@ -128,11 +128,30 @@ void PaddleModelToProgram::AddOpMapper_elementwise_add() { }; } +void PaddleModelToProgram::AddOpMapper_elementwise_mul() { + op_mappers_["elementwise_mul"] = [&](const paddle::cpp::OpDesc& op_desc) { + CHECK_EQ(op_desc.Input("X").size(), 1UL); + auto x_name = op_desc.Input("X").front(); + CHECK_EQ(op_desc.Input("Y").size(), 1UL); + auto y_name = op_desc.Input("Y").front(); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); + auto out_name = op_desc.Output("Out").front(); + int axis = op_desc.GetAttr("axis"); + + auto x = GetVar(TransValidVarName(x_name)); + auto y = GetVar(TransValidVarName(y_name)); + auto out = program_->elementwise_mul(x, y, axis); + + AddVar(TransValidVarName(out_name), out); + var_model_to_program_map_[out_name] = out->id; + }; +} + void PaddleModelToProgram::AddOpMapper_relu6() { op_mappers_["relu6"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("X").empty()); + CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - CHECK(!op_desc.Output("Out").empty()); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); std::unordered_map attrs; @@ -141,7 +160,7 @@ void PaddleModelToProgram::AddOpMapper_relu6() { attrs["threshold"] = op_desc.GetAttr("threshold"); auto x = GetVar(TransValidVarName(x_name)); - auto out = program_->relu6(x, attrs); + auto out = program_->relu6(x); AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -149,11 +168,11 @@ void PaddleModelToProgram::AddOpMapper_relu6() { } void PaddleModelToProgram::AddOpMapper_depthwise_conv2d() { op_mappers_["depthwise_conv2d"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("Input").empty()); + CHECK_EQ(op_desc.Input("Input").size(), 1UL); auto x_name = op_desc.Input("Input").front(); - CHECK(!op_desc.Input("Filter").empty()); + CHECK_EQ(op_desc.Input("Filter").size(), 1UL); auto y_name = op_desc.Input("Filter").front(); - CHECK(!op_desc.Output("Output").empty()); + CHECK_EQ(op_desc.Output("Output").size(), 1UL); auto out_name = op_desc.Output("Output").front(); std::unordered_map attrs; @@ -176,11 +195,11 @@ void PaddleModelToProgram::AddOpMapper_depthwise_conv2d() { void PaddleModelToProgram::AddOpMapper_conv2d() { op_mappers_["conv2d"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("Input").empty()); + CHECK_EQ(op_desc.Input("Input").size(), 1UL); auto x_name = op_desc.Input("Input").front(); - CHECK(!op_desc.Input("Filter").empty()); + CHECK_EQ(op_desc.Input("Filter").size(), 1UL); auto y_name = op_desc.Input("Filter").front(); - CHECK(!op_desc.Output("Output").empty()); + CHECK_EQ(op_desc.Output("Output").size(), 1UL); auto out_name = op_desc.Output("Output").front(); std::unordered_map attrs; @@ -203,9 +222,9 @@ void PaddleModelToProgram::AddOpMapper_conv2d() { void PaddleModelToProgram::AddOpMapper_pool2d() { op_mappers_["pool2d"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("X").empty()); + CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - CHECK(!op_desc.Output("Out").empty()); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); auto out_name = op_desc.Output("Out").front(); std::unordered_map attrs; @@ -217,6 +236,7 @@ void PaddleModelToProgram::AddOpMapper_pool2d() { attrs["stride_size"] = op_desc.GetAttr>("strides"); CHECK(op_desc.HasAttr("paddings")); auto padding_size = op_desc.GetAttr>("paddings"); + if (padding_size.size() == 2) { padding_size.insert(padding_size.begin(), padding_size.front()); padding_size.push_back(padding_size.back()); @@ -228,6 +248,8 @@ void PaddleModelToProgram::AddOpMapper_pool2d() { attrs["exclusive"] = op_desc.GetAttr("exclusive"); CHECK(op_desc.HasAttr("data_format")); attrs["data_format"] = op_desc.GetAttr("data_format"); + CHECK(op_desc.HasAttr("global_pooling")); + attrs["global_pooling"] = op_desc.GetAttr("global_pooling"); auto x = GetVar(TransValidVarName(x_name)); auto out = program_->pool2d(x, attrs); @@ -239,15 +261,15 @@ void PaddleModelToProgram::AddOpMapper_pool2d() { void PaddleModelToProgram::AddOpMapper_batchnorm() { op_mappers_["batch_norm"] = [&](const paddle::cpp::OpDesc& op_desc) { - CHECK(!op_desc.Input("X").empty()); + CHECK_EQ(op_desc.Input("X").size(), 1UL); auto x_name = op_desc.Input("X").front(); - CHECK(!op_desc.Input("Scale").empty()); + CHECK_EQ(op_desc.Input("Scale").size(), 1UL); auto scale_name = op_desc.Input("Scale").front(); - CHECK(!op_desc.Input("Bias").empty()); + CHECK_EQ(op_desc.Input("Bias").size(), 1UL); auto bias_name = op_desc.Input("Bias").front(); - CHECK(!op_desc.Input("Mean").empty()); + CHECK_EQ(op_desc.Input("Mean").size(), 1UL); auto mean_name = op_desc.Input("Mean").front(); - CHECK(!op_desc.Input("Variance").empty()); + CHECK_EQ(op_desc.Input("Variance").size(), 1UL); auto variance_name = op_desc.Input("Variance").front(); CHECK(!op_desc.Output("Y").empty()); auto out_name = op_desc.Output("Y").front(); @@ -267,6 +289,63 @@ void PaddleModelToProgram::AddOpMapper_batchnorm() { }; } +void PaddleModelToProgram::AddOpMapper_sigmoid() { + op_mappers_["sigmoid"] = [&](const paddle::cpp::OpDesc& op_desc) { + CHECK_EQ(op_desc.Input("X").size(), 1UL); + auto x_name = op_desc.Input("X").front(); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); + auto out_name = op_desc.Output("Out").front(); + + auto x = GetVar(TransValidVarName(x_name)); + auto out = program_->sigmoid(x); + + AddVar(TransValidVarName(out_name), out); + var_model_to_program_map_[out_name] = out->id; + }; +} + +void PaddleModelToProgram::AddOpMapper_slice() { + op_mappers_["slice"] = [&](const paddle::cpp::OpDesc& op_desc) { + CHECK_EQ(op_desc.Input("Input").size(), 1UL); + auto x_name = op_desc.Input("Input").front(); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); + auto out_name = op_desc.Output("Out").front(); + + std::unordered_map attrs; + CHECK(op_desc.HasAttr("starts")); + attrs["starts"] = op_desc.GetAttr>("starts"); + CHECK(op_desc.HasAttr("ends")); + attrs["ends"] = op_desc.GetAttr>("ends"); + CHECK(op_desc.HasAttr("axes")); + attrs["axes"] = op_desc.GetAttr>("axes"); + auto x = GetVar(TransValidVarName(x_name)); + auto out = program_->slice(x, attrs); + + AddVar(TransValidVarName(out_name), out); + var_model_to_program_map_[out_name] = out->id; + }; +} + +void PaddleModelToProgram::AddOpMapper_dropout_infer() { + op_mappers_["dropout"] = [&](const paddle::cpp::OpDesc& op_desc) { + CHECK_EQ(op_desc.Input("X").size(), 1UL); + auto x_name = op_desc.Input("X").front(); + CHECK_EQ(op_desc.Output("Out").size(), 1UL); + auto out_name = op_desc.Output("Out").front(); + + std::unordered_map attrs; + CHECK(op_desc.HasAttr("dropout_prob")); + attrs["dropout_prob"] = op_desc.GetAttr("dropout_prob"); + CHECK(op_desc.HasAttr("dropout_implementation")); + attrs["dropout_implementation"] = op_desc.GetAttr("dropout_implementation"); + auto x = GetVar(TransValidVarName(x_name)); + auto out = program_->dropout_infer(x, attrs); + + AddVar(TransValidVarName(out_name), out); + var_model_to_program_map_[out_name] = out->id; + }; +} + void PaddleModelToProgram::AddOp(const paddle::cpp::OpDesc& op_desc) { const auto& op_type = op_desc.Type(); auto it = op_mappers_.find(op_type); diff --git a/cinn/frontend/paddle_model_to_program.h b/cinn/frontend/paddle_model_to_program.h index cf1c0398c4..62d2b8d7fe 100644 --- a/cinn/frontend/paddle_model_to_program.h +++ b/cinn/frontend/paddle_model_to_program.h @@ -32,12 +32,16 @@ class PaddleModelToProgram { AddOpMapper_scale(); AddOpMapper_relu(); AddOpMapper_elementwise_add(); + AddOpMapper_elementwise_mul(); AddOpMapper_conv2d(); AddOpMapper_batchnorm(); AddOpMapper_pool2d(); AddOpMapper_softmax(); AddOpMapper_relu6(); AddOpMapper_depthwise_conv2d(); + AddOpMapper_sigmoid(); + AddOpMapper_slice(); + AddOpMapper_dropout_infer(); } std::unique_ptr operator()(const std::string& model_dir, bool is_combined); @@ -52,12 +56,16 @@ class PaddleModelToProgram { void AddOpMapper_mul(); void AddOpMapper_relu(); void AddOpMapper_elementwise_add(); + void AddOpMapper_elementwise_mul(); void AddOpMapper_conv2d(); void AddOpMapper_batchnorm(); void AddOpMapper_pool2d(); void AddOpMapper_softmax(); void AddOpMapper_relu6(); void AddOpMapper_depthwise_conv2d(); + void AddOpMapper_sigmoid(); + void AddOpMapper_slice(); + void AddOpMapper_dropout_infer(); // @} const std::unordered_map& var_map() const { return var_map_; } diff --git a/cinn/frontend/syntax.cc b/cinn/frontend/syntax.cc index 7e2ac9196c..3c429c1415 100644 --- a/cinn/frontend/syntax.cc +++ b/cinn/frontend/syntax.cc @@ -103,6 +103,30 @@ Variable Program::softmax(const Variable& a, const std::unordered_map& attr_store) { + Instruction instr("slice", {a}); + for (auto& iter : attr_store) { + instr.SetAttr(iter.first, iter.second); + } + AppendInstruction(instr); + return instr.GetOutput(0); +} + +Variable Program::dropout_infer(const Variable& a, const std::unordered_map& attr_store) { + Instruction instr("dropout_infer", {a}); + for (auto& iter : attr_store) { + instr.SetAttr(iter.first, iter.second); + } + AppendInstruction(instr); + return instr.GetOutput(0); +} + Instruction& Program::operator[](size_t i) { CHECK_LT(i, instrs_.size()); return instrs_[i]; @@ -159,13 +183,20 @@ Variable Program::elementwise_add(const Variable& a, const Variable& b, int axis return instr.GetOutput(0); } +Variable Program::elementwise_mul(const Variable& a, const Variable& b, int axis) { + Instruction instr("elementwise_mul", {a, b}); + instr.SetAttr("axis", axis); + AppendInstruction(instr); + return instr.GetOutput(0); +} + Variable Program::relu(const Variable& a) { Instruction instr("relu", {a}); AppendInstruction(instr); return instr.GetOutput(0); } -Variable Program::relu6(const Variable& a, const std::unordered_map& attr_store) { +Variable Program::relu6(const Variable& a) { Instruction instr("relu6", {a}); AppendInstruction(instr); return instr.GetOutput(0); diff --git a/cinn/frontend/syntax.h b/cinn/frontend/syntax.h index 062e0e5a71..44f9f162e4 100644 --- a/cinn/frontend/syntax.h +++ b/cinn/frontend/syntax.h @@ -174,7 +174,12 @@ struct Program { /** * Add two tensors element-wise. */ - Variable elementwise_add(const Variable& a, const Variable& b, int axis = 0); + Variable elementwise_add(const Variable& a, const Variable& b, int axis = -1); + + /** + * Multiply two tensors element-wise. + */ + Variable elementwise_mul(const Variable& a, const Variable& b, int axis = -1); /** * Apply Rectified Linear Unit on input Variable. @@ -184,7 +189,7 @@ struct Program { * @return The result. */ Variable relu(const Variable& a); - Variable relu6(const Variable& a, const std::unordered_map& attr_store); + Variable relu6(const Variable& a); /** * The convolution2D layer calculates the output based on the input, filter @@ -220,6 +225,13 @@ struct Program { Variable scale(const Variable& a, const std::unordered_map& attr_store); Variable softmax(const Variable& a, const std::unordered_map& attr_store); + + Variable sigmoid(const Variable& a); + + Variable slice(const Variable& a, const std::unordered_map& attr_store); + + Variable dropout_infer(const Variable& a, const std::unordered_map& attr_store); + /** * Get \p i-th instruction. */ diff --git a/cinn/hlir/op/broadcast.cc b/cinn/hlir/op/broadcast.cc index 231d456dc2..081843aa82 100644 --- a/cinn/hlir/op/broadcast.cc +++ b/cinn/hlir/op/broadcast.cc @@ -44,7 +44,7 @@ std::shared_ptr StrategyForElementwiseAdd(const framework::NodeAttr auto out = pe::Add(A, B, UniqName("C"), axis); - auto stages = CreateStages({out}); + auto stages = CreateStages({A, B, out}); *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; }); @@ -86,7 +86,7 @@ std::shared_ptr StrategyForElementwiseMul(const framework::NodeAttr auto out = pe::Multiply(A, B, UniqName("C"), axis); - auto stages = CreateStages({out}); + auto stages = CreateStages({A, B, out}); *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; }); diff --git a/cinn/hlir/op/nn.cc b/cinn/hlir/op/nn.cc index 53213b8152..b5f825626d 100644 --- a/cinn/hlir/op/nn.cc +++ b/cinn/hlir/op/nn.cc @@ -599,6 +599,7 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, std::string pool_type = "max"; bool ceil_mode = false; bool exclusive = true; + bool global_pooling = false; std::string data_format = "NCHW"; for (auto &iter : attrs.attr_store) { if (iter.first == "kernel_size") { @@ -615,6 +616,8 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, exclusive = std::get(iter.second); } else if (iter.first == "data_format") { data_format = std::get(iter.second); + } else if (iter.first == "global_pooling") { + global_pooling = std::get(iter.second); } else { LOG(ERROR) << "Unsupported attr: " << iter.first << std::endl; } @@ -623,7 +626,29 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, CHECK(!stride_size.empty()) << "stride_size for pool2d is empty. Please check.\n"; CHECK(!padding_size.empty()) << "padding_size for pool2d is empty. Please check.\n"; - auto out = pe::Pool2d(A.as_tensor_ref(), + ir::Tensor A_tensor = A.as_tensor_ref(); + CHECK_EQ(A_tensor->shape.size(), 4U) << "pool2d's input tensor size should be 4. Please check.\n"; + if (global_pooling) { + int height_index = -1; + int width_index = -1; + if (data_format == "NCHW") { + height_index = 2; + width_index = 3; + } else if (data_format == "NHWC") { + height_index = 1; + width_index = 2; + } else if (data_format == "AnyLayout") { + height_index = 2; + width_index = 3; + data_format = "NCHW"; + } else { + LOG(FATAL) << "Only support 'NCHW' or 'NHWC' or 'AnyLayout' data_format.\n"; + } + kernel_size = {A_tensor->shape[height_index].as_int32(), A_tensor->shape[width_index].as_int32()}; + padding_size = {0, 0, 0, 0}; + } + + auto out = pe::Pool2d(A_tensor, kernel_size, stride_size, padding_size, @@ -633,14 +658,15 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, data_format, UniqName("T_Pool2d_out")); - auto stages = CreateStages(out); + auto stages = CreateStages({A_tensor}); CHECK(out.size() == 1U || out.size() == 2U) << "The size of pe::Pool2d's output should be 1 or 2."; - CHECK(!out_type.empty()) << "Output type of Pool2d is empty! Please check.\n"; - out.back()->InitReduction(stages, ir::Zero(out_type[0])); std::vector res; for (auto &t : out) { - res.push_back(CINNValue(Expr(t.get()))); + stages->InsertLazily(t); + res.push_back(CINNValue(t)); } + CHECK(!out_type.empty()) << "Output type of Pool2d is empty! Please check.\n"; + out.back()->InitReduction(stages, ir::Zero(out_type[0])); res.push_back(CINNValue(stages)); *ret = CINNValuePack{res}; }); @@ -668,7 +694,8 @@ std::shared_ptr StrategyForPool2d(const framework::NodeAttr &attrs, std::vector> InferShapeForPool2d(const std::vector> &inputs_shape, const framework::NodeAttr &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; + CHECK(!inputs_shape.empty() && inputs_shape[0].size() == 4) + << "The input's shape size of pool2d should be 4! Please check again."; auto attr_store = attrs.attr_store; std::vector kernel_size; std::vector stride_size; @@ -677,6 +704,7 @@ std::vector> InferShapeForPool2d(const std::vector>(iter.second); @@ -688,17 +716,18 @@ std::vector> InferShapeForPool2d(const std::vector(iter.second); } else if (iter.first == "exclusive") { exclusive = std::get(iter.second); + } else if (iter.first == "global_pooling") { + global_pooling = std::get(iter.second); } else if (iter.first == "data_format") { data_format = std::get(iter.second); } } - CHECK_EQ(kernel_size.size(), 2U) << "kernel size for pool1d should be 2.\n"; - CHECK_EQ(stride_size.size(), 2U) << "stride_size size for pool1d should be 2.\n"; + CHECK_EQ(kernel_size.size(), 2U) << "kernel size for pool2d should be 2.\n"; + CHECK_EQ(stride_size.size(), 2U) << "stride_size size for pool2d should be 2.\n"; std::vector output_shape1 = inputs_shape[0]; - CHECK_EQ(inputs_shape[0].size(), 4U) << "input_shape size for pool2d should be 4.\n"; - int height_axis = -1; - int width_axis = -1; + int height_axis = -1; + int width_axis = -1; if (data_format == "NCHW") { height_axis = 2; width_axis = 3; @@ -713,6 +742,11 @@ std::vector> InferShapeForPool2d(const std::vector StrategyForSlice(const framework::NodeAttr &attrs, CHECK(!args.empty()) << "The input arguments of slice schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL) << "The input tensor's size of slice schedule is " << arg_pack.size() - << "and it should be equal to 3! Please check."; + << "and it should be equal to 2! Please check."; Expr A [[maybe_unused]] = arg_pack[0]; *ret = arg_pack; }); @@ -1119,6 +1153,76 @@ std::vector InferDtypeForSlice(const std::vector &inputs_type, const return res; } +std::shared_ptr StrategyForDropoutInfer(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + float dropout_prob = 0; + std::string dropout_implementation = "downgrade_in_infer"; + for (auto &iter : attrs.attr_store) { + if (iter.first == "dropout_prob") { + dropout_prob = std::get(iter.second); + } else if (iter.first == "dropout_implementation") { + dropout_implementation = std::get(iter.second); + } else { + LOG(ERROR) << "Unsupported attr: " << iter.first << std::endl; + } + } + + framework::CINNCompute dropout_infer_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of dropout_infer compute is empty! Please check."; + CINNValuePack a = args[0]; + CHECK(!a.empty()) << "The input tensors of dropout_infer compute is empty! Please check."; + Expr A_expr = a[0]; + CHECK(A_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + + auto out = pe::DropoutInfer(A, dropout_prob, dropout_implementation, UniqName("T_dropout_infer_out")); + auto stages = CreateStages({A, out}); + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + }); + + framework::CINNSchedule dropout_infer_schedule([](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of dropout_infer schedule is empty! Please check."; + CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 2UL) << "The input tensor's size of dropout_infer schedule is " << arg_pack.size() + << "and it should be equal to 2! Please check."; + Expr A [[maybe_unused]] = arg_pack[0]; + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(dropout_infer_compute, dropout_infer_schedule, "strategy.dropout_infer.x86", 1); + + return strategy; +} + +std::vector> InferShapeForDropoutInfer(const std::vector> &inputs_shape, + const framework::NodeAttr &attrs) { + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; + float dropout_prob = 0; + std::string dropout_implementation = "downgrade_in_infer"; + for (auto &iter : attrs.attr_store) { + if (iter.first == "dropout_prob") { + dropout_prob = std::get(iter.second); + } else if (iter.first == "dropout_implementation") { + dropout_implementation = std::get(iter.second); + } else { + LOG(ERROR) << "Unsupported attr: " << iter.first << std::endl; + } + } + + std::vector> res{inputs_shape[0]}; + return res; +} + +std::vector InferDtypeForDropoutInfer(const std::vector &inputs_type, const framework::NodeAttr &attrs) { + CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; + std::vector res{inputs_type[0]}; + return res; +} + } // namespace op } // namespace hlir } // namespace cinn @@ -1223,5 +1327,14 @@ CINN_REGISTER_HELPER(nn_ops) { .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForSlice)) .set_support_level(4); + CINN_REGISTER_OP(dropout_infer) + .describe("Downgrade the outcome at inference or keep the same.") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForDropoutInfer) + .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForDropoutInfer)) + .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForDropoutInfer)) + .set_support_level(4); + return true; } diff --git a/cinn/hlir/pe/broadcast.cc b/cinn/hlir/pe/broadcast.cc index e8ca36fb56..64ff58cb1c 100644 --- a/cinn/hlir/pe/broadcast.cc +++ b/cinn/hlir/pe/broadcast.cc @@ -21,13 +21,13 @@ void GetBroadcastShape(const std::vector& shape1, std::vector* common_shape, std::vector* broadcast_flag1, std::vector* broadcast_flag2, + int* axis_offset, const Expr& axis) { CHECK(common_shape); CHECK(broadcast_flag1); CHECK(broadcast_flag2); int size1 = shape1.size(); std::vector shape2_new = shape2; - int axis_offset = -1; if (axis.defined()) { int axis_val = axis.as_int32(); CHECK_GE(axis_val, -1) << "wrong axis: " << axis_val << std::endl; @@ -35,9 +35,9 @@ void GetBroadcastShape(const std::vector& shape1, CHECK_LE(axis_val, int(shape1.size() - shape2.size())) << "wrong axis: " << axis_val << " is not <= " << shape1.size() - shape2.size() << std::endl; if (axis_val >= 0) { - axis_offset = shape1.size() - shape2.size() - axis_val; - for (int i = 1; i <= axis_offset; ++i) { - // specified axis to align, we push the Expr one in tensor B so as to align right with tensor A. + *axis_offset = shape1.size() - shape2.size() - axis_val; + for (int i = 1; i <= *axis_offset; ++i) { + // specified axis to align, we insert Expr one in tensor B so as to align right with tensor A. shape2_new.emplace_back(Expr(1)); common_shape->insert(common_shape->begin(), shape1[size1 - i]); // flag is used to indicate whether to include the indice or not. @@ -50,12 +50,14 @@ void GetBroadcastShape(const std::vector& shape1, int size2 = shape2_new.size(); Expr one(1); int i; - i = axis_offset <= 0 ? 1 : axis_offset + 1; + i = axis_offset <= 0 ? 1 : *axis_offset + 1; for (; i <= std::min(size1, size2); ++i) { + // traverse from right to left to get the output shape and broadcast flag auto* var1 = shape1[size1 - i].As(); auto* var2 = shape2_new[size2 - i].As(); if (MathEqual(shape1[size1 - i], shape2_new[size2 - i])) { common_shape->insert(common_shape->begin(), shape1[size1 - i]); + // broadcast flags are recorded in a reverse order broadcast_flag1->emplace_back(true); broadcast_flag2->emplace_back(true); } else if (MathEqual(one, shape1[size1 - i])) { @@ -100,6 +102,9 @@ void GetBroadcastShape(const std::vector& shape1, } void GetBroadcastIndice(const std::vector& indice, + const Tensor& tensor_a, + const Tensor& tensor_b, + int axis_offset, std::vector* broadcast_indice1, std::vector* broadcast_indice2, const std::vector& broadcast_flags1, @@ -112,10 +117,18 @@ void GetBroadcastIndice(const std::vector& indice, CHECK_GE(indice.size(), flag_size); for (i = 0; i < flag_size; i++) { if (broadcast_flags1[flag_size - 1 - i]) { + // broadcast indices are added from left to right broadcast_indice1->push_back(indice[i]); + } else { + broadcast_indice1->push_back(Expr(0)); } if (broadcast_flags2[flag_size - 1 - i]) { broadcast_indice2->push_back(indice[i]); + } else if (flag_size - i <= tensor_b->shape.size() + axis_offset && + broadcast_indice2->size() < tensor_b->shape.size()) { + // insert indice 0 when have not yet reached the dimension of tensor. Meanwhile we have to consider the case of + // axis alignment. + broadcast_indice2->push_back(Expr(0)); } } } @@ -132,10 +145,13 @@ Tensor Broadcast(const FuncOp& op, std::vector broadcast_flags2; std::vector broadcast_indice1; std::vector broadcast_indice2; + // the counts of left-shift of tensor b so as to right alignment + int axis_offset = 0; - GetBroadcastShape(a->shape, b->shape, &common_shape, &broadcast_flags1, &broadcast_flags2, axis); + GetBroadcastShape(a->shape, b->shape, &common_shape, &broadcast_flags1, &broadcast_flags2, &axis_offset, axis); auto fn = [&](const std::vector& indice) { - GetBroadcastIndice(indice, &broadcast_indice1, &broadcast_indice2, broadcast_flags1, broadcast_flags2); + GetBroadcastIndice( + indice, a, b, axis_offset, &broadcast_indice1, &broadcast_indice2, broadcast_flags1, broadcast_flags2); return op(a(broadcast_indice1), b(broadcast_indice2)); }; Tensor output = Compute(common_shape, fn, output_name); diff --git a/cinn/hlir/pe/nn.cc b/cinn/hlir/pe/nn.cc index 050de1e80b..ef6957d763 100644 --- a/cinn/hlir/pe/nn.cc +++ b/cinn/hlir/pe/nn.cc @@ -11,7 +11,6 @@ #include "cinn/ir/ir_operators.h" #include "cinn/lang/builtin.h" #include "cinn/lang/compute.h" -#include "cinn/optim/ir_simplify.h" namespace cinn { namespace hlir { @@ -645,6 +644,19 @@ std::vector Pool3d(const Tensor &tensor, tensor, kernel_size, stride_size, padding_size, pool_type, axis, ceil_mode, exclusive, UniqName(output_name)); } +Tensor DropoutInfer(const ir::Tensor &tensor, + float dropout_prob, + const std::string &dropout_implementation, + const std::string &output_name) { + if (dropout_implementation == "downgrade_in_infer") { + return Multiply(tensor, Expr(1 - dropout_prob)); + } else if (dropout_implementation == "upscale_in_train") { + return Identity(tensor); + } else { + LOG(FATAL) << "dropout_implementation attr must be 'downgrade_in_infer' or 'upscale_in_train'\n"; + } +} + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/nn.h b/cinn/hlir/pe/nn.h index ee8aaaf365..d5baab7aff 100644 --- a/cinn/hlir/pe/nn.h +++ b/cinn/hlir/pe/nn.h @@ -297,6 +297,22 @@ std::vector Pool3d(const ir::Tensor &tensor, const std::string &data_format = "NCDHW", const std::string &output_name = UniqName("T_Pool3d_out")); +/** + * @brief Perform dropout in the inference which will downgrade the outcome at inference or keep the same. + * @param tensor The input tensor + * @param dropout_prob float. Probability of setting units to zero. + * @param dropout_implementation ['downgrade_in_infer'(default)|'upscale_in_train'] + * 1. downgrade_in_infer(default), downgrade the outcome at inference + * out = input * (1.0 - dropout_prob) + * 2. upscale_in_train, keep the same + * out = input + * @param output_name the name of the output tensor. + */ +ir::Tensor DropoutInfer(const ir::Tensor &tensor, + float dropout_prob, + const std::string &dropout_implementation = "downgrade_in_infer", + const std::string &output_name = UniqName("T_Dropout_infer_out")); + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/transform.cc b/cinn/hlir/pe/transform.cc index 1f339bd6dd..6ea90a6f2c 100644 --- a/cinn/hlir/pe/transform.cc +++ b/cinn/hlir/pe/transform.cc @@ -2,10 +2,10 @@ #include +#include "cinn/common/cas.h" #include "cinn/common/ir_util.h" #include "cinn/ir/tensor.h" #include "cinn/lang/compute.h" -#include "cinn/optim/ir_simplify.h" namespace cinn { namespace hlir { @@ -57,7 +57,6 @@ void GetMatmulIndice(const std::vector& shape1_new, indice1->emplace_back(indices[i]); } Expr reduce_shape1 = Expr(1); - int count = 1; // A reduce axes for (size_t i = x_num_col_dims; i < shape1_new.size(); i++) { reduce_shape1 = reduce_shape1 * shape1_new[i]; @@ -65,13 +64,12 @@ void GetMatmulIndice(const std::vector& shape1_new, auto k = Var(shape1_new[i], reduce_name); reduce_axes->emplace_back(k); indice1->emplace_back(k); - count++; } Expr reduce_shape2 = Expr(1); // B reduce axes for (size_t i = 0; i < y_num_col_dims; i++) { reduce_shape2 = reduce_shape2 * shape2_new[i]; - optim::Simplify(&reduce_shape2); + reduce_shape2 = common::AutoSimplify(reduce_shape2); indice2->emplace_back((*indice1)[indice1->size() - 1 - i]); } @@ -117,9 +115,9 @@ Tensor Matmul(const Tensor& A, &A_indice, &B_indice, &reduce_axes); - return ReduceSum(A(A_indice) * B(B_indice), Expr()); + return ReduceSum(A(A_indice) * B(B_indice)); }; - return Compute(output_shape, fn, name, reduce_axes, output_shape); + return Compute(output_shape, fn, name, reduce_axes); } } // namespace pe diff --git a/cinn/ir/buffer.cc b/cinn/ir/buffer.cc index 2a402e852a..6c6180fda7 100644 --- a/cinn/ir/buffer.cc +++ b/cinn/ir/buffer.cc @@ -12,7 +12,8 @@ namespace ir { std::string TensorGetBufferName(const _Tensor_ *tensor) { CHECK(!tensor->name.empty()); - CHECK(!utils::Startswith(tensor->name, "_")) << "the name with prefix _ is not allowed for tensor"; + CHECK(!utils::Startswith(tensor->name, "_")) + << "the name with prefix _ is not allowed for tensor. Current tensor's name is: " << tensor->name; return "_" + tensor->name; } std::string BufferGetTensorName(const _Buffer_ *buffer) { diff --git a/cinn/utils/string.cc b/cinn/utils/string.cc index 3ec98af144..955ca61b9e 100644 --- a/cinn/utils/string.cc +++ b/cinn/utils/string.cc @@ -77,6 +77,7 @@ void Replace(std::string *s, const std::string &from, const std::string &to) { std::string TransValidVarName(std::string name) { utils::Replace(&name, ".", "__"); utils::Replace(&name, "/", "___"); + name.erase(0, name.find_first_not_of("_")); return name; } diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index f9adebbb2d..4d331b01c9 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -73,3 +73,8 @@ ADD_TEST(NAME test_cinn_real_mobilenetV2 COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH} python3 ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_mobilenetv2.py "${CMAKE_BINARY_DIR}/thirds/MobileNetV2" ) + +ADD_TEST(NAME test_cinn_real_efficientnet + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_BINARY_DIR}/python:$ENV{PYTHONPATH} + python3 ${CMAKE_CURRENT_SOURCE_DIR}/tests/test_efficientnet.py "${CMAKE_BINARY_DIR}/thirds/EfficientNet" +) diff --git a/python/tests/test_efficientnet.py b/python/tests/test_efficientnet.py new file mode 100644 index 0000000000..6eed869309 --- /dev/null +++ b/python/tests/test_efficientnet.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +import paddle as paddle +import paddle.fluid as fluid +from cinn.frontend import * +from cinn import Target +from cinn.framework import * +import unittest +import cinn +from cinn import runtime +from cinn import ir +from cinn import lang +from cinn.common import * +import numpy as np +import paddle.fluid as fluid +import sys + +model_dir = sys.argv.pop() + + +class TestLoadEfficientNetModel(unittest.TestCase): + def setUp(self): + self.target = Target() + self.target.arch = Target.Arch.X86 + self.target.bits = Target.Bit.k64 + self.target.os = Target.OS.Linux + self.model_dir = model_dir + self.x_shape = [1, 3, 224, 224] + self.target_tensor = 'save_infer_model/scale_0' + self.input_tensor = 'image' + + def get_paddle_inference_result(self, model_dir, data): + config = fluid.core.AnalysisConfig(model_dir + '/__model__', + model_dir + '/params') + config.disable_gpu() + config.switch_ir_optim(False) + self.paddle_predictor = fluid.core.create_paddle_predictor(config) + data = fluid.core.PaddleTensor(data) + results = self.paddle_predictor.run([data]) + get_tensor = self.paddle_predictor.get_output_tensor( + self.target_tensor).copy_to_cpu() + return get_tensor + + def test_model(self): + x_data = np.random.random(self.x_shape).astype("float32") + self.executor = Executor([self.input_tensor], [self.x_shape]) + print("self.mode_dir is:", self.model_dir) + # True means load combined model + self.executor.load_paddle_model(self.model_dir, True) + a_t = self.executor.get_tensor(self.input_tensor) + a_t.from_numpy(x_data) + + out = self.executor.get_tensor(self.target_tensor) + out.from_numpy(np.zeros(out.shape(), dtype='float32')) + + self.executor.run() + + out = out.numpy() + target_result = self.get_paddle_inference_result( + self.model_dir, x_data) + + print("result in test_model: \n") + out = out.reshape(-1) + target_result = target_result.reshape(-1) + for i in range(0, out.shape[0]): + if np.abs(out[i] - target_result[i]) > 1e-3: + print("Error! ", i, "-th data has diff with target data:\n", + out[i], " vs: ", target_result[i], ". Diff is: ", + out[i] - target_result[i]) + self.assertTrue(np.allclose(out, target_result, atol=1e-3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_op_nn.py b/python/tests/test_op_nn.py index 57c8fa2c9d..59eceb147e 100644 --- a/python/tests/test_op_nn.py +++ b/python/tests/test_op_nn.py @@ -565,5 +565,51 @@ def test_op(self): self.to_test_op([[3, 4, 5, 6]], [[3, 3, 1, 2]], "slice", attrs) +class OpTest_dropout_infer_0(SingleOpTester): + def init_testcase(self): + self.attrs = framework.NodeAttr() + self.attrs.attr_store = { + "dropout_prob": 0.2, + "dropout_implementation": "downgrade_in_infer", + } + + def create_target_data(self, inputs_data, attrs): + [X] = inputs_data + assert "dropout_implementation" in self.attrs.attr_store + if self.attrs.attr_store[ + "dropout_implementation"] == "downgrade_in_infer": + return X * (1 - self.attrs.attr_store["dropout_prob"]) + else: + return X + + def test_op(self): + self.init_testcase() + self.to_test_op([[2, 1280, 2, 2]], [[2, 1280, 2, 2]], "dropout_infer", + self.attrs) + + +class OpTest_dropout_infer_1(SingleOpTester): + def init_testcase(self): + self.attrs = framework.NodeAttr() + self.attrs.attr_store = { + "dropout_prob": 0.2, + "dropout_implementation": "upscale_in_train", + } + + def create_target_data(self, inputs_data, attrs): + [X] = inputs_data + assert "dropout_implementation" in self.attrs.attr_store + if self.attrs.attr_store[ + "dropout_implementation"] == "downgrade_in_infer": + return X * (1 - self.attrs.attr_store["dropout_prob"]) + else: + return X + + def test_op(self): + self.init_testcase() + self.to_test_op([[2, 1280, 2, 2]], [[2, 1280, 2, 2]], "dropout_infer", + self.attrs) + + if __name__ == "__main__": unittest.main()