Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

add pool1d, pool2d, pool3d, pad PEs and ops and C++/python tests #210

Merged
merged 3 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cinn/common/ir_util.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -96,9 +97,14 @@ Expr make_const(Type t, T v) {
}

template <typename FuncOp>
Expr FoldExpr(FuncOp funcOp, Expr init_value, const std::vector<Expr> &values) {
Expr FoldExpr(FuncOp func_op, const std::vector<Expr> &values) {
Expr init_value;
for (const Expr &val : values) {
init_value = funcOp(init_value, val);
if (!init_value.defined()) {
init_value = val;
} else {
init_value = func_op(val, init_value);
}
}
return init_value;
}
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ foreach(cpp ${srcs})
endforeach()

cc_test(test_op_broadcast SRCS op_broadcast_test.cc DEPS core)
cc_test(test_op_nn SRCS op_nn_test.cc DEPS core)
41 changes: 25 additions & 16 deletions cinn/hlir/op/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,23 @@ std::shared_ptr<OpStrategy> StrategyForElementwiseAdd(const framework::NodeAttr
const std::vector<Type> &out_type,
const Target &target) {
framework::CINNCompute add_compute([&attrs](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of add compute is empty! Please check.\n";
CINNValuePack a = args[0];
Expr A_expr = a[0];
Expr B_expr = a[1];
CHECK_GE(a.size(), 2U) << "at least 2 input tensors for add compute\n";
haozech marked this conversation as resolved.
Show resolved Hide resolved
Expr A_expr = a[0];
Expr B_expr = a[1];
CHECK(A_expr.as_tensor());
CHECK(B_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor B = B_expr.as_tensor_ref();
auto attr_store = attrs.attr_store;
auto iter = attr_store.find("axis");
ir::Tensor A = A_expr.as_tensor_ref();
ir::Tensor B = B_expr.as_tensor_ref();
Expr axis;
if (iter != attr_store.end()) {
axis = Expr(std::get<int>(iter->second));
bool trans_a;
for (auto &iter : attrs.attr_store) {
if (iter.first == "axis") {
axis = Expr(std::get<int>(iter.second));
} else {
LOG(ERROR) << "unsupported attr_store: " << iter.first << std::endl;
}
}

auto out = pe::Add(A, B, UniqName("C"), axis);
Expand All @@ -42,10 +47,11 @@ std::shared_ptr<OpStrategy> StrategyForElementwiseAdd(const framework::NodeAttr
});

framework::CINNSchedule add_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK(!args.empty()) << "The input argument of add schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
Expr A [[maybe_unused]] = arg_pack[0];
*ret = arg_pack;
});

auto strategy = std::make_shared<framework::OpStrategy>();
Expand All @@ -59,9 +65,11 @@ std::shared_ptr<OpStrategy> StrategyForElementwiseMul(const framework::NodeAttr
const std::vector<Type> &out_type,
const Target &target) {
framework::CINNCompute mul_compute([&attrs](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of elementwise_mul compute is empty! Please check.\n";
CINNValuePack a = args[0];
Expr A_expr = a[0];
Expr B_expr = a[1];
CHECK_GE(a.size(), 2U) << "at least 2 input tensors for elementwise_mul compute\n";
wenming2014 marked this conversation as resolved.
Show resolved Hide resolved
Expr A_expr = a[0];
Expr B_expr = a[1];
CHECK(A_expr.as_tensor());
CHECK(B_expr.as_tensor());
ir::Tensor A = A_expr.as_tensor_ref();
Expand All @@ -80,10 +88,11 @@ std::shared_ptr<OpStrategy> StrategyForElementwiseMul(const framework::NodeAttr
});

framework::CINNSchedule mul_schedule([](lang::Args args, lang::RetValue *ret) {
CINNValuePack arg_pack = args[0];
Expr A [[maybe_unused]] = arg_pack[0];
CHECK(!args.empty()) << "The input argument of elementwise_mul schedule is empty! Please check.\n";
CINNValuePack arg_pack = args[0];
CHECK_EQ(arg_pack.size(), 2UL);
*ret = arg_pack;
Expr A [[maybe_unused]] = arg_pack[0];
*ret = arg_pack;
});

auto strategy = std::make_shared<framework::OpStrategy>();
Expand Down
Loading