Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add pool1d, pool2d, pool3d, pad PEs and ops and C++/python tests (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenming2014 authored Sep 14, 2020
1 parent d5b55e5 commit 9f41170
Show file tree
Hide file tree
Showing 18 changed files with 2,085 additions and 114 deletions.
10 changes: 8 additions & 2 deletions cinn/common/ir_util.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -96,9 +97,14 @@ Expr make_const(Type t, T v) {
}

template <typename FuncOp>
Expr FoldExpr(FuncOp funcOp, Expr init_value, const std::vector<Expr> &values) {
Expr FoldExpr(FuncOp func_op, const std::vector<Expr> &values) {
Expr init_value;
for (const Expr &val : values) {
init_value = funcOp(init_value, val);
if (!init_value.defined()) {
init_value = val;
} else {
init_value = func_op(val, init_value);
}
}
return init_value;
}
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ foreach(cpp ${srcs})
endforeach()

cc_test(test_op_broadcast SRCS op_broadcast_test.cc DEPS core)
cc_test(test_op_nn SRCS op_nn_test.cc DEPS core)
41 changes: 25 additions & 16 deletions cinn/hlir/op/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,23 @@ std::shared_ptr<OpStrategy> StrategyForElementwiseAdd(const framework::NodeAttr
const std::vector<Type> &out_type,
const Target &target) {
framework::CINNCompute add_compute([&attrs](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of add compute is empty! Please check.\n";
CINNValuePack a = args[0];
Expr A_expr = a[0];
Expr B_expr = a[1];
CHECK_GE(a.size(), 2U) << "at least 2 input tensors for add compute\n";
Expr A_expr = a[0];
Expr B_expr = a[1];
CHECK(A_expr.as_tensor());
CHECK(B_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor B = B_expr.as_tensor_ref();
auto attr_store = attrs.attr_store;
auto iter = attr_store.find("axis");
ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor B = B_expr.as_tensor_ref();
Expr axis;
if (iter != attr_store.end()) {
axis = Expr(std::get<int>(iter->second));
bool trans_a;
for (auto &iter : attrs.attr_store) {
if (iter.first == "axis") {
axis = Expr(std::get<int>(iter.second));
} else {
LOG(ERROR) << "unsupported attr_store: " << iter.first << std::endl;
}
}

auto out = pe::Add(A, B, UniqName("C"), axis);
Expand All @@ -42,10 +47,11 @@ std::shared_ptr<OpStrategy> StrategyForElementwiseAdd(const framework::NodeAttr
});

framework::CINNSchedule add_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK(!args.empty()) << "The input argument of add schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
Expr A [[maybe_unused]] = arg_pack[0];
*ret = arg_pack;
});

auto strategy = std::make_shared<framework::OpStrategy>();
Expand All @@ -59,9 +65,11 @@ std::shared_ptr<OpStrategy> StrategyForElementwiseMul(const framework::NodeAttr
const std::vector<Type> &out_type,
const Target &target) {
framework::CINNCompute mul_compute([&attrs](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of elementwise_mul compute is empty! Please check.\n";
CINNValuePack a = args[0];
Expr A_expr = a[0];
Expr B_expr = a[1];
CHECK_GE(a.size(), 2U) << "at least 2 input tensors for elementwise_mul compute\n";
Expr A_expr = a[0];
Expr B_expr = a[1];
CHECK(A_expr.as_tensor());
CHECK(B_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
Expand All @@ -80,10 +88,11 @@ std::shared_ptr<OpStrategy> StrategyForElementwiseMul(const framework::NodeAttr
});

framework::CINNSchedule mul_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK(!args.empty()) << "The input argument of elementwise_mul schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
Expr A [[maybe_unused]] = arg_pack[0];
*ret = arg_pack;
});

auto strategy = std::make_shared<framework::OpStrategy>();
Expand Down
Loading

0 comments on commit 9f41170

Please sign in to comment.