Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix code styles
Browse files Browse the repository at this point in the history
  • Loading branch information
wenming2014 committed Sep 29, 2020
1 parent de23d1e commit e518d36
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 42 deletions.
78 changes: 39 additions & 39 deletions cinn/frontend/paddle_model_to_program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ void PaddleModelToProgram::AddOpMapper_feed() {

void PaddleModelToProgram::AddOpMapper_fetch() {
op_mappers_["fetch"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto output_name = op_desc.Input("X").front();
LOG(INFO) << "detect model output: [" << output_name << "]";
};
}

void PaddleModelToProgram::AddOpMapper_scale() {
op_mappers_["scale"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
auto x = GetVar(utils::TransValidVarName(x_name));
float scale{};
if (op_desc.HasAttr("scale")) { // the old model format
scale = op_desc.GetAttr<float>("scale");
} else { // the newly refactored format
// load scale tensor
CHECK(!op_desc.Input("ScaleTensor").empty());
CHECK_EQ(op_desc.Input("ScaleTensor").size(), 1UL);
auto* scale_tensor_var = scope_->FindVar(op_desc.Input("ScaleTensor").front());
CHECK(scale_tensor_var) << "No scale tensor found in the scope";
auto& scale_tensor = std::get<hlir::framework::Tensor>(*scale_tensor_var);
Expand All @@ -46,7 +46,7 @@ void PaddleModelToProgram::AddOpMapper_scale() {
std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
attrs["scale"] = scale;
auto out = program_->scale(x, attrs);
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
AddVar(utils::TransValidVarName(out_name), out);
var_model_to_program_map_[out_name] = out->id;
Expand All @@ -55,9 +55,9 @@ void PaddleModelToProgram::AddOpMapper_scale() {

void PaddleModelToProgram::AddOpMapper_mul() {
op_mappers_["mul"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Input("Y").empty());
CHECK_EQ(op_desc.Input("Y").size(), 1UL);
auto y_name = op_desc.Input("Y").front();
auto x = GetVar(utils::TransValidVarName(x_name));
auto y = GetVar(utils::TransValidVarName(y_name));
Expand All @@ -68,7 +68,7 @@ void PaddleModelToProgram::AddOpMapper_mul() {
VLOG(4) << "x shape: " << utils::Join(x->shape, ",");
VLOG(4) << "y shape: " << utils::Join(y->shape, ",");
auto out = program_->mul(x, y, x_num_col_dims, y_num_col_dims);
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
AddVar(utils::TransValidVarName(out_name), out);
var_model_to_program_map_[out_name] = out->id;
Expand All @@ -77,9 +77,9 @@ void PaddleModelToProgram::AddOpMapper_mul() {

void PaddleModelToProgram::AddOpMapper_relu() {
op_mappers_["relu"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
auto x = GetVar(TransValidVarName(x_name));
auto out = program_->relu(x);
Expand All @@ -91,16 +91,16 @@ void PaddleModelToProgram::AddOpMapper_relu() {

void PaddleModelToProgram::AddOpMapper_softmax() {
op_mappers_["softmax"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
if (op_desc.HasAttr("axis")) {
attrs["axis"] = op_desc.GetAttr<int>("axis");
} else {
attrs["axis"] = int(-1);
attrs["axis"] = static_cast<int>(-1);
}
auto x = GetVar(TransValidVarName(x_name));
auto out = program_->softmax(x, attrs);
Expand All @@ -111,11 +111,11 @@ void PaddleModelToProgram::AddOpMapper_softmax() {

void PaddleModelToProgram::AddOpMapper_elementwise_add() {
op_mappers_["elementwise_add"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Input("Y").empty());
CHECK_EQ(op_desc.Input("Y").size(), 1UL);
auto y_name = op_desc.Input("Y").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
int axis = op_desc.GetAttr<int>("axis");

Expand All @@ -130,11 +130,11 @@ void PaddleModelToProgram::AddOpMapper_elementwise_add() {

void PaddleModelToProgram::AddOpMapper_elementwise_mul() {
op_mappers_["elementwise_mul"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Input("Y").empty());
CHECK_EQ(op_desc.Input("Y").size(), 1UL);
auto y_name = op_desc.Input("Y").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();
int axis = op_desc.GetAttr<int>("axis");

Expand All @@ -149,9 +149,9 @@ void PaddleModelToProgram::AddOpMapper_elementwise_mul() {

void PaddleModelToProgram::AddOpMapper_relu6() {
op_mappers_["relu6"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
Expand All @@ -168,11 +168,11 @@ void PaddleModelToProgram::AddOpMapper_relu6() {
}
void PaddleModelToProgram::AddOpMapper_depthwise_conv2d() {
op_mappers_["depthwise_conv2d"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("Input").empty());
CHECK_EQ(op_desc.Input("Input").size(), 1UL);
auto x_name = op_desc.Input("Input").front();
CHECK(!op_desc.Input("Filter").empty());
CHECK_EQ(op_desc.Input("Filter").size(), 1UL);
auto y_name = op_desc.Input("Filter").front();
CHECK(!op_desc.Output("Output").empty());
CHECK_EQ(op_desc.Output("Output").size(), 1UL);
auto out_name = op_desc.Output("Output").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
Expand All @@ -195,11 +195,11 @@ void PaddleModelToProgram::AddOpMapper_depthwise_conv2d() {

void PaddleModelToProgram::AddOpMapper_conv2d() {
op_mappers_["conv2d"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("Input").empty());
CHECK_EQ(op_desc.Input("Input").size(), 1UL);
auto x_name = op_desc.Input("Input").front();
CHECK(!op_desc.Input("Filter").empty());
CHECK_EQ(op_desc.Input("Filter").size(), 1UL);
auto y_name = op_desc.Input("Filter").front();
CHECK(!op_desc.Output("Output").empty());
CHECK_EQ(op_desc.Output("Output").size(), 1UL);
auto out_name = op_desc.Output("Output").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
Expand All @@ -222,9 +222,9 @@ void PaddleModelToProgram::AddOpMapper_conv2d() {

void PaddleModelToProgram::AddOpMapper_pool2d() {
op_mappers_["pool2d"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
Expand Down Expand Up @@ -261,15 +261,15 @@ void PaddleModelToProgram::AddOpMapper_pool2d() {

void PaddleModelToProgram::AddOpMapper_batchnorm() {
op_mappers_["batch_norm"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Input("Scale").empty());
CHECK_EQ(op_desc.Input("Scale").size(), 1UL);
auto scale_name = op_desc.Input("Scale").front();
CHECK(!op_desc.Input("Bias").empty());
CHECK_EQ(op_desc.Input("Bias").size(), 1UL);
auto bias_name = op_desc.Input("Bias").front();
CHECK(!op_desc.Input("Mean").empty());
CHECK_EQ(op_desc.Input("Mean").size(), 1UL);
auto mean_name = op_desc.Input("Mean").front();
CHECK(!op_desc.Input("Variance").empty());
CHECK_EQ(op_desc.Input("Variance").size(), 1UL);
auto variance_name = op_desc.Input("Variance").front();
CHECK(!op_desc.Output("Y").empty());
auto out_name = op_desc.Output("Y").front();
Expand All @@ -291,9 +291,9 @@ void PaddleModelToProgram::AddOpMapper_batchnorm() {

void PaddleModelToProgram::AddOpMapper_sigmoid() {
op_mappers_["sigmoid"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

auto x = GetVar(TransValidVarName(x_name));
Expand All @@ -306,9 +306,9 @@ void PaddleModelToProgram::AddOpMapper_sigmoid() {

void PaddleModelToProgram::AddOpMapper_slice() {
op_mappers_["slice"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("Input").empty());
CHECK_EQ(op_desc.Input("Input").size(), 1UL);
auto x_name = op_desc.Input("Input").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
Expand All @@ -328,9 +328,9 @@ void PaddleModelToProgram::AddOpMapper_slice() {

void PaddleModelToProgram::AddOpMapper_dropout_infer() {
op_mappers_["dropout"] = [&](const paddle::cpp::OpDesc& op_desc) {
CHECK(!op_desc.Input("X").empty());
CHECK_EQ(op_desc.Input("X").size(), 1UL);
auto x_name = op_desc.Input("X").front();
CHECK(!op_desc.Output("Out").empty());
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t> attrs;
Expand Down
1 change: 0 additions & 1 deletion cinn/hlir/pe/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,6 @@ Tensor DropoutInfer(const ir::Tensor &tensor,
const std::string &dropout_implementation,
const std::string &output_name) {
if (dropout_implementation == "downgrade_in_infer") {
LOG(INFO) << "DropoutInfer: tensor's shape:" << tensor->shape;
return Multiply(tensor, Expr(1 - dropout_prob));
} else if (dropout_implementation == "upscale_in_train") {
return Identity(tensor);
Expand Down
3 changes: 1 addition & 2 deletions python/tests/test_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def setUp(self):
self.target.os = Target.OS.Linux
self.model_dir = model_dir
self.x_shape = [1, 3, 224, 224]
self.target_tensor = 'pool2d_16.tmp_0'
# self.target_tensor = 'save_infer_model/scale_0'
self.target_tensor = 'save_infer_model/scale_0'
self.input_tensor = 'image'

def get_paddle_inference_result(self, model_dir, data):
Expand Down

0 comments on commit e518d36

Please sign in to comment.