diff --git a/cinn/hlir/framework/graph_compiler.cc b/cinn/hlir/framework/graph_compiler.cc old mode 100755 new mode 100644 index 5fbd469aed..0ec3cb04ab --- a/cinn/hlir/framework/graph_compiler.cc +++ b/cinn/hlir/framework/graph_compiler.cc @@ -310,74 +310,24 @@ std::vector GraphCompiler::GetOpFuncWithIRSchedule( input_output_nodes.push_back(id); } - for (auto& i : GetAllNodeData(node)) { - VLOG(3) << "cinn_inputs.push_back " << i->id(); - cinn_inputs.push_back(common::CINNValue(i->id())); - } - std::vector out_types; std::vector> out_shapes; auto node_datas = GetAllNodeData(node); for (auto node_data : node_datas) { // collect output node data name. - out_types.push_back(type_dict_.at(node_data->id())); - out_shapes.push_back(shape_dict_.at(node_data->id())); - input_output_nodes.push_back(node_data->id()); + std::string out_name = node_data->id(); + VLOG(3) << "cinn_inputs.push_back " << out_name; + cinn_inputs.push_back(common::CINNValue(out_name)); + out_types.push_back(type_dict_.at(out_name)); + out_shapes.push_back(shape_dict_.at(out_name)); + input_output_nodes.push_back(out_name); } - // 2.Call Op's Compute function, using the default stages and LowerVec to get IR tree. auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, target_)); - common::CINNValuePack C = impl->fcompute(common::CINNValuePack{cinn_inputs}); - auto all_arg_tensors = tensor_inputs; - - // 3. Collect tensors and arguments - // Add output tensors to all_arg_tensors - for (int i = 0; i < C->size() - 1; i++) { - ir::Expr temp = C[i]; - // checkout whether the tensor is with buffer. - if (!temp.as_tensor_ref()->buffer.defined() || target_ != common::DefaultNVGPUTarget()) { - all_arg_tensors.push_back(temp.as_tensor_ref()); - } - } - - poly::StageMap stages = C.back(); - std::string func_name_prefix = "fn_"; - auto func = lang::LowerVec(func_name_prefix + node->id(), stages, all_arg_tensors, {}, {}, nullptr, target_, true); - - std::vector schedule_inputs; - for (auto& f : func) { - schedule_inputs.push_back(common::CINNValue(f->body)); - } - for (int i = 0; i < C->size() - 1; i++) { - ir::Expr temp = C[i]; - schedule_inputs.push_back(common::CINNValue(temp.as_tensor_ref()->name)); - } - - // 4. Call Op's Schedule function, optimizing the IR tree by new IR schedule - common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs}); - - // 5. Optimize the LoweredFunc - VLOG(3) << "expr_pack.size() is : " << expr_pack.size(); - std::vector res; - for (int i = 0; i < expr_pack.size(); i++) { - if (func.size() > expr_pack.size()) { - auto new_args = lang::GetArgs(func[i]->body, input_output_nodes); - func[i]->args = new_args; - } - auto temp_buffers = lang::GetTempBuffers(all_arg_tensors, stages, func[i]->body); - func[i]->temp_bufs = temp_buffers; - func[i]->PrepareBufferCastExprs(); - res.push_back(func[i]); - } - for (int i = 0; i < res.size(); i++) { -#ifdef CINN_WITH_CUDA - optim::OptimizeExprGPU(&(res[i]->body)); -#endif - res[i] = optim::Optimize(Expr(res[i]), target_, false).as_lowered_func_ref(); - } - // 6. Return the result. + auto res = + GetFuncFromImpl(impl, common::CINNValuePack{cinn_inputs}, tensor_inputs, input_output_nodes, node->id(), target_); return res; } @@ -1547,6 +1497,65 @@ std::vector GraphCompiler::FusedNodeGroupToLoweredFunc( return funcs; } +std::vector GetFuncFromImpl(const std::shared_ptr& impl, + const common::CINNValuePack& cinn_inputs, + std::vector& all_arg_tensors, + const std::vector& input_output_nodes, + const std::string& node_id, + const Target& target) { + // 1.Call Op's Compute function, using the default stages and LowerVec to get IR tree. + common::CINNValuePack C = impl->fcompute(cinn_inputs); + + // 2. Collect tensors and arguments + // Add output tensors to all_arg_tensors + for (int i = 0; i < C->size() - 1; i++) { + ir::Expr temp = C[i]; + // checkout whether the tensor is with buffer. + if (!temp.as_tensor_ref()->buffer.defined() || target != common::DefaultNVGPUTarget()) { + all_arg_tensors.push_back(temp.as_tensor_ref()); + } + } + + poly::StageMap stages = C.back(); + std::string func_name_prefix = "fn_"; + auto func = lang::LowerVec(func_name_prefix + node_id, stages, all_arg_tensors, {}, {}, nullptr, target, true); + + std::vector schedule_inputs; + for (auto& f : func) { + schedule_inputs.push_back(common::CINNValue(f->body)); + } + for (int i = 0; i < C->size() - 1; i++) { + ir::Expr temp = C[i]; + schedule_inputs.push_back(common::CINNValue(temp.as_tensor_ref()->name)); + } + + // 3. Call Op's Schedule function, optimizing the IR tree by new IR schedule + common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs}); + + // 4. Optimize the LoweredFunc + VLOG(3) << "expr_pack.size() is : " << expr_pack.size(); + std::vector res; + for (int i = 0; i < expr_pack.size(); i++) { + if (func.size() > expr_pack.size()) { + auto new_args = lang::GetArgs(func[i]->body, input_output_nodes); + func[i]->args = new_args; + } + auto temp_buffers = lang::GetTempBuffers(all_arg_tensors, stages, func[i]->body); + func[i]->temp_bufs = temp_buffers; + func[i]->PrepareBufferCastExprs(); + res.push_back(func[i]); + } + for (int i = 0; i < res.size(); i++) { +#ifdef CINN_WITH_CUDA + optim::OptimizeExprGPU(&(res[i]->body)); +#endif + res[i] = optim::Optimize(Expr(res[i]), target, false).as_lowered_func_ref(); + } + + // 5. Return the result. + return res; +} + } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/framework/graph_compiler.h b/cinn/hlir/framework/graph_compiler.h old mode 100755 new mode 100644 index 61efb17cd7..4c920adfc4 --- a/cinn/hlir/framework/graph_compiler.h +++ b/cinn/hlir/framework/graph_compiler.h @@ -200,6 +200,14 @@ std::shared_ptr BuildScope(Target target, const std::shared_ptr& graph, std::shared_ptr scope = nullptr); +// Given params, lower the op to LoweredFunc using new IR Schedule +std::vector GetFuncFromImpl(const std::shared_ptr& impl, + const common::CINNValuePack& cinn_inputs, + std::vector& tensor_inputs, + const std::vector& input_output_nodes, + const std::string& node_id, + const Target& target); + } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/op/broadcast.cc b/cinn/hlir/op/broadcast.cc index 51f6c3973e..6a9fcc37b4 100644 --- a/cinn/hlir/op/broadcast.cc +++ b/cinn/hlir/op/broadcast.cc @@ -207,7 +207,7 @@ std::shared_ptr StrategyForBroadcastTo(const framework::NodeAttr &at if (target.arch == Target::Arch::NVGPU) { pe::IRCudaScheduleInjective(ir_sch, out_shape, target); } else if (target.arch == Target::Arch::X86) { - pe::IRScheduleInjectiveCPU(ir_sch, out_shape, target); + pe::IRScheduleInjectiveCPU(ir_sch, out_shape, target, false); } std::vector res{CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; diff --git a/cinn/hlir/op/op_broadcast_test.cc b/cinn/hlir/op/op_broadcast_test.cc old mode 100755 new mode 100644 index ea6057792f..b3d8166937 --- a/cinn/hlir/op/op_broadcast_test.cc +++ b/cinn/hlir/op/op_broadcast_test.cc @@ -17,14 +17,19 @@ #include #include +#include "cinn/backends/codegen_cuda_dev.h" #include "cinn/backends/llvm/execution_engine.h" #include "cinn/cinn.h" #include "cinn/common/test_helper.h" +#include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" #include "cinn/hlir/op/use_ops.h" #include "cinn/hlir/pe/broadcast.h" +#include "cinn/runtime/flags.h" + +DECLARE_bool(cinn_ir_schedule); namespace cinn { namespace hlir { @@ -46,29 +51,45 @@ TEST(Operator, Operator_ElementWise_Add_Test0) { std::vector type{Float(32)}; common::Target target = common::DefaultHostTarget(); auto impl = OpStrategy::SelectImpl(strategy[add](attrs, inputs, type, {{M.as_int32(), N.as_int32()}}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - ASSERT_EQ(rets.size(), 2UL); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - Module::Builder builder("module0", target); - auto func = Lower("add1", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - builder.AddFunction(func); - LOG(INFO) << "func:\n" << func; - ASSERT_EQ(impl->name, "strategy.elementwise_add.x86"); ASSERT_EQ(add->description, "elementwise_add function"); + std::string func_name = "add1"; + Module::Builder builder("module0", target); + + if (FLAGS_cinn_ir_schedule) { + std::string out_name = "C"; + common::CINNValuePack cinn_input = + common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B), common::CINNValue(out_name)}}; + std::vector input_output_names{"A", "B", out_name}; + + auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_output_names, func_name, target); + + for (auto func : funcs) { + LOG(INFO) << "Test Operator_ElementWise_Add_Test0's Strategy, func is :\n" << func; + builder.AddFunction(func); + } + + } else { + common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + ASSERT_EQ(rets.size(), 2UL); + rets = impl->fschedule(rets); + ASSERT_EQ(rets.size(), 2UL); + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + Expr temp = rets[i]; + inputs.push_back(temp.as_tensor_ref()); + } + auto func = Lower("fn_" + func_name, rets.back(), inputs); + LOG(INFO) << "Test Strategy Codegen:\n" << func; + builder.AddFunction(func); + } + auto jit = backends::ExecutionEngine::Create({}); auto module = builder.Build(); jit->Link(module); - auto fn = jit->Lookup("add1"); + auto fn = jit->Lookup("fn_" + func_name); CHECK(fn); auto fn_ = reinterpret_cast(fn); cinn_buffer_t *A_buf; @@ -94,7 +115,7 @@ TEST(Operator, Operator_ElementWise_Add_Test0) { ASSERT_NEAR(cd[i], ad[i] + bd[i], 1e-5); } } - +#ifdef CINN_WITH_CUDA TEST(Operator, Operator_ElementWise_Add_Test1) { auto add = Operator::Get("elementwise_add"); Operator temp = *add; @@ -108,25 +129,80 @@ TEST(Operator, Operator_ElementWise_Add_Test1) { attrs.attr_store["axis"] = 1; std::vector inputs{A.tensor(), B.tensor()}; std::vector type{Float(32)}; - common::Target target = common::DefaultHostTarget(); - auto impl = OpStrategy::SelectImpl(strategy[add](attrs, inputs, type, {{100, 32}}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - ASSERT_EQ(rets.size(), 2UL); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("add1", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; - + common::Target target = common::DefaultNVGPUTarget(); + auto impl = OpStrategy::SelectImpl(strategy[add](attrs, inputs, type, {{100, 32}}, target)); ASSERT_EQ(impl->name, "strategy.elementwise_add.x86"); ASSERT_EQ(add->description, "elementwise_add function"); + + std::string func_name = "add2"; + Module::Builder builder("module", target); + + if (FLAGS_cinn_ir_schedule) { + std::string out_name = "C"; + common::CINNValuePack cinn_input = + common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B), common::CINNValue(out_name)}}; + std::vector input_output_names{"A", "B", out_name}; + + auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_output_names, func_name, target); + + for (auto func : funcs) { + builder.AddFunction(func); + LOG(INFO) << "Test Operator_ElementWise_Add_Test1's Strategy, func is :\n" << func; + } + } else { + common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(A), common::CINNValue(B)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + ASSERT_EQ(rets.size(), 2UL); + rets = impl->fschedule(rets); + ASSERT_EQ(rets.size(), 2UL); + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + Expr temp = rets[i]; + inputs.push_back(temp.as_tensor_ref()); + } + auto func = Lower("fn_" + func_name, rets.back(), inputs); + LOG(INFO) << "Test Strategy Codegen:\n" << func; + builder.AddFunction(func); + } + + backends::CodeGenCUDA_Dev codegen(target); + + auto module = builder.Build(); + auto source_code = codegen.Compile(module); + LOG(INFO) << "Operator_ElementWise_Add_Test1 source code:\n" << source_code; + + std::string target_code = R"ROC( +extern "C" { + +#include "cinn_cuda_runtime_source.cuh" + +#ifdef __CUDACC_RTC__ +typedef int int32_t; +typedef char int8_t; +#endif + + + +__global__ +void __launch_bounds__(1024) fn_add2(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C) +{ + if (((int)blockIdx.x < 4)) { + if (((int)threadIdx.x < 1024)) { + if ((((1024 * (int)blockIdx.x) + (int)threadIdx.x) < 3200)) { + C[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] = (A[((1024 * (int)blockIdx.x) + (int)threadIdx.x)] + B[((int)threadIdx.x & 31)]); + }; + }; + }; } +} +)ROC"; + if (FLAGS_cinn_ir_schedule) { + ASSERT_EQ(utils::Trim(target_code), source_code); + } +} +#endif + TEST(Operator, Operator_BroadcastTo) { auto broadcast_to = Operator::Get("broadcast_to"); Operator temp = *broadcast_to; @@ -147,20 +223,35 @@ TEST(Operator, Operator_BroadcastTo) { common::Target target = common::DefaultHostTarget(); auto impl = OpStrategy::SelectImpl(strategy[broadcast_to](attrs, inputs, type, {out_shape}, target)); - common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(B)}}; - common::CINNValuePack rets = impl->fcompute(cinn_input); - - ASSERT_EQ(rets.size(), 2UL); - rets = impl->fschedule(rets); - ASSERT_EQ(rets.size(), 2UL); - // the last element is a StageMap - for (int i = 0; i < rets->size() - 1; i++) { - Expr temp = rets[i]; - inputs.push_back(temp.as_tensor_ref()); - } - auto func = Lower("broadcast_to", rets.back(), inputs); - LOG(INFO) << "Test Strategy Codegen:\n" << func; + std::string func_name = "broadcast_to"; + + if (FLAGS_cinn_ir_schedule) { + std::string out_name = "C"; + common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(B), common::CINNValue(out_name)}}; + std::vector input_output_names{"B", out_name}; + + auto funcs = framework::GetFuncFromImpl(impl, cinn_input, inputs, input_output_names, func_name, target); + + for (auto func : funcs) { + LOG(INFO) << "Test Operator_BroadcastTo's Strategy, func is :\n" << func; + } + } else { + common::CINNValuePack cinn_input = common::CINNValuePack{{common::CINNValue(B)}}; + common::CINNValuePack rets = impl->fcompute(cinn_input); + + ASSERT_EQ(rets.size(), 2UL); + rets = impl->fschedule(rets); + ASSERT_EQ(rets.size(), 2UL); + // the last element is a StageMap + for (int i = 0; i < rets->size() - 1; i++) { + Expr temp = rets[i]; + inputs.push_back(temp.as_tensor_ref()); + } + + auto func = Lower("func" + func_name, rets.back(), inputs); + LOG(INFO) << "Test Operator_BroadcastTo's Strategy, func is :\n" << func; + } } TEST(Operator, Operator_BroadcastTo_0) { diff --git a/cinn/ir/ir_schedule.cc b/cinn/ir/ir_schedule.cc index db268b7fd7..927149103a 100644 --- a/cinn/ir/ir_schedule.cc +++ b/cinn/ir/ir_schedule.cc @@ -167,7 +167,8 @@ void IRSchedule::MutateForType(const Expr& loop, ForType for_type, int factor) { auto* for_node = loop.As(); CHECK(for_node) << "loop param must be For node! Please check."; CHECK(for_node->is_serial()) << "loop is not serial, current forloop type is " - << static_cast(for_node->for_type()); + << static_cast(for_node->for_type()) << ", and it cannot become " + << static_cast(for_type); auto loop_copy = optim::IRCopy(loop); auto* new_for_node = loop_copy.As(); CHECK(new_for_node);