Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

add dropout_infer op, PE and tests and EfficientNet model #242

Merged
merged 2 commits into from
Sep 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ function prepare_model {
tar -xvf ResNet18.tar
wget https://paddle-inference-dist.bj.bcebos.com/CINN/MobileNetV2.tar
tar -xvf MobileNetV2.tar
wget https://paddle-inference-dist.bj.bcebos.com/CINN/EfficientNet.tar
tar -xvf EfficientNet.tar
python $workspace/python/tests/fake_model/naive_mul.py
python $workspace/python/tests/fake_model/naive_multi_fc.py
python $workspace/python/tests/fake_model/resnet_model.py
Expand Down
141 changes: 110 additions & 31 deletions cinn/frontend/paddle_model_to_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ void PaddleModelToProgram::AddOpMapper_feed() {

void PaddleModelToProgram::AddOpMapper_fetch() {
op_mappers_["fetch"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto output_name = op_desc.Input("X").front();
LOG(INFO) << "detect model output: [" << output_name << "]";
};
}

void PaddleModelToProgram::AddOpMapper_scale() {
op_mappers_["scale"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
auto x = GetVar(utils::TransValidVarName(x_name));
float scale{};
if (op_desc.HasAttr("scale")) { // the old model format
scale = op_desc.GetAttr<float>("scale");
} else { // the newly refactored format
// load scale tensor
CHECK(!op_desc.Input("ScaleTensor").empty());
CHECK_EQ(op_desc.Input("ScaleTensor").size(), 1UL);
auto* scale_tensor_var = scope_->FindVar(op_desc.Input("ScaleTensor").front());
CHECK(scale_tensor_var) << "No scale tensor found in the scope";
auto& scale_tensor = std::get<hlir::framework::Tensor>(*scale_tensor_var);
Expand All @@ -46,7 +46,7 @@ void PaddleModelToProgram::AddOpMapper_scale() {
std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
attrs["scale"] = scale;
auto out = program_->scale(x, attrs);
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
AddVar(utils::TransValidVarName(out_name), out);
var_model_to_program_map_[out_name] = out->id;
Expand All @@ -55,9 +55,9 @@ void PaddleModelToProgram::AddOpMapper_scale() {

void PaddleModelToProgram::AddOpMapper_mul() {
op_mappers_["mul"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Input("Y").empty());
CHECK_EQ(op_desc.Input("Y").size(), 1UL);
auto y_name = op_desc.Input("Y").front();
auto x = GetVar(utils::TransValidVarName(x_name));
auto y = GetVar(utils::TransValidVarName(y_name));
Expand All @@ -68,7 +68,7 @@ void PaddleModelToProgram::AddOpMapper_mul() {
VLOG(4) << "x shape: " << utils::Join(x->shape, ",");
VLOG(4) << "y shape: " << utils::Join(y->shape, ",");
auto out = program_->mul(x, y, x_num_col_dims, y_num_col_dims);
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
AddVar(utils::TransValidVarName(out_name), out);
var_model_to_program_map_[out_name] = out->id;
Expand All @@ -77,9 +77,9 @@ void PaddleModelToProgram::AddOpMapper_mul() {

void PaddleModelToProgram::AddOpMapper_relu() {
op_mappers_["relu"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
auto x = GetVar(TransValidVarName(x_name));
auto out = program_->relu(x);
Expand All @@ -91,16 +91,16 @@ void PaddleModelToProgram::AddOpMapper_relu() {

void PaddleModelToProgram::AddOpMapper_softmax() {
op_mappers_["softmax"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
if (op_desc.HasAttr("axis")) {
attrs["axis"] = op_desc.GetAttr<int>("axis");
} else {
attrs["axis"] = int(-1);
attrs["axis"] = static_cast<int>(-1);
}
auto x = GetVar(TransValidVarName(x_name));
auto out = program_->softmax(x, attrs);
Expand All @@ -111,11 +111,11 @@ void PaddleModelToProgram::AddOpMapper_softmax() {

void PaddleModelToProgram::AddOpMapper_elementwise_add() {
op_mappers_["elementwise_add"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Input("Y").empty());
CHECK_EQ(op_desc.Input("Y").size(), 1UL);
auto y_name = op_desc.Input("Y").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
int axis = op_desc.GetAttr<int>("axis");

Expand All @@ -128,11 +128,30 @@ void PaddleModelToProgram::AddOpMapper_elementwise_add() {
};
}

void PaddleModelToProgram::AddOpMapper_elementwise_mul() {
op_mappers_["elementwise_mul"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Input("Y").size(), 1UL);
auto y_name = op_desc.Input("Y").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
int axis = op_desc.GetAttr<int>("axis");

auto x = GetVar(TransValidVarName(x_name));
auto y = GetVar(TransValidVarName(y_name));
auto out = program_->elementwise_mul(x, y, axis);

AddVar(TransValidVarName(out_name), out);
var_model_to_program_map_[out_name] = out->id;
};
}

void PaddleModelToProgram::AddOpMapper_relu6() {
op_mappers_["relu6"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
Expand All @@ -141,19 +160,19 @@ void PaddleModelToProgram::AddOpMapper_relu6() {
attrs["threshold"] = op_desc.GetAttr<float>("threshold");

auto x = GetVar(TransValidVarName(x_name));
auto out = program_->relu6(x, attrs);
auto out = program_->relu6(x);

AddVar(TransValidVarName(out_name), out);
var_model_to_program_map_[out_name] = out->id;
};
}
void PaddleModelToProgram::AddOpMapper_depthwise_conv2d() {
op_mappers_["depthwise_conv2d"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("Input").empty());
CHECK_EQ(op_desc.Input("Input").size(), 1UL);
auto x_name = op_desc.Input("Input").front();
CHECK(!op_desc.Input("Filter").empty());
CHECK_EQ(op_desc.Input("Filter").size(), 1UL);
auto y_name = op_desc.Input("Filter").front();
CHECK(!op_desc.Output("Output").empty());
CHECK_EQ(op_desc.Output("Output").size(), 1UL);
auto out_name = op_desc.Output("Output").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
Expand All @@ -176,11 +195,11 @@ void PaddleModelToProgram::AddOpMapper_depthwise_conv2d() {

void PaddleModelToProgram::AddOpMapper_conv2d() {
op_mappers_["conv2d"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("Input").empty());
CHECK_EQ(op_desc.Input("Input").size(), 1UL);
auto x_name = op_desc.Input("Input").front();
CHECK(!op_desc.Input("Filter").empty());
CHECK_EQ(op_desc.Input("Filter").size(), 1UL);
auto y_name = op_desc.Input("Filter").front();
CHECK(!op_desc.Output("Output").empty());
CHECK_EQ(op_desc.Output("Output").size(), 1UL);
auto out_name = op_desc.Output("Output").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
Expand All @@ -203,9 +222,9 @@ void PaddleModelToProgram::AddOpMapper_conv2d() {

void PaddleModelToProgram::AddOpMapper_pool2d() {
op_mappers_["pool2d"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
Expand All @@ -217,6 +236,7 @@ void PaddleModelToProgram::AddOpMapper_pool2d() {
attrs["stride_size"] = op_desc.GetAttr<std::vector<int>>("strides");
CHECK(op_desc.HasAttr("paddings"));
auto padding_size = op_desc.GetAttr<std::vector<int>>("paddings");

if (padding_size.size() == 2) {
padding_size.insert(padding_size.begin(), padding_size.front());
padding_size.push_back(padding_size.back());
Expand All @@ -228,6 +248,8 @@ void PaddleModelToProgram::AddOpMapper_pool2d() {
attrs["exclusive"] = op_desc.GetAttr<bool>("exclusive");
CHECK(op_desc.HasAttr("data_format"));
attrs["data_format"] = op_desc.GetAttr<std::string>("data_format");
CHECK(op_desc.HasAttr("global_pooling"));
attrs["global_pooling"] = op_desc.GetAttr<bool>("global_pooling");

auto x = GetVar(TransValidVarName(x_name));
auto out = program_->pool2d(x, attrs);
Expand All @@ -239,15 +261,15 @@ void PaddleModelToProgram::AddOpMapper_pool2d() {

void PaddleModelToProgram::AddOpMapper_batchnorm() {
op_mappers_["batch_norm"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Input("Scale").empty());
CHECK_EQ(op_desc.Input("Scale").size(), 1UL);
auto scale_name = op_desc.Input("Scale").front();
CHECK(!op_desc.Input("Bias").empty());
CHECK_EQ(op_desc.Input("Bias").size(), 1UL);
auto bias_name = op_desc.Input("Bias").front();
CHECK(!op_desc.Input("Mean").empty());
CHECK_EQ(op_desc.Input("Mean").size(), 1UL);
auto mean_name = op_desc.Input("Mean").front();
CHECK(!op_desc.Input("Variance").empty());
CHECK_EQ(op_desc.Input("Variance").size(), 1UL);
auto variance_name = op_desc.Input("Variance").front();
CHECK(!op_desc.Output("Y").empty());
auto out_name = op_desc.Output("Y").front();
Expand All @@ -267,6 +289,63 @@ void PaddleModelToProgram::AddOpMapper_batchnorm() {
};
}

void PaddleModelToProgram::AddOpMapper_sigmoid() {
op_mappers_["sigmoid"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

auto x = GetVar(TransValidVarName(x_name));
auto out = program_->sigmoid(x);

AddVar(TransValidVarName(out_name), out);
var_model_to_program_map_[out_name] = out->id;
};
}

void PaddleModelToProgram::AddOpMapper_slice() {
op_mappers_["slice"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK_EQ(op_desc.Input("Input").size(), 1UL);
auto x_name = op_desc.Input("Input").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
CHECK(op_desc.HasAttr("starts"));
attrs["starts"] = op_desc.GetAttr<std::vector<int>>("starts");
CHECK(op_desc.HasAttr("ends"));
attrs["ends"] = op_desc.GetAttr<std::vector<int>>("ends");
CHECK(op_desc.HasAttr("axes"));
attrs["axes"] = op_desc.GetAttr<std::vector<int>>("axes");
auto x = GetVar(TransValidVarName(x_name));
auto out = program_->slice(x, attrs);

AddVar(TransValidVarName(out_name), out);
var_model_to_program_map_[out_name] = out->id;
};
}

void PaddleModelToProgram::AddOpMapper_dropout_infer() {
op_mappers_["dropout"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
CHECK(op_desc.HasAttr("dropout_prob"));
attrs["dropout_prob"] = op_desc.GetAttr<float>("dropout_prob");
CHECK(op_desc.HasAttr("dropout_implementation"));
attrs["dropout_implementation"] = op_desc.GetAttr<std::string>("dropout_implementation");
auto x = GetVar(TransValidVarName(x_name));
auto out = program_->dropout_infer(x, attrs);

AddVar(TransValidVarName(out_name), out);
var_model_to_program_map_[out_name] = out->id;
};
}

void PaddleModelToProgram::AddOp(const paddle::cpp::OpDesc& op_desc) {
const auto& op_type = op_desc.Type();
auto it = op_mappers_.find(op_type);
Expand Down
8 changes: 8 additions & 0 deletions cinn/frontend/paddle_model_to_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@ class PaddleModelToProgram {
AddOpMapper_scale();
AddOpMapper_relu();
AddOpMapper_elementwise_add();
AddOpMapper_elementwise_mul();
AddOpMapper_conv2d();
AddOpMapper_batchnorm();
AddOpMapper_pool2d();
AddOpMapper_softmax();
AddOpMapper_relu6();
AddOpMapper_depthwise_conv2d();
AddOpMapper_sigmoid();
AddOpMapper_slice();
AddOpMapper_dropout_infer();
}

std::unique_ptr<Program> operator()(const std::string& model_dir, bool is_combined);
Expand All @@ -52,12 +56,16 @@ class PaddleModelToProgram {
void AddOpMapper_mul();
void AddOpMapper_relu();
void AddOpMapper_elementwise_add();
void AddOpMapper_elementwise_mul();
void AddOpMapper_conv2d();
void AddOpMapper_batchnorm();
void AddOpMapper_pool2d();
void AddOpMapper_softmax();
void AddOpMapper_relu6();
void AddOpMapper_depthwise_conv2d();
void AddOpMapper_sigmoid();
void AddOpMapper_slice();
void AddOpMapper_dropout_infer();
// @}

const std::unordered_map<std::string, Variable>& var_map() const { return var_map_; }
Expand Down
33 changes: 32 additions & 1 deletion cinn/frontend/syntax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,30 @@ Variable Program::softmax(const Variable& a, const std::unordered_map<std::strin
return instr.GetOutput(1);
}

Variable Program::sigmoid(const Variable& a) {
Instruction instr("sigmoid", {a});
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable Program::slice(const Variable& a, const std::unordered_map<std::string, attr_t>& attr_store) {
Instruction instr("slice", {a});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable Program::dropout_infer(const Variable& a, const std::unordered_map<std::string, attr_t>& attr_store) {
Instruction instr("dropout_infer", {a});
for (auto& iter : attr_store) {
instr.SetAttr(iter.first, iter.second);
}
AppendInstruction(instr);
return instr.GetOutput(0);
}

Instruction& Program::operator[](size_t i) {
CHECK_LT(i, instrs_.size());
return instrs_[i];
Expand Down Expand Up @@ -159,13 +183,20 @@ Variable Program::elementwise_add(const Variable& a, const Variable& b, int axis
return instr.GetOutput(0);
}

Variable Program::elementwise_mul(const Variable& a, const Variable& b, int axis) {
Instruction instr("elementwise_mul", {a, b});
instr.SetAttr("axis", axis);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable Program::relu(const Variable& a) {
Instruction instr("relu", {a});
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable Program::relu6(const Variable& a, const std::unordered_map<std::string, attr_t>& attr_store) {
Variable Program::relu6(const Variable& a) {
Instruction instr("relu6", {a});
AppendInstruction(instr);
return instr.GetOutput(0);
Expand Down
Loading