From cbca7d79aba307f6641a456cbc26704971318483 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Mon, 19 Jun 2023 19:21:52 +0800 Subject: [PATCH 01/11] delete useless code in OpLower --- cinn/hlir/framework/op_lowering.cc | 697 +---------------------------- cinn/hlir/framework/op_lowering.h | 24 +- 2 files changed, 10 insertions(+), 711 deletions(-) mode change 100755 => 100644 cinn/hlir/framework/op_lowering.h diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 9a213242a1..96e7a45fb6 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -29,20 +29,15 @@ namespace framework { using common::bfloat16; using common::float16; -using framework::Graph; using framework::Node; using framework::NodeData; using framework::OpPatternKind; using framework::shape_t; using framework::StrategyFunction; -using common::GraphEdge; -using common::GraphNode; using common::Type; using namespace lang; -using Comparator = Graph::Group::SharedGroupComparator; -using Hasher = Graph::Group::SharedGroupHasher; using cinn::hlir::op::ExternalApiRegistry; OpLowerer::OpLowerer(const absl::flat_hash_map& type_dict, @@ -59,9 +54,9 @@ std::vector OpLowerer::Lower(GroupPtr& group) { case framework::kElementWise: case framework::kBroadcast: case framework::kInjective: - return IRLowerOp(&OpLowerer::IRElementwiseCompute, &OpLowerer::IRElementwiseSchedule, group); + return IRLowerOp(&OpLowerer::IRElementwiseCompute, group); case framework::kReduction: - return IRLowerOp(&OpLowerer::IRReduceCompute, &OpLowerer::IRReduceSchedule, group); + return IRLowerOp(&OpLowerer::IRReduceCompute, group); case framework::kOutFusible: LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!"; case framework::kNonFusible: @@ -96,9 +91,7 @@ std::vector OpLowerer::LowerWithoutSchedule(GroupPtr& group) { } } -std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, - IRScheduleFunction schedule, - GroupPtr& group) { +std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, GroupPtr& group) { poly::StageMap stages; std::vector arg_tensors; std::unordered_map tensor_map; @@ -316,49 +309,6 @@ std::vector OpLowerer::IRElementwiseCompute(poly::StageMap& stages, return ast_exprs; } -void OpLowerer::IRElementwiseSchedule(ir::IRSchedule& ir_sch, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group, - Node*&, - Node*&) { - VLOG(2) << "IRElementwiseSchedule Group : " << sub_group->group_id; - auto master_node = *group->master_nodes.begin(); - auto manster_tensor = tensor_map[GetNodeData(master_node)->id()]; - - for (int idx = sub_group->nodes.size() - 1; idx >= 0; --idx) { - auto node = sub_group->nodes[idx]; - auto node_tensor = tensor_map[GetNodeData(node)->id()]; - - VLOG(3) << "Schedule node -> " << node->id() << " var : " << node_tensor->name; - if (group->master_nodes.count(node)) { - continue; - } - - if (IsConstOp(node) && !group->output_nodes.count(node)) { - ir_sch.ComputeInline(ir_sch.GetBlock(node_tensor->name)); - continue; - } - - // if node is fringe node or internal node, fringe node is output node of sub-graph - if (group->output_nodes.count(node) || group->internal_nodes.count(node) || sub_group->internal_nodes.count(node)) { - // internal node use buffer - if (!group->output_nodes.count(node)) { - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SetBuffer(node_block, "local", true); - } - - auto node_block = ir_sch.GetBlock(node_tensor->name); - auto master_loops = ir_sch.GetLoops(manster_tensor->name); - ir_sch.SimpleComputeAt(node_block, master_loops.back()); - continue; - } - - // others elemenwise internal node use compute-inline - ir_sch.ComputeInline(ir_sch.GetBlock(node_tensor->name)); - } -} - std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, std::vector& func_args, std::unordered_map& tensor_map, @@ -438,645 +388,6 @@ std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, return ast_exprs; } -void OpLowerer::IRReduceSchedule(ir::IRSchedule& ir_sch, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group, - Node*& master, - Node*& reducer) { - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - auto OrderAssignReduce = [this](ir::IRSchedule& ir_sch, - const std::string& block_name, - const std::vector& axes, - const bool just_reorder = false) { - // reorder none-last reduce axis to last. - // like: shape = [16,16,16,16,16],axes = [1,3] -> new order = [0, 2, 4, 1, 3]. - std::vector order; - int n_out_dims = ir_sch.GetLoops(block_name).size(); - for (int idx = 0; idx < n_out_dims; ++idx) { - if (std::find(axes.begin(), axes.end(), idx) == axes.end()) { - order.push_back(idx); - } - } - for (auto axis : axes) { - order.push_back(axis); - } - ir_sch.Reorder(ir_sch.GetBlock(block_name), order); - - if (just_reorder) { - return; - } - // fuse others none-reduce axis. - int last_dimension_num = n_out_dims - axes.back() - 1; - int index = n_out_dims - last_dimension_num - axes.size(); - - // fuse last_dimension_num - 1 times - for (auto idx = index; idx < index + last_dimension_num - 1; ++idx) { - ir_sch.Fuse(block_name, {index, index + 1}); - } - - auto loops = ir_sch.GetLoops(block_name); - auto psize = ir::GetLoopExtent(loops[index]); - if (psize > this->target_.max_num_threads()) { - for (int idx = this->target_.max_num_threads(); idx > 0; --idx) { - if (psize % idx == 0) { - ir_sch.Split(loops[index], {-1, idx}); - break; - } - CHECK_GT(idx, 1); - } - } - - // fuse index - 1 times - for (int idx = 0; idx < index - 1; ++idx) { - ir_sch.Fuse(block_name, {0, 1}); - } - }; - - auto WithoutLastDimInReduce = [](const std::vector& inshape, std::vector& axes) { - // if last axis is in reduce. - axes = axes.empty() ? inshape : axes; - if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || - std::find(axes.begin(), axes.end(), -1) != axes.end()) { - return false; - } - - int sum_last_axes = 1; - for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { - sum_last_axes *= inshape[idx]; - } - - if (sum_last_axes > 1) { - return true; - } else { - return false; - } - }; - - auto ScheduleAssignReduceWithoutLast = [this, OrderAssignReduce](ir::IRSchedule& ir_sch, - const std::string& block_name, - const std::vector& inshape, - std::vector& axes) { - axes = axes.empty() ? inshape : axes; - int lane = 1; - int max_num_threads = this->target_.max_num_threads(); - for (int idx = axes.back() + 1; idx < inshape.size(); ++idx) { - lane *= inshape[idx]; - } - CHECK_LE(lane, max_num_threads / 2) << "Parallel threads must less equal max_num_threads/2 on gpu!"; - int pos = 0; - int index = axes.size() - 1; - for (; index >= 0; --index) { - if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { - pos = axes[index + 1]; - break; - } - - lane *= inshape[axes[index]]; - if (lane > max_num_threads / 2) { - pos = axes[index]; - break; - } - - if (index == 0) { - pos = axes[0]; - } - } - - if (lane > max_num_threads / 2) { - int prefix = inshape[axes[index]]; - int tail = lane / prefix; - for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; --idx) { - if (prefix % idx == 0) { - ir_sch.Split(block_name, axes[index], {-1, idx}); - break; - } - CHECK_GT(idx - 1, (max_num_threads / 2) / tail) << "idx should greater than (max_num_threads / 2) / tail."; - } - } - - // insert 1 - for (int idx = 0; idx < axes.size() - 1 - index; ++idx) { - auto loops = ir_sch.GetLoops(block_name); - ir_sch.Split(block_name, pos, {-1, ir::GetLoopExtent(loops[pos])}); - } - OrderAssignReduce(ir_sch, block_name, axes); - // return insert 1 - int start_index = ir_sch.GetLoops(block_name).size() - axes.size(); - for (int idx = 0; idx < axes.size(); ++idx) { - auto loops = ir_sch.GetLoops(block_name); - if (ir::GetLoopExtent(loops[start_index]) == 1) { - ir_sch.Fuse({loops[start_index - 1], loops[start_index]}); - } else { - ++start_index; - } - } - }; - - auto ScheduleAssignReduceWithLast = [this, OrderAssignReduce](ir::IRSchedule& ir_sch, - const std::string& block_name, - const std::vector& inshape, - std::vector& axes) { - // find first reduce and second reduce axis. - axes = axes.empty() ? inshape : axes; - int lane = 1; - int index = static_cast(axes.size()) - 1; - auto max_num_threads = this->target_.max_num_threads(); - for (; index >= 0; --index) { - if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { - break; - } - lane *= inshape[axes[index]]; - if (index == 0 && lane <= max_num_threads) { - LOG(FATAL) << "Error! lane is less equal than max_num_threads, Please check!"; - } - if (lane >= max_num_threads / 2) { - if (lane <= max_num_threads) { - --index; - } - break; - } - } - std::vector first_axes(axes.begin(), axes.begin() + index + 1); - if (lane > max_num_threads) { - // last reduce axis size > 1024 - if (index == static_cast(axes.size()) - 1) { - int idx = max_num_threads; - do { - if (lane % idx == 0) { - ir_sch.Split(block_name, axes[index], {-1, idx}); - break; - } - --idx; - } while (idx >= max_num_threads / 2); - // if can't be divide by(1024, 512), it's shouldn't be fused. - CHECK_GE(idx, max_num_threads / 2) << "Check bounds exist, can't fuse!"; - } else { - int axis = axes[index]; - int prefix = inshape[axis]; - int tail = lane / prefix; - for (int idx = max_num_threads / tail; idx > (max_num_threads / 2) / tail; --idx) { - if (prefix % idx == 0) { - ir_sch.Split(block_name, axis, {-1, idx}); - break; - } - CHECK_GT(idx, (max_num_threads / 2) / tail) << "Error, it's shouldn't fuse!"; - } - } - OrderAssignReduce(ir_sch, block_name, first_axes); - } else { - int fuse_times = axes.size() - (index + 1) - 1; - for (int idx = 0; idx < fuse_times; ++idx) { - ir_sch.Fuse(block_name, {axes[index + 1], axes[index + 1] + 1}); - } - OrderAssignReduce(ir_sch, block_name, first_axes, true); - // fuse axis before reduce to bind blockidx. - for (int idx = 0; idx < (inshape.size() - axes.size()) - 1; ++idx) { - ir_sch.Fuse(block_name, {0, 1}); - } - } - }; - - if (master == nullptr && reducer == nullptr) { - auto blocks = ir_sch.GetAllBlocks(); - for (int idx = blocks.size() - 1; idx >= 0; --idx) { - auto block = blocks[idx]; - CHECK(block->as()); - CHECK(block->as()->schedule_block->as()); - if (!tensor_map.count(block->as()->schedule_block->as()->name)) { - continue; - } - - for (auto node : group->master_nodes) { - if (GetNodeData(node)->id() == - block->as()->schedule_block->as()->name) { - if (op_pattern_dict[node->op()] != framework::kReduction) { - master = node; - break; - } - - if (op_pattern_dict[node->op()] == framework::kReduction && master) { - reducer = node; - break; - } - } - } - - if (master && reducer) { - break; - } - } - CHECK((master && reducer) || (!master && !reducer)) << "Can't find Master reducer!"; - if (!master && !reducer) { - master = *group->master_nodes.begin(); - reducer = *group->master_nodes.begin(); - } - - // do master schedule. - if (op_pattern_dict[master->op()] != framework::kReduction) { - VLOG(2) << "Do Master Schedule : " << master->id(); - auto master_data = GetNodeData(master); - CHECK(master_data); - CHECK(tensor_map.count(master_data->id())); - auto master_tensor = tensor_map[master_data->id()]; - auto loops = ir_sch.GetLoops(master_tensor->name); - if (op_pattern_dict[master->op()] == framework::kElementWise) { - ir_sch.FlattenLoops(loops, true); - } else { - ir_sch.FlattenLoops(loops, false); - } - - auto reducer_data = GetNodeData(reducer); - auto reducer_tensor = tensor_map[reducer_data->id()]; - auto rloops = ir_sch.GetLoops(reducer_tensor->name); - - // assign master loops to reducer loops without reduce axis. - int extend = 1; - std::vector factors; - auto sloops = ir_sch.GetLoops(master_tensor->name); - for (auto& loop : rloops) { - // without last reduce axis, so check loop extend. - extend *= loop.As()->extent.as_int32(); - if (extend > sloops.back().As()->extent.as_int32()) { - break; - } - CHECK_LE(extend, sloops.back().As()->extent.as_int32()); - factors.push_back(loop.As()->extent.as_int32()); - } - ir_sch.Split(sloops.back(), factors); - - auto nloops = ir_sch.GetLoops(master_tensor->name); - CHECK_GE(rloops.size(), nloops.size()); - for (int idx = 0; idx < nloops.size(); ++idx) { - nloops[idx].As()->set_bind_info(rloops[idx].As()->bind_info()); - } - } - // do reducer schedule. - { - auto reducer_data = GetNodeData(reducer); - auto reducer_tensor = tensor_map[reducer_data->id()]; - CHECK(reducer->attrs.attr_store.count("dim")); - auto reducer_axes = absl::get>(reducer->attrs.attr_store.at("dim")); - CHECK(reducer->inlinks_in_order().size()); - CHECK(this->shape_dict_.count(reducer->inlinks_in_order()[0]->source()->id())); - auto reducer_shape = this->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); - - if (reducer_axes.empty()) { - for (int i = 0; i < reducer_shape.size(); ++i) { - reducer_axes.emplace_back(i); - } - } - - bool without_last_dim = WithoutLastDimInReduce(reducer_shape, reducer_axes); - - std::unordered_set visited_nodes; - for (auto node : group->master_nodes) { - VLOG(2) << "Schedule reduce node -> " << node->id(); - if (op_pattern_dict[node->op()] != framework::kReduction) { - continue; - } - auto node_data = GetNodeData(node); - auto node_tensor = tensor_map[node_data->id()]; - - if (!group->output_nodes.count(node)) { - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SetBuffer(node_block, "local", true); - } - if (node == reducer) { - continue; - } - auto node_shape = this->shape_dict_.at(node->inlinks_in_order()[0]->source()->id()); - if (without_last_dim) { - VLOG(2) << "Reduce Schedule WithoutLastDimInReduce"; - // find a shape to do simple compute at. - auto tmp_reducer = reducer; - auto tmp_reducer_shape = reducer_shape; - if (node_shape != reducer_shape) { - // try to find the same shape reduce from visited_nodes - for (auto visited : visited_nodes) { - auto shape = this->shape_dict_.at(visited->inlinks_in_order()[0]->source()->id()); - if (shape == node_shape) { - tmp_reducer = visited; - tmp_reducer_shape = shape; - break; - } - } - } - visited_nodes.insert(node); - auto tmp_reducer_data = GetNodeData(tmp_reducer); - auto tmp_reducer_tensor = tensor_map[tmp_reducer_data->id()]; - - // using block shuffle reduce. - if (tensor_map.count(reducer_data->id() + "_1")) { - auto node_0_tensor = tensor_map[node_data->id() + "_0"]; - auto node_0_block = ir_sch.GetBlock(node_0_tensor->name); - - auto tmp_reducer_0_tensor = tensor_map[tmp_reducer_data->id() + "_0"]; - auto tmp_reducer_0_loops = ir_sch.GetLoops(tmp_reducer_0_tensor->name); - - if (tmp_reducer_shape == node_shape) { - ir_sch.SimpleComputeAt(node_0_block, tmp_reducer_0_loops.back()); - // init compute at reduce - int loop_depth = ir_sch.GetLoops(node_0_tensor->name + "__reduce_init").size(); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_0_tensor->name + "__reduce_init"), - ir_sch.GetLoops(node_0_tensor->name)[loop_depth - 1]); - } else { - if (tmp_reducer_0_tensor->shape.back() == node_0_tensor->shape.back()) { - int num_reduce_axis = tmp_reducer_0_tensor->reduce_axis.size(); - CHECK_GE(static_cast(tmp_reducer_0_loops.size()) - num_reduce_axis - 1, 0); - ir_sch.SimpleComputeAt(node_0_block, - tmp_reducer_0_loops[tmp_reducer_0_loops.size() - num_reduce_axis - 1]); - // init compute at reduce - int loop_depth = ir_sch.GetLoops(node_0_tensor->name + "__reduce_init").size(); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_0_tensor->name + "__reduce_init"), - ir_sch.GetLoops(node_0_tensor->name)[loop_depth - 1]); - } else { - CHECK_GE(static_cast(tmp_reducer_0_loops.size()), 2); - ir_sch.SimpleComputeAt(node_0_block, tmp_reducer_0_loops[0]); - } - } - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_tensor->name), - ir_sch.GetLoops(tmp_reducer_tensor->name).back()); - } else { - if (tmp_reducer_shape == node_shape) { - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_tensor->name), - ir_sch.GetLoops(tmp_reducer_tensor->name).back()); - } else { - int num_reduce_axis = tmp_reducer_tensor->reduce_axis.size(); - auto tmp_reducer_loops = ir_sch.GetLoops(tmp_reducer_tensor->name); - CHECK_GE(static_cast(tmp_reducer_loops.size()) - num_reduce_axis - 1, 0); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_tensor->name), - tmp_reducer_loops[tmp_reducer_loops.size() - num_reduce_axis - 1]); - } - // init compute at reduce - int loop_depth = ir_sch.GetLoops(node_tensor->name + "__reduce_init").size(); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_tensor->name + "__reduce_init"), - ir_sch.GetLoops(node_tensor->name)[loop_depth - 1]); - } - } else { - VLOG(2) << "Reduce Schedule WithLastDimInReduce"; - // if with column reduce behind. - if (tensor_map.count(node_data->id() + "_1")) { - auto reducer_1_tensor = tensor_map[reducer_data->id() + "_1"]; - auto reducer_0_tensor = tensor_map[reducer_data->id() + "_0"]; - - auto node_1_tensor = tensor_map[node_data->id() + "_1"]; - auto node_0_tensor = tensor_map[node_data->id() + "_0"]; - - auto node_block_1 = ir_sch.GetBlock(node_1_tensor->name); - auto node_block_0 = ir_sch.GetBlock(node_0_tensor->name); - auto node_block = ir_sch.GetBlock(node_tensor->name); - - ir_sch.SimpleComputeAt(node_block, ir_sch.GetLoops(reducer_tensor->name).back()); - ir_sch.SimpleComputeAt(node_block_0, ir_sch.GetLoops(reducer_0_tensor->name).back()); - ir_sch.SimpleComputeAt(node_block_1, ir_sch.GetLoops(reducer_1_tensor->name).back()); - // init compute at reduce - int loop_depth = ir_sch.GetLoops(node_1_tensor->name + "__reduce_init").size(); - ir_sch.SimpleComputeAt(ir_sch.GetBlock(node_1_tensor->name + "__reduce_init"), - ir_sch.GetLoops(node_1_tensor->name)[loop_depth - 1]); - } else if (tensor_map.count(node_data->id() + "_0")) { - auto reducer_0_tensor = tensor_map[reducer_data->id() + "_0"]; - auto node_0_tensor = tensor_map[node_data->id() + "_0"]; - - auto node_0_block = ir_sch.GetBlock(node_0_tensor->name); - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SimpleComputeAt(node_block, ir_sch.GetLoops(reducer_tensor->name).back()); - ir_sch.SimpleComputeAt(node_0_block, ir_sch.GetLoops(reducer_0_tensor->name).back()); - } else { - LOG(FATAL) << "Error! Unkown Reduce Type, Please Check!"; - } - } - } - - if (without_last_dim) { - if (tensor_map.count(reducer_data->id() + "_1")) { - auto reducer_tensor = tensor_map[GetNodeData(reducer)->id()]; - auto reducer_loops = ir_sch.GetLoops(reducer_tensor->name); - ir_sch.SyncThreads(reducer_loops[0], false); - } - } - } - } - - // master node - auto master_data = GetNodeData(master); - CHECK(master_data); - CHECK(tensor_map.count(master_data->id())); - auto master_tensor = tensor_map[master_data->id()]; - auto master_shape = this->shape_dict_.at(master_data->id()); - auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); - - // reducer node - auto reducer_data = GetNodeData(reducer); - CHECK(reducer_data); - CHECK(reducer->inlinks_in_order().size()); - CHECK(this->shape_dict_.count(reducer->inlinks_in_order()[0]->source()->id())); - auto reducer_shape = this->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); - auto reduce_size = std::accumulate(reducer_shape.begin(), reducer_shape.end(), 1, std::multiplies()); - - CHECK(reducer->attrs.attr_store.count("dim")); - auto reducer_axes = absl::get>(reducer->attrs.attr_store.at("dim")); - if (reducer_axes.empty()) { - for (int i = 0; i < reducer_shape.size(); ++i) { - reducer_axes.emplace_back(i); - } - } - - VLOG(2) << "master node : " << master->id() << " ,reducer node : " << reducer->id(); - for (int idx = sub_group->nodes.size() - 1; idx >= 0; --idx) { - auto node = sub_group->nodes[idx]; - - if (node == master) { - continue; - } - if (op_pattern_dict[node->op()] == framework::kReduction) { - continue; - } - auto node_data = GetNodeData(node); - auto node_tensor = tensor_map[node_data->id()]; - - VLOG(3) << "Schedule node -> " << node->id() << " var : " << node_tensor->name; - // for x86 schedule. - if (this->target_ == common::DefaultHostTarget()) { - LOG(FATAL) << "X86 Not implemented"; - } - - bool dont_compute_inline = - group->output_nodes.count(node) || group->internal_nodes.count(node) || sub_group->internal_nodes.count(node); - if (!dont_compute_inline) { - auto consumers = GetConsumers(node); - for (auto& consumer : consumers) { - if (op_pattern_dict[consumer->op()] == framework::kReduction) { - dont_compute_inline = true; - break; - } - } - } - - // if is const op, do compute inline. - if (IsConstOp(node) && !group->output_nodes.count(node)) { - dont_compute_inline = false; - } - - // if node is internal node or output, try to copy schedule from fellow node - if (dont_compute_inline) { - VLOG(2) << "Reduce Schedule for Elementwise Type"; - // if node is not output node, set buffer. - if (!group->output_nodes.count(node)) { - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SetBuffer(node_block, "local", true); - } - // node is after reduce - auto node_shape = this->shape_dict_.at(node_data->id()); - auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); - if (node_shape == master_shape || node_size == master_size) { - VLOG(2) << "Do Elementwise Type After Reduce!"; - auto loops = ir_sch.GetLoops(node_tensor->name); - // flat loop and tensor shape - if (op_pattern_dict[master->op()] == framework::kElementWise) { - ir_sch.FlattenLoops(loops, true); - } else { - ir_sch.FlattenLoops(loops, false); - } - // split loop to assign master loop - std::vector factors; - auto mloops = ir_sch.GetLoops(master_tensor->name); - for (auto& loop : mloops) { - factors.push_back(loop.As()->extent.as_int32()); - } - loops = ir_sch.GetLoops(node_tensor->name); - ir_sch.Split(loops.back(), factors); - // note do simple compute at - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SimpleComputeAt(node_block, mloops.back()); - continue; - } - // do elementwise flat - auto loops = ir_sch.GetLoops(node_tensor->name); - if (op_pattern_dict[node->op()] == framework::kElementWise) { - ir_sch.FlattenLoops(loops, true); - } else { - ir_sch.FlattenLoops(loops, false); - } - // node is before reduce. - if (WithoutLastDimInReduce(reducer_shape, reducer_axes)) { - VLOG(2) << "Reduce Schedule for WithoutLastDimInReduce"; - // find a shape to do simple compute at. - auto tmp_reducer = reducer; - auto tmp_reducer_shape = reducer_shape; - auto tmp_reducer_size = std::accumulate(reducer_shape.begin(), reducer_shape.end(), 1, std::multiplies()); - // node shape. - auto node_shape = this->shape_dict_.at(node_data->id()); - if (node_shape != tmp_reducer_shape && node_size != reduce_size) { - // try to find the same shape reduce from visited_nodes - for (auto rnode : group->master_nodes) { - if (op_pattern_dict[rnode->op()] != framework::kReduction) { - continue; - } - auto shape = this->shape_dict_.at(rnode->inlinks_in_order()[0]->source()->id()); - auto size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - if (shape == node_shape || size == node_size) { - tmp_reducer = rnode; - tmp_reducer_size = size; - tmp_reducer_shape = shape; - break; - } - } - } - // do split - CHECK(node_shape == tmp_reducer_shape || node_size == tmp_reducer_size); - - auto loops = ir_sch.GetLoops(node_tensor->name); - ir_sch.Split(loops.back(), tmp_reducer_shape); - - auto tmp_reducer_data = GetNodeData(tmp_reducer); - auto tmp_reducer_tensor = tensor_map[tmp_reducer_data->id()]; - // if used block shuffle reduce - if (tensor_map.count(tmp_reducer_data->id() + "_1")) { - ScheduleAssignReduceWithoutLast(ir_sch, node_tensor->name, tmp_reducer_shape, reducer_axes); - auto tmp_reducer_tensor_0 = tensor_map[tmp_reducer_data->id() + "_0"]; - auto tmp_reducer_loops_0 = ir_sch.GetLoops(tmp_reducer_tensor_0->name); - auto node_loops = ir_sch.GetLoops(node_tensor->name); - if (node_loops.size() < tmp_reducer_loops_0.size()) { - ir_sch.Split(node_tensor->name, 0, {-1, ir::GetLoopExtent(node_loops[0])}); - } - CHECK_EQ(ir_sch.GetLoops(node_tensor->name).size(), tmp_reducer_loops_0.size()) - << "node loops and reduce loops must be equal!"; - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SimpleComputeAt(node_block, tmp_reducer_loops_0.back()); - } else { - OrderAssignReduce(ir_sch, node_tensor->name, reducer_axes); - - auto node_block = ir_sch.GetBlock(node_tensor->name); - auto node_loops = ir_sch.GetLoops(node_tensor->name); - if (node_loops.size() < ir_sch.GetLoops(tmp_reducer_tensor->name).size()) { - ir_sch.Split(node_tensor->name, 0, {-1, ir::GetLoopExtent(node_loops[0])}); - } - CHECK_EQ(ir_sch.GetLoops(node_tensor->name).size(), ir_sch.GetLoops(tmp_reducer_tensor->name).size()) - << "node loop size and reduce loop size must be equal!"; - ir_sch.SimpleComputeAt(node_block, ir_sch.GetLoops(tmp_reducer_tensor->name).back()); - } - } else { - VLOG(2) << "Reduce Schedule for WithLastDimInReduce"; - if (tensor_map.count(reducer_data->id() + "_1")) { - { - auto node_loops = ir_sch.GetLoops(node_tensor->name); - ir_sch.Split(node_loops.back(), reducer_shape); - } - - ScheduleAssignReduceWithLast(ir_sch, node_tensor->name, reducer_shape, reducer_axes); - auto reducer_1_tensor = tensor_map[reducer_data->id() + "_1"]; - auto reducer_1_block = ir_sch.GetBlock(reducer_1_tensor->name); - auto reducer_1_loops = ir_sch.GetLoops(reducer_1_block); - - auto node_loops = ir_sch.GetLoops(node_tensor->name); - if (ir_sch.GetLoops(node_tensor->name).size() < ir_sch.GetLoops(reducer_1_block).size()) { - ir_sch.Split(node_tensor->name, 0, {-1, ir::GetLoopExtent(node_loops[0])}); - } - - CHECK_EQ(ir_sch.GetLoops(node_tensor->name).size(), ir_sch.GetLoops(reducer_1_block).size()) - << "node loop size and reduce loop size must be equal!" << ir_sch.GetModule().GetExprs().at(0); - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SimpleComputeAt(node_block, reducer_1_loops.back()); - } else { - auto reducer_0_tensor = tensor_map[reducer_data->id() + "_0"]; - auto reducer_0_block = ir_sch.GetBlock(reducer_0_tensor->name); - auto reducer_0_loops = ir_sch.GetLoops(reducer_0_block); - { - auto node_loops = ir_sch.GetLoops(node_tensor->name); - std::vector factors; - for (auto& loop : reducer_0_loops) { - factors.push_back(loop.As()->extent.as_int32()); - } - ir_sch.Split(node_loops.back(), factors); - } - - auto node_loops = ir_sch.GetLoops(node_tensor->name); - if (node_loops.size() < reducer_0_loops.size()) { - ir_sch.Split(node_tensor->name, 0, {-1, ir::GetLoopExtent(node_loops[0])}); - } - CHECK_EQ(ir_sch.GetLoops(node_tensor->name).size(), reducer_0_loops.size()) - << "node loop size and reduce loop size must be equal!" << ir_sch.GetModule().GetExprs().at(0); - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.SimpleComputeAt(node_block, reducer_0_loops.back()); - } - } - continue; - } - - // others elemenwise internal node use compute-inline - VLOG(2) << "Do Elementwise ComputeInline!"; - auto loops = ir_sch.GetLoops(node_tensor->name); - if (op_pattern_dict[node->op()] == framework::kElementWise) { - ir_sch.FlattenLoops(loops, true); - } else { - ir_sch.FlattenLoops(loops, false); - } - auto node_block = ir_sch.GetBlock(node_tensor->name); - ir_sch.ComputeInline(node_block); - } -} - std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, bool apply_impl_schedule) { VLOG(3) << "LowerNonFusibleOp Group : " << group->group_id; // get input tensor and output tensor @@ -1201,7 +512,7 @@ std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, boo } } -// do compute +// group schedule void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, const GroupPtr& group, const std::unordered_map& tensor_map) { diff --git a/cinn/hlir/framework/op_lowering.h b/cinn/hlir/framework/op_lowering.h old mode 100755 new mode 100644 index 6e291afeb6..072f296b91 --- a/cinn/hlir/framework/op_lowering.h +++ b/cinn/hlir/framework/op_lowering.h @@ -45,12 +45,6 @@ typedef std::vector (OpLowerer::*IRComputeFunction)(poly::StageMap&, const GroupPtr&, const GroupPtr&, bool); -typedef void (OpLowerer::*IRScheduleFunction)(ir::IRSchedule& ir_sch, - std::unordered_map&, - const GroupPtr&, - const GroupPtr&, - Node*&, - Node*&); class OpLowerer { public: @@ -61,27 +55,21 @@ class OpLowerer { std::vector LowerWithoutSchedule(GroupPtr& group); private: - std::vector IRLowerOp(IRComputeFunction, IRScheduleFunction, GroupPtr&); + std::vector IRLowerOp(IRComputeFunction, GroupPtr&); std::vector IRLowerNonFusibleOp(GroupPtr&, bool); std::vector IRLowerOpWithoutSchedule(IRComputeFunction, GroupPtr&); -#define DEFINE_IR_COMPUTE_SCHDULE(type) \ +#define DEFINE_IR_COMPUTE(type) \ std::vector IR##type##Compute(poly::StageMap& stages, \ std::vector& func_args, \ std::unordered_map& tensor_map, \ const GroupPtr& group, \ const GroupPtr& sub_group, \ - bool apply_impl_schedule = false); \ - void IR##type##Schedule(ir::IRSchedule& ir_sch, \ - std::unordered_map& tensor_map, \ - const GroupPtr& group, \ - const GroupPtr& sub_group, \ - Node*& first, \ - Node*& second); + bool apply_impl_schedule = false); // compute and schedule - DEFINE_IR_COMPUTE_SCHDULE(Elementwise); - DEFINE_IR_COMPUTE_SCHDULE(Reduce); - DEFINE_IR_COMPUTE_SCHDULE(OutEWiseFusable); + DEFINE_IR_COMPUTE(Elementwise); + DEFINE_IR_COMPUTE(Reduce); + DEFINE_IR_COMPUTE(OutEWiseFusable); void IRSchedule(ir::IRSchedule& ir_sch, const GroupPtr& group, From fbac9612c39858742d2206a53ff37acf355fce39 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 20 Jun 2023 20:05:37 +0800 Subject: [PATCH 02/11] merge IRLowerOp to LowerGroup --- cinn/hlir/framework/op_lowering.cc | 98 ++++++++++++++++++++++++++++-- cinn/hlir/framework/op_lowering.h | 3 + 2 files changed, 97 insertions(+), 4 deletions(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 96e7a45fb6..b30096fe30 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -54,9 +54,9 @@ std::vector OpLowerer::Lower(GroupPtr& group) { case framework::kElementWise: case framework::kBroadcast: case framework::kInjective: - return IRLowerOp(&OpLowerer::IRElementwiseCompute, group); + return LowerGroup(&OpLowerer::IRElementwiseCompute, group, true); case framework::kReduction: - return IRLowerOp(&OpLowerer::IRReduceCompute, group); + return LowerGroup(&OpLowerer::IRReduceCompute, group, true); case framework::kOutFusible: LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!"; case framework::kNonFusible: @@ -76,9 +76,9 @@ std::vector OpLowerer::LowerWithoutSchedule(GroupPtr& group) { case framework::kElementWise: case framework::kBroadcast: case framework::kInjective: - return IRLowerOpWithoutSchedule(&OpLowerer::IRElementwiseCompute, group); + return LowerGroup(&OpLowerer::IRElementwiseCompute, group, false); case framework::kReduction: - return IRLowerOpWithoutSchedule(&OpLowerer::IRReduceCompute, group); + return LowerGroup(&OpLowerer::IRReduceCompute, group, false); case framework::kOutFusible: LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!"; case framework::kNonFusible: @@ -91,6 +91,96 @@ std::vector OpLowerer::LowerWithoutSchedule(GroupPtr& group) { } } +std::vector OpLowerer::LowerGroup(IRComputeFunction compute, + GroupPtr& group, + bool with_group_schedule) { + poly::StageMap stages; + std::vector arg_tensors; + std::unordered_map tensor_map; + // do compute. + VLOG(3) << "group->fused_sub_groups.size() is : " << group->fused_sub_groups.size(); + std::vector ast_exprs; + if (group->fused_sub_groups.size() == 0) { + ast_exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, group, /*apply_impl_schedule = */ false); + } else { + for (auto& sub_group : group->fused_sub_groups) { + auto exprs = + (this->*compute)(stages, arg_tensors, tensor_map, group, sub_group, /*apply_impl_schedule = */ false); + ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end()); + } + } + ir::ModuleExpr mod_expr(ast_exprs); + ir::IRSchedule ir_sch(mod_expr); + ir_sch.MergeExprs(); + VLOG(3) << "Before group lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + if (with_group_schedule) { + // do schedule. + IRSchedule(ir_sch, group, tensor_map); + VLOG(3) << "After group schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + } + + // function args + group->input_names.clear(); + std::vector func_args; + for (auto& args : arg_tensors) { + // input node data name. + group->input_names.push_back(args->name); + // input args + func_args.emplace_back(args->buffer, ir::Argument::IO::kInput); + } + + group->output_names.clear(); + for (auto& node : group->output_nodes) { + // output node data name. + for (auto node_data : GetAllNodeData(node)) { + group->output_names.push_back(node_data->id()); + } + // collect all output tensor. + std::string post = ""; + std::string prefix = GetNodeData(node)->id(); + for (int idx = 0; idx < 1; ++idx) { + CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix; + if (!tensor_map.count(prefix + post)) { + break; + } + auto tensor = tensor_map[prefix + post]; + arg_tensors.push_back(tensor); + // output args + func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); + // update post + post = "_" + std::to_string(idx); + } + } + + std::unordered_set args_map; + for (auto arg : func_args) { + args_map.insert(arg.name()); + } + + for (auto& tensor : tensor_map) { + if (args_map.count("_" + tensor.first)) { + continue; + } + arg_tensors.push_back(tensor.second); + // use the underlying tensor name to be consistent with the argument name in the lowered function + group->output_names.push_back(tensor.second->name); + func_args.emplace_back(tensor.second->buffer, ir::Argument::IO::kOutput); + } + + auto func_body = ir_sch.GetModule().GetExprs().at(0); +#ifdef CINN_WITH_CUDA + optim::OptimizeExprGPU(&(func_body)); +#endif + + auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body); + auto func = + ir::_LoweredFunc_::Make(group->GetFuncName(), func_args, ir_sch.GetModule().GetExprs().at(0), temp_buffers); + func->PrepareBufferCastExprs(); + func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); + + return {func}; +} + std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, GroupPtr& group) { poly::StageMap stages; std::vector arg_tensors; diff --git a/cinn/hlir/framework/op_lowering.h b/cinn/hlir/framework/op_lowering.h index 072f296b91..c947c982f3 100644 --- a/cinn/hlir/framework/op_lowering.h +++ b/cinn/hlir/framework/op_lowering.h @@ -58,6 +58,9 @@ class OpLowerer { std::vector IRLowerOp(IRComputeFunction, GroupPtr&); std::vector IRLowerNonFusibleOp(GroupPtr&, bool); std::vector IRLowerOpWithoutSchedule(IRComputeFunction, GroupPtr&); + + std::vector LowerGroup(IRComputeFunction, GroupPtr&, bool); + #define DEFINE_IR_COMPUTE(type) \ std::vector IR##type##Compute(poly::StageMap& stages, \ std::vector& func_args, \ From 51928663756ffe6bf12e74fe85a4ecba15040df4 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Mon, 26 Jun 2023 11:25:32 +0800 Subject: [PATCH 03/11] Organize the code for OpLowerer --- .../auto_gen_rule/auto_inline_test.cc | 2 +- .../search_space/auto_gen_rule/test_helper.cc | 6 +- cinn/auto_schedule/task/tune_task.cc | 2 +- .../tests/performance_comparison_test.cc | 2 +- cinn/hlir/framework/op_lowering.cc | 669 ++++++------------ cinn/hlir/framework/op_lowering.h | 75 +- cinn/hlir/framework/op_lowering_util.cc | 10 +- cinn/hlir/framework/op_lowering_util.h | 4 +- 8 files changed, 258 insertions(+), 512 deletions(-) diff --git a/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc b/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc index a8d8ee9f9d..c984c7d5ba 100644 --- a/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc +++ b/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc @@ -156,7 +156,7 @@ TEST(AutoInline, AddReluInline) { auto op_lowerer = std::make_unique(dtype_dict, shape_dict, target); EXPECT_EQ(graph->fusion_groups.size(), 1UL); - std::vector funcs = op_lowerer->LowerWithoutSchedule(graph->fusion_groups[0]); + std::vector funcs = op_lowerer->Lower(graph->fusion_groups[0], false, false); VLOG(6) << "Expr before auto inline: " << funcs[0]->body; diff --git a/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc b/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc index 9ad001a23b..699f57137d 100644 --- a/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc +++ b/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc @@ -58,11 +58,7 @@ ir::IRSchedule TestAutoGenRuleBase::MakeIRSchedule(const frontend::Program& test auto& shape_dict = graph->GetMutableAttrs>("infershape"); hlir::framework::OpLowerer op_lowerer(dtype_dict, shape_dict, target_); - if (apply_manual_schedule) { - lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front()); - } else { - lowered_funcs_ = op_lowerer.LowerWithoutSchedule(graph->fusion_groups.front()); - } + lowered_funcs_ = op_lowerer.Lower(graph->fusion_groups.front(), apply_manual_schedule, apply_manual_schedule); CHECK(!lowered_funcs_.empty()) << "lowered_funcs_ is empty"; std::vector bodys; diff --git a/cinn/auto_schedule/task/tune_task.cc b/cinn/auto_schedule/task/tune_task.cc index 80998c3825..baeabc6346 100644 --- a/cinn/auto_schedule/task/tune_task.cc +++ b/cinn/auto_schedule/task/tune_task.cc @@ -37,7 +37,7 @@ void TuneTask::Initialize(const absl::flat_hash_maplowered_funcs = op_lowerer->LowerWithoutSchedule(subgraph); + this->lowered_funcs = op_lowerer->Lower(subgraph, false, false); this->output_names = GetOutputNamesFromLoweredFunc(this->lowered_funcs); this->serialized_key = SerializeToString(shape_dict, dtype_dict); } diff --git a/cinn/auto_schedule/tests/performance_comparison_test.cc b/cinn/auto_schedule/tests/performance_comparison_test.cc index 35a1e58063..baae8a1c44 100644 --- a/cinn/auto_schedule/tests/performance_comparison_test.cc +++ b/cinn/auto_schedule/tests/performance_comparison_test.cc @@ -130,7 +130,7 @@ class PerformanceTester : public ::testing::Test { compile_options.groups = graph->fusion_groups; for (auto group : graph->fusion_groups) { - compile_options.lowered_funcs.push_back(op_lowerer->LowerWithoutSchedule(group)); + compile_options.lowered_funcs.push_back(op_lowerer->Lower(group, false, false)); } VLOG(3) << "===========================No Schedule LoweredFunc Begin==========================="; diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index b30096fe30..1c86fc4e76 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -45,7 +45,7 @@ OpLowerer::OpLowerer(const absl::flat_hash_map& type_dict, const Target& target) : type_dict_(type_dict), shape_dict_(shape_dict), target_(target) {} -std::vector OpLowerer::Lower(GroupPtr& group) { +std::vector OpLowerer::Lower(GroupPtr& group, bool apply_op_schedule, bool apply_group_schedule) { VLOG(3) << "Lowering Group : " << group->group_id << " , Op Pattern : " << group->op_pattern_kind; group->input_names.clear(); group->output_names.clear(); @@ -54,13 +54,15 @@ std::vector OpLowerer::Lower(GroupPtr& group) { case framework::kElementWise: case framework::kBroadcast: case framework::kInjective: - return LowerGroup(&OpLowerer::IRElementwiseCompute, group, true); + return LowerGroup( + group, apply_op_schedule, apply_group_schedule, &OpLowerer::ElementwiseScheduleDetermineFunction); case framework::kReduction: - return LowerGroup(&OpLowerer::IRReduceCompute, group, true); + return LowerGroup(group, apply_op_schedule, apply_group_schedule, &OpLowerer::ReduceScheduleDetermineFunction); case framework::kOutFusible: LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!"; case framework::kNonFusible: - return IRLowerNonFusibleOp(group, /*apply_impl_schedule = */ true); + return LowerGroup( + group, apply_op_schedule, apply_group_schedule, &OpLowerer::NonFusibleScheduleDetermineFunction); default: LOG(FATAL) << "Group Pattern Kind Is Unknown!"; } @@ -69,543 +71,275 @@ std::vector OpLowerer::Lower(GroupPtr& group) { } } -std::vector OpLowerer::LowerWithoutSchedule(GroupPtr& group) { - VLOG(3) << "Lowering Group : " << group->group_id << " , Op Pattern : " << group->op_pattern_kind; - if (FLAGS_cinn_ir_schedule) { - switch (group->op_pattern_kind) { - case framework::kElementWise: - case framework::kBroadcast: - case framework::kInjective: - return LowerGroup(&OpLowerer::IRElementwiseCompute, group, false); - case framework::kReduction: - return LowerGroup(&OpLowerer::IRReduceCompute, group, false); - case framework::kOutFusible: - LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!"; - case framework::kNonFusible: - return IRLowerNonFusibleOp(group, /*apply_impl_schedule = */ false); - default: - LOG(FATAL) << "Group Pattern Kind kNonFusible Is Not Implemented!"; - } - } else { - LOG(FATAL) << "Previous IR Schedule Is Not Implemented!"; +bool OpLowerer::ElementwiseScheduleDetermineFunction(Node* node) { return true; } + +bool OpLowerer::ReduceScheduleDetermineFunction(Node* node) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + if (op_pattern_dict[node->op()] == framework::kReduction) { + return true; } + return false; } -std::vector OpLowerer::LowerGroup(IRComputeFunction compute, - GroupPtr& group, - bool with_group_schedule) { - poly::StageMap stages; - std::vector arg_tensors; - std::unordered_map tensor_map; - // do compute. +bool OpLowerer::NonFusibleScheduleDetermineFunction(Node* node) { return true; } + +std::vector OpLowerer::LowerGroup(GroupPtr& group, + bool apply_op_schedule, + bool apply_group_schedule, + ScheduleDetermineFunction schedule_determine_func) { + // 1.Do compute, lower and schedule for each op. VLOG(3) << "group->fused_sub_groups.size() is : " << group->fused_sub_groups.size(); - std::vector ast_exprs; - if (group->fused_sub_groups.size() == 0) { - ast_exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, group, /*apply_impl_schedule = */ false); - } else { - for (auto& sub_group : group->fused_sub_groups) { - auto exprs = - (this->*compute)(stages, arg_tensors, tensor_map, group, sub_group, /*apply_impl_schedule = */ false); - ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end()); - } + std::vector nodes = group->CollectNodes(); + if (nodes.size() == 1 && nodes[0]->op()->name == "custom_call") { + return LowerCustomCall(group); } - ir::ModuleExpr mod_expr(ast_exprs); + std::vector group_func_arg_tensors; + std::unordered_map tensor_map; + bool do_op_schedule = apply_group_schedule || apply_op_schedule; + std::vector func_bodies = + LowerOps(nodes, &group_func_arg_tensors, &tensor_map, do_op_schedule, schedule_determine_func); + + // 2.Do group schedule. + ir::ModuleExpr mod_expr(func_bodies); ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); - VLOG(3) << "Before group lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - if (with_group_schedule) { - // do schedule. - IRSchedule(ir_sch, group, tensor_map); + VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + if (apply_group_schedule && nodes.size() > 1) { + DoGroupSchedule(ir_sch, group, tensor_map); VLOG(3) << "After group schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); } - // function args - group->input_names.clear(); - std::vector func_args; - for (auto& args : arg_tensors) { - // input node data name. - group->input_names.push_back(args->name); - // input args - func_args.emplace_back(args->buffer, ir::Argument::IO::kInput); - } + // 3.Do post-processing, + // including preparing function args and temporary variables, + // applying low-level optimization passes, etc. + return PostProcess(&ir_sch, group, tensor_map, &group_func_arg_tensors, do_op_schedule); +} - group->output_names.clear(); - for (auto& node : group->output_nodes) { - // output node data name. - for (auto node_data : GetAllNodeData(node)) { - group->output_names.push_back(node_data->id()); - } - // collect all output tensor. - std::string post = ""; - std::string prefix = GetNodeData(node)->id(); - for (int idx = 0; idx < 1; ++idx) { - CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix; - if (!tensor_map.count(prefix + post)) { - break; - } - auto tensor = tensor_map[prefix + post]; - arg_tensors.push_back(tensor); - // output args - func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); - // update post - post = "_" + std::to_string(idx); +std::vector OpLowerer::LowerCustomCall(GroupPtr& group) { + std::vector nodes = group->CollectNodes(); + CHECK_EQ(nodes.size(), 1); + Node* node = nodes[0]; + std::vector op_func_arg_tensors; + std::unordered_map tensor_map; + for (auto& node_data : GetInputNodeData(node)) { + CHECK(node_data); + ir::Tensor tensor; + if (!tensor_map.count(node_data->id())) { + tensor = GetTensor(node_data, this->type_dict_, this->shape_dict_); + // record tensor. + tensor_map[node_data->id()] = tensor; + // input name. + group->input_names.push_back(node_data->id()); + } else { + tensor = tensor_map[node_data->id()]; } + op_func_arg_tensors.push_back(tensor); } - std::unordered_set args_map; - for (auto arg : func_args) { - args_map.insert(arg.name()); - } - - for (auto& tensor : tensor_map) { - if (args_map.count("_" + tensor.first)) { - continue; - } - arg_tensors.push_back(tensor.second); - // use the underlying tensor name to be consistent with the argument name in the lowered function - group->output_names.push_back(tensor.second->name); - func_args.emplace_back(tensor.second->buffer, ir::Argument::IO::kOutput); + std::vector out_types; + std::vector> out_shapes; + auto node_datas = GetAllNodeData(node); + for (auto node_data : node_datas) { + group->output_names.push_back(node_data->id()); + out_types.push_back(this->type_dict_.at(node_data->id())); + out_shapes.push_back(this->shape_dict_.at(node_data->id())); } - - auto func_body = ir_sch.GetModule().GetExprs().at(0); -#ifdef CINN_WITH_CUDA - optim::OptimizeExprGPU(&(func_body)); -#endif - - auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body); - auto func = - ir::_LoweredFunc_::Make(group->GetFuncName(), func_args, ir_sch.GetModule().GetExprs().at(0), temp_buffers); - func->PrepareBufferCastExprs(); - func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); - - return {func}; -} - -std::vector OpLowerer::IRLowerOp(IRComputeFunction compute, GroupPtr& group) { - poly::StageMap stages; - std::vector arg_tensors; - std::unordered_map tensor_map; - // do compute. - VLOG(3) << "group->fused_sub_groups.size() is : " << group->fused_sub_groups.size(); - std::vector ast_exprs; - if (group->fused_sub_groups.size() == 0) { - ast_exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, group, /*apply_impl_schedule = */ true); + auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); + auto impl = OpStrategy::SelectImpl( + cinn_strategy[node->op()](node->attrs, op_func_arg_tensors, out_types, out_shapes, target_)); + std::string external_api; + if (node->attrs.attr_store.count("custom_call")) { + external_api = absl::get(node->attrs.attr_store.at("custom_call")); } else { - for (auto& sub_group : group->fused_sub_groups) { - auto exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, sub_group, /*apply_impl_schedule = */ true); - ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end()); - } + external_api = ExternalApiRegistry::Global()->GetExternalApi(node, target_); } - ir::ModuleExpr mod_expr(ast_exprs); - ir::IRSchedule ir_sch(mod_expr); - ir_sch.MergeExprs(); - - Node* first = nullptr; - Node* second = nullptr; - - VLOG(3) << "Before IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - // do schedule. - IRSchedule(ir_sch, group, tensor_map); - VLOG(3) << "After IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - // function args + std::vector compute_args = {common::CINNValue(group->GetFuncName()), + common::CINNValue(external_api)}; + common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{compute_args}); + CHECK_EQ(pack.size(), 1UL); + // reset input names as extern api input args can't be remove duplicate. group->input_names.clear(); - std::vector func_args; - for (auto& args : arg_tensors) { - // input node data name. - group->input_names.push_back(args->name); - // input args - func_args.emplace_back(args->buffer, ir::Argument::IO::kInput); + for (auto& inode : node->inlinks_in_order()) { + group->input_names.push_back(inode->source()->as()->id()); } - - group->output_names.clear(); - for (auto& node : group->output_nodes) { - // output node data name. - for (auto node_data : GetAllNodeData(node)) { - group->output_names.push_back(node_data->id()); - } - // collect all output tensor. - std::string post = ""; - std::string prefix = GetNodeData(node)->id(); - for (int idx = 0; idx < 1; ++idx) { - CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix; - if (!tensor_map.count(prefix + post)) { - break; - } - auto tensor = tensor_map[prefix + post]; - arg_tensors.push_back(tensor); - // output args - func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); - // update post - post = "_" + std::to_string(idx); - } - } - auto func_body = ir_sch.GetModule().GetExprs().at(0); -#ifdef CINN_WITH_CUDA - optim::OptimizeExprGPU(&(func_body)); -#endif - - auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body); - auto func = - ir::_LoweredFunc_::Make(group->GetFuncName(), func_args, ir_sch.GetModule().GetExprs().at(0), temp_buffers); - func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); - return {func}; + return {pack[0].operator ir::Expr().as_lowered_func_ref()}; } -std::vector OpLowerer::IRLowerOpWithoutSchedule(IRComputeFunction compute, GroupPtr& group) { - poly::StageMap stages; - std::vector arg_tensors; - std::unordered_map tensor_map; - // do compute. - VLOG(3) << "group->fused_sub_groups.size() is : " << group->fused_sub_groups.size(); - std::vector ast_exprs; - if (group->fused_sub_groups.size() == 0) { - ast_exprs = (this->*compute)(stages, arg_tensors, tensor_map, group, group, /*apply_impl_schedule = */ false); - } else { - for (auto& sub_group : group->fused_sub_groups) { - auto exprs = - (this->*compute)(stages, arg_tensors, tensor_map, group, sub_group, /*apply_impl_schedule = */ false); - ast_exprs.insert(ast_exprs.end(), exprs.begin(), exprs.end()); - } - } - ir::ModuleExpr mod_expr(ast_exprs); - ir::IRSchedule ir_sch(mod_expr); - ir_sch.MergeExprs(); - - VLOG(3) << "After IRLowerOp compute, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); +std::vector OpLowerer::PostProcess(ir::IRSchedule* ir_sch, + const GroupPtr& group, + const std::unordered_map& tensor_map, + std::vector* group_func_arg_tensors, + bool done_op_schedule) { // function args group->input_names.clear(); - std::vector func_args; - for (auto& args : arg_tensors) { + std::vector group_func_args; + for (auto& arg_tensor : *group_func_arg_tensors) { // input node data name. - group->input_names.push_back(args->name); + group->input_names.push_back(arg_tensor->name); // input args - func_args.emplace_back(args->buffer, ir::Argument::IO::kInput); + group_func_args.emplace_back(arg_tensor->buffer, ir::Argument::IO::kInput); } group->output_names.clear(); for (auto& node : group->output_nodes) { - // output node data name. - for (auto node_data : GetAllNodeData(node)) { - group->output_names.push_back(node_data->id()); - } // collect all output tensor. - std::string post = ""; - std::string prefix = GetNodeData(node)->id(); - for (int idx = 0; idx < 1; ++idx) { - CHECK(tensor_map.count(prefix)) << "Can't find output tensor " << prefix; - if (!tensor_map.count(prefix + post)) { - break; - } - auto tensor = tensor_map[prefix + post]; - arg_tensors.push_back(tensor); + for (auto node_data : GetAllNodeData(node)) { + std::string output_node_data_name = node_data->id(); + group->output_names.push_back(output_node_data_name); + CHECK(tensor_map.count(output_node_data_name)) << "Can't find output tensor " << output_node_data_name; + auto tensor = tensor_map.at(output_node_data_name); + // output arg tensors + group_func_arg_tensors->push_back(tensor); // output args - func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); - // update post - post = "_" + std::to_string(idx); + group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); } } - std::unordered_set args_map; - for (auto arg : func_args) { - args_map.insert(arg.name()); - } + if (!done_op_schedule) { + std::unordered_set args_set; + for (auto arg : group_func_args) { + args_set.insert(arg.name()); + } - for (auto& tensor : tensor_map) { - if (args_map.count("_" + tensor.first)) { - continue; + for (auto& tensor_pair : tensor_map) { + if (args_set.count("_" + tensor_pair.first)) { + continue; + } + group_func_arg_tensors->push_back(tensor_pair.second); + // use the underlying tensor name to be consistent with the argument name in the lowered function + group->output_names.push_back(tensor_pair.second->name); + group_func_args.emplace_back(tensor_pair.second->buffer, ir::Argument::IO::kOutput); } - arg_tensors.push_back(tensor.second); - // use the underlying tensor name to be consistent with the argument name in the lowered function - group->output_names.push_back(tensor.second->name); - func_args.emplace_back(tensor.second->buffer, ir::Argument::IO::kOutput); } - auto func_body = ir_sch.GetModule().GetExprs().at(0); + auto func_body = ir_sch->GetModule().GetExprs().at(0); #ifdef CINN_WITH_CUDA optim::OptimizeExprGPU(&(func_body)); #endif - auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body); - auto func = - ir::_LoweredFunc_::Make(group->GetFuncName(), func_args, ir_sch.GetModule().GetExprs().at(0), temp_buffers); + poly::StageMap stages; + auto temp_buffers = lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body); + auto func = ir::_LoweredFunc_::Make( + group->GetFuncName(), group_func_args, ir_sch->GetModule().GetExprs().at(0), temp_buffers); func->PrepareBufferCastExprs(); func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); - return {func}; } -std::vector OpLowerer::IRElementwiseCompute(poly::StageMap& stages, - std::vector& func_tensors, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group, - bool apply_impl_schedule) { - VLOG(2) << "ElementwiseCompute Group : " << sub_group->group_id; +std::vector OpLowerer::LowerOps(const std::vector& nodes, + std::vector* group_func_arg_tensors, + std::unordered_map* tensor_map, + bool apply_op_schedule, + ScheduleDetermineFunction schedule_determine_func) { auto& strategy = Operator::GetAttrs("CINNStrategy"); - - std::vector ast_exprs; - for (auto& node : sub_group->nodes) { - VLOG(4) << "Lower op: " << node->op()->name; - auto node_data = GetNodeData(node); - CHECK_EQ(GetAllNodeData(node).size(), 1U); - std::vector cinn_inputs; - std::vector tensor_inputs = - std::move(CollectInputTensor(node, func_tensors, tensor_map, this->type_dict_, this->shape_dict_)); - for (auto& tensor : tensor_inputs) { - cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); - } - // set tensor name = node data name - cinn_inputs.push_back(common::CINNValue(node_data->id())); - + std::vector func_bodies; + for (Node* node : nodes) { + // 1. select Op impl std::vector out_types; std::vector> out_shapes; - out_types.push_back(this->type_dict_.at(node_data->id())); - out_shapes.push_back(this->shape_dict_.at(node_data->id())); - auto impl = - OpStrategy::SelectImpl(strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, this->target_)); - // do compute - common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{cinn_inputs}); - CHECK_EQ(pack.size(), 2U); - - Expr expr = pack[0]; - poly::StageMap node_stages = pack.back(); - tensor_inputs.push_back(expr.as_tensor_ref()); - tensor_map[node_data->id()] = expr.as_tensor_ref(); - - auto func = lang::LowerVec("fn_" + node->id(), node_stages, tensor_inputs, {}, {}, nullptr, this->target_, true); - CHECK_EQ(func.size(), 1); - - if (apply_impl_schedule) { - std::vector schedule_inputs; - // collect tensor - for (int idx = 0; idx < pack.size() - 1; ++idx) { - CHECK(pack[idx].is_tensor()); - schedule_inputs.push_back(common::CINNValue(pack[idx])); - } - for (auto& f : func) { - schedule_inputs.push_back(common::CINNValue(f->body)); - } - // do ast tree schedule - common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs}); - - CHECK_EQ(expr_pack.size(), 1); - Expr ast_expr = expr_pack[0]; - ast_exprs.push_back(ast_expr); + std::vector node_datas = GetAllNodeData(node); + for (const auto& node_data : node_datas) { + out_types.push_back(this->type_dict_.at(node_data->id())); + out_shapes.push_back(this->shape_dict_.at(node_data->id())); + } + std::vector op_func_arg_tensors = + std::move(CollectInputTensor(node, group_func_arg_tensors, tensor_map, this->type_dict_, this->shape_dict_)); + auto op_impl = OpStrategy::SelectImpl( + strategy[node->op()](node->attrs, op_func_arg_tensors, out_types, out_shapes, this->target_)); + + // 2. perform the lower process of Op + std::vector funcs = DoOpLower(node, op_impl, tensor_map, &op_func_arg_tensors); + + if (apply_op_schedule && (this->*schedule_determine_func)(node)) { + // 3. perform the schedule of Op + func_bodies.push_back(DoOpSchedule(op_impl, op_func_arg_tensors, funcs)); } else { - ast_exprs.push_back(func[0]->body); - } - } - - return ast_exprs; -} - -std::vector OpLowerer::IRReduceCompute(poly::StageMap& stages, - std::vector& func_args, - std::unordered_map& tensor_map, - const GroupPtr& group, - const GroupPtr& sub_group, - bool apply_impl_schedule) { - VLOG(2) << "ReduceCompute Group : " << sub_group->group_id; - auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - - std::vector ast_exprs; - for (auto& node : sub_group->nodes) { - auto node_data = GetNodeData(node); - VLOG(3) << "In ReduceCompute, process node: " << node->id() << " with op type: " << node->op()->name; - - std::vector cinn_inputs; - std::vector tensor_inputs = - std::move(CollectInputTensor(node, func_args, tensor_map, this->type_dict_, this->shape_dict_)); - for (auto& tensor : tensor_inputs) { - cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); - } - cinn_inputs.push_back(common::CINNValue(node_data->id())); - - std::vector out_types; - std::vector> out_shapes; - - out_types.push_back(this->type_dict_.at(node_data->id())); - out_shapes.push_back(this->shape_dict_.at(node_data->id())); - - auto impl = - OpStrategy::SelectImpl(cinn_strategy[node->op()](node->attrs, tensor_inputs, out_types, out_shapes, target_)); - // do compute - common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{cinn_inputs}); - - CHECK_GE(pack.size(), 2UL); - CHECK_LE(pack.size(), 5UL); - poly::StageMap tmp_stages = pack.back(); - - std::string post = ""; - for (int idx = 0; idx < pack.size() - 1; ++idx) { - Expr expr = pack[idx]; - tensor_map[node_data->id() + post] = expr.as_tensor_ref(); - // As op may has more than 1 output tensor, using id + "_0"/"_1" as key. - post = "_" + std::to_string(idx); - - // Insert outout tensors - if (!expr.as_tensor_ref()->buffer.defined() || this->target_ != common::DefaultNVGPUTarget()) { - tensor_inputs.push_back(expr.as_tensor_ref()); - } - } - auto func = lang::LowerVec("fn_" + node->id(), tmp_stages, tensor_inputs, {}, {}, nullptr, this->target_, true); - - // node is kReduction - if (op_pattern_dict[node->op()] == framework::kReduction && apply_impl_schedule) { - std::vector schedule_inputs; - // collect tensor - for (int idx = 0; idx < pack.size() - 1; ++idx) { - CHECK(pack[idx].is_tensor()); - schedule_inputs.push_back(common::CINNValue(pack[idx])); - } - for (auto& f : func) { - schedule_inputs.push_back(common::CINNValue(f->body)); + for (const ir::LoweredFunc& func : funcs) { + func_bodies.push_back(func->body); } - // do ast tree schedule - common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs}); - // ast tree after schedule. - Expr ast_expr = expr_pack[0]; - ast_exprs.push_back(ast_expr); - } else if (group->master_nodes.count(node)) { - // as master node should copy transform from reducer, left it to reduce schedule. - ast_exprs.push_back(func[0]->body); - } else { - ast_exprs.push_back(func[0]->body); } } - return ast_exprs; + return func_bodies; } -std::vector OpLowerer::IRLowerNonFusibleOp(GroupPtr& group, bool apply_impl_schedule) { - VLOG(3) << "LowerNonFusibleOp Group : " << group->group_id; - // get input tensor and output tensor - CHECK(group->nodes.size() || group->fused_sub_groups.size()); - auto& cinn_strategy = Operator::GetAttrs("CINNStrategy"); - auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); - - auto node = group->fused_sub_groups.size() ? group->fused_sub_groups[0]->nodes.front() : group->nodes.front(); - VLOG(3) << "GetOpFunc of op " << node->id(); - std::vector inputs; +std::vector OpLowerer::DoOpLower(Node* node, + std::shared_ptr op_impl, + std::unordered_map* tensor_map, + std::vector* op_func_arg_tensors) { + VLOG(4) << "Do lower with Compute, op: " << node->op()->name; std::vector cinn_inputs; - - std::vector args; - std::unordered_map tensor_map; - for (auto& node_data : GetInputNodeData(node)) { - CHECK(node_data); - ir::Tensor tensor; - if (!tensor_map.count(node_data->id())) { - tensor = GetTensor(node_data, this->type_dict_, this->shape_dict_); - // record tensor. - tensor_map[node_data->id()] = tensor; - // input name. - group->input_names.push_back(node_data->id()); - // input type. - args.emplace_back(tensor->buffer, ir::Argument::IO::kInput); - } else { - tensor = tensor_map[node_data->id()]; - } - inputs.push_back(tensor); - cinn_inputs.push_back(common::CINNValue(tensor)); + for (const ir::Tensor& tensor : *op_func_arg_tensors) { + cinn_inputs.push_back(common::CINNValue(ir::Expr(tensor))); } - - std::vector out_types; - std::vector> out_shapes; - auto node_datas = GetAllNodeData(node); - for (auto node_data : node_datas) { - VLOG(3) << "cinn_inputs.push_back " << node_data->id(); - group->output_names.push_back(node_data->id()); - out_types.push_back(this->type_dict_.at(node_data->id())); - out_shapes.push_back(this->shape_dict_.at(node_data->id())); + // set tensor name = node data name + std::vector node_datas = GetAllNodeData(node); + for (const NodeData* node_data : node_datas) { cinn_inputs.push_back(common::CINNValue(node_data->id())); } - auto impl = OpStrategy::SelectImpl(cinn_strategy[node->op()](node->attrs, inputs, out_types, out_shapes, target_)); - // if node op is custom_call, apply custom_call compute. - if (node->op()->name == "custom_call") { - std::string external_api; - if (node->attrs.attr_store.count("custom_call")) { - external_api = absl::get(node->attrs.attr_store.at("custom_call")); + // 1.Do compute + common::CINNValuePack pack = op_impl->fcompute(common::CINNValuePack{cinn_inputs}); + + poly::StageMap tmp_stages = pack.back(); + std::string post = ""; + for (int idx = 0; idx < pack.size() - 1; ++idx) { + Expr expr = pack[idx]; + // Insert the output tensor defined by Compute into the tensor_map + if (pack.size() - 1 > node_datas.size()) { + // Some nodes may output multiple temp tensors in their Compute definition, + // but only one output node_data in the graph, and we use id + "_0"/"_1" as key. + (*tensor_map)[node_datas[0]->id() + post] = expr.as_tensor_ref(); + post = "_" + std::to_string(idx); } else { - external_api = ExternalApiRegistry::Global()->GetExternalApi(node, target_); + // If the number of output tensors defined by Compute is same with the output node_data on the graph, then there + // is a one-to-one correspondence. CHECK_EQ(node_datas[idx]->id(), expr.as_tensor_ref()->name); + (*tensor_map)[node_datas[idx]->id()] = expr.as_tensor_ref(); } - std::vector compute_args = {common::CINNValue(group->GetFuncName()), - common::CINNValue(external_api)}; - common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{compute_args}); - CHECK_EQ(pack.size(), 1UL); - // reset input names as extern api input args can't be remove duplicate. - group->input_names.clear(); - for (auto& inode : node->inlinks_in_order()) { - group->input_names.push_back(inode->source()->as()->id()); + + // Insert output tensors into function arg + if (!expr.as_tensor_ref()->buffer.defined() || this->target_ != common::DefaultNVGPUTarget()) { + op_func_arg_tensors->push_back(expr.as_tensor_ref()); + expr.as_tensor_ref()->WithBuffer(); } - return {pack[0].operator ir::Expr().as_lowered_func_ref()}; } - common::CINNValuePack pack = impl->fcompute(common::CINNValuePack{cinn_inputs}); - for (int i = 0; i < pack->size() - 1; i++) { - ir::Expr temp = pack[i]; - // checkout whether the tensor is with buffer. - if (!temp.as_tensor_ref()->buffer.defined() || this->target_ != common::DefaultNVGPUTarget()) { - inputs.push_back(temp.as_tensor_ref()); - temp.as_tensor_ref()->WithBuffer(); - args.emplace_back(temp.as_tensor_ref()->buffer, ir::Argument::IO::kOutput); - } + // 2.Do lower + std::vector funcs = + lang::LowerVec("fn_" + node->id(), tmp_stages, *op_func_arg_tensors, {}, {}, nullptr, this->target_, true); + VLOG(4) << "Lower op: " << node->op()->name << ", get " << funcs.size() << " LoweredFunc:\n"; + + op_func_arg_tensors->clear(); + for (int idx = 0; idx < pack.size() - 1; ++idx) { + CHECK(pack[idx].is_tensor()); + op_func_arg_tensors->push_back(pack[idx].operator ir::Expr().as_tensor_ref()); } - poly::StageMap stages = pack.back(); - auto func = lang::LowerVec(group->GetFuncName(), stages, inputs, {}, {}, nullptr, this->target_, true); + return funcs; +} - if (apply_impl_schedule) { - std::vector schedule_inputs; - // collect tensor - for (int idx = 0; idx < pack.size() - 1; ++idx) { - CHECK(pack[idx].is_tensor()); - schedule_inputs.push_back(common::CINNValue(pack[idx])); - } - for (auto& f : func) { - schedule_inputs.push_back(common::CINNValue(f->body)); - } - // do ast tree schedule - common::CINNValuePack expr_pack = impl->fschedule(common::CINNValuePack{schedule_inputs}); - - ir::Expr func_body = expr_pack[0]; - std::vector input_output_nodes(group->input_names); - input_output_nodes.insert(input_output_nodes.end(), group->output_names.begin(), group->output_names.end()); - VLOG(6) << "func.size() = " << func.size() << ", expr_pack.size() = " << expr_pack.size(); - VLOG(6) << "args.size() = " << args.size() << ", input_output_nodes.size() = " << input_output_nodes.size(); - if (args.size() > input_output_nodes.size()) { - args = lang::GetArgs(func_body, input_output_nodes); - } - std::vector res; - for (int i = 0; i < expr_pack.size(); i++) { - ir::Expr func_body = expr_pack[0]; -#ifdef CINN_WITH_CUDA - optim::OptimizeExprGPU(&(func_body)); -#endif - auto temp_buffers = lang::GetTempBuffers(inputs, stages, func_body); - auto function = ir::_LoweredFunc_::Make(group->GetFuncName(), args, func_body, temp_buffers); - res.push_back(function); - } - for (auto& i : res) { - i = optim::Optimize(Expr(i), target_, false).as_lowered_func_ref(); - } - return res; - } else { - for (auto& f : func) { -#ifdef CINN_WITH_CUDA - optim::OptimizeExprGPU(&(f->body)); -#endif - f = optim::Optimize(Expr(f), target_, false).as_lowered_func_ref(); - } - return func; +ir::Expr OpLowerer::DoOpSchedule(std::shared_ptr op_impl, + const std::vector& op_func_arg_tensors, + const std::vector& lowered_funcs) { + std::vector schedule_inputs; + // collect tensors + for (const ir::Tensor& op_func_arg_tensor : op_func_arg_tensors) { + schedule_inputs.push_back(common::CINNValue(op_func_arg_tensor)); } + // collect bodies to be scheduled + for (const ir::LoweredFunc& func : lowered_funcs) { + schedule_inputs.push_back(common::CINNValue(func->body)); + } + // do schedule on AST + common::CINNValuePack expr_pack = op_impl->fschedule(common::CINNValuePack{schedule_inputs}); + + return expr_pack[0].operator ir::Expr(); } // group schedule -void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, - const GroupPtr& group, - const std::unordered_map& tensor_map) { +ir::Expr OpLowerer::DoGroupSchedule(ir::IRSchedule& ir_sch, + const GroupPtr& group, + const std::unordered_map& tensor_map) { // topological order. auto nodes_set = group->NodeSet(); auto v_consumers = BuildVirtualConsumer(group, this->shape_dict_); @@ -745,6 +479,7 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); SyncThreadWithShared(ir_sch, group, nodes_inline, nodes_set, this->shape_dict_, tensor_map); VLOG(4) << "After IRSchedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + return ir_sch.GetModule().GetExprs().at(0); } } // namespace framework diff --git a/cinn/hlir/framework/op_lowering.h b/cinn/hlir/framework/op_lowering.h index c947c982f3..e7258c13d1 100644 --- a/cinn/hlir/framework/op_lowering.h +++ b/cinn/hlir/framework/op_lowering.h @@ -39,45 +39,60 @@ using GroupPtr = std::shared_ptr; using common::Target; class OpLowerer; -typedef std::vector (OpLowerer::*IRComputeFunction)(poly::StageMap&, - std::vector&, - std::unordered_map&, - const GroupPtr&, - const GroupPtr&, - bool); + +typedef bool (OpLowerer::*ScheduleDetermineFunction)(Node*); class OpLowerer { public: OpLowerer(const absl::flat_hash_map&, const absl::flat_hash_map&, const Target&); - std::vector Lower(GroupPtr& group); - std::vector LowerWithoutSchedule(GroupPtr& group); - - private: - std::vector IRLowerOp(IRComputeFunction, GroupPtr&); - std::vector IRLowerNonFusibleOp(GroupPtr&, bool); - std::vector IRLowerOpWithoutSchedule(IRComputeFunction, GroupPtr&); - - std::vector LowerGroup(IRComputeFunction, GroupPtr&, bool); -#define DEFINE_IR_COMPUTE(type) \ - std::vector IR##type##Compute(poly::StageMap& stages, \ - std::vector& func_args, \ - std::unordered_map& tensor_map, \ - const GroupPtr& group, \ - const GroupPtr& sub_group, \ - bool apply_impl_schedule = false); + /** + * @brief Lower a group to CINN IR + * @param apply_op_schedule Whether to schedule at Op level. + * @param apply_group_schedule Whether to schedule at group level. + */ + std::vector Lower(GroupPtr& group, bool apply_op_schedule = true, bool apply_group_schedule = true); - // compute and schedule - DEFINE_IR_COMPUTE(Elementwise); - DEFINE_IR_COMPUTE(Reduce); - DEFINE_IR_COMPUTE(OutEWiseFusable); - - void IRSchedule(ir::IRSchedule& ir_sch, - const GroupPtr& group, - const std::unordered_map& tensor_map); + private: + std::vector LowerGroup(GroupPtr& group, + bool apply_op_schedule, + bool apply_group_schedule, + ScheduleDetermineFunction schedule_determine_func); + + std::vector LowerCustomCall(GroupPtr& group); + + std::vector PostProcess(ir::IRSchedule* ir_sch, + const GroupPtr& group, + const std::unordered_map& tensor_map, + std::vector* group_func_arg_tensors, + bool done_op_schedule); + + std::vector LowerOps(const std::vector& nodes, + std::vector* group_func_arg_tensors, + std::unordered_map* tensor_map, + bool apply_op_schedule, + ScheduleDetermineFunction schedule_determine_func); + + std::vector DoOpLower(Node* node, + std::shared_ptr op_impl, + std::unordered_map* tensor_map, + std::vector* op_func_arg_tensors); + + ir::Expr DoOpSchedule(std::shared_ptr op_impl, + const std::vector& op_func_arg_tensors, + const std::vector& lowered_funcs); + + ir::Expr DoGroupSchedule(ir::IRSchedule& ir_sch, + const GroupPtr& group, + const std::unordered_map& tensor_map); + + inline bool ReduceScheduleDetermineFunction(Node* node); + inline bool ElementwiseScheduleDetermineFunction(Node* node); + inline bool NonFusibleScheduleDetermineFunction(Node* node); + private: Target target_; const absl::flat_hash_map& type_dict_; const absl::flat_hash_map& shape_dict_; diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index 8220c58600..1f61b651c6 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -77,8 +77,8 @@ ir::Tensor GetTensor(const NodeData* node_data, } std::vector CollectInputTensor(const Node* node, - std::vector& func_args, - std::unordered_map& tensor_map, + std::vector* func_args, + std::unordered_map* tensor_map, const absl::flat_hash_map& type_dict, const absl::flat_hash_map& shape_dict) { std::vector tensors; @@ -86,10 +86,10 @@ std::vector CollectInputTensor(const Node* node, for (auto& node_data : GetInputNodeData(node)) { CHECK(node_data); auto tensor = GetTensor(node_data, type_dict, shape_dict); - if (!tensor_map.count(node_data->id())) { - tensor_map[node_data->id()] = tensor; + if (!tensor_map->count(node_data->id())) { + (*tensor_map)[node_data->id()] = tensor; // record func input args - func_args.push_back(tensor); + func_args->push_back(tensor); } tensors.push_back(tensor); } diff --git a/cinn/hlir/framework/op_lowering_util.h b/cinn/hlir/framework/op_lowering_util.h index f081411ec0..33cdd45b2d 100644 --- a/cinn/hlir/framework/op_lowering_util.h +++ b/cinn/hlir/framework/op_lowering_util.h @@ -29,8 +29,8 @@ ir::Tensor GetTensor(const NodeData* node_data, const absl::flat_hash_map& shape_dict); std::vector CollectInputTensor(const Node* node, - std::vector& func_args, - std::unordered_map& tensor_map, + std::vector* func_args, + std::unordered_map* tensor_map, const absl::flat_hash_map& type_dict, const absl::flat_hash_map& shape_dict); From a9d4627ba20763109b2a8172f3685bcad672c634 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Mon, 26 Jun 2023 14:51:15 +0800 Subject: [PATCH 04/11] Add some annotation for OpLowerer --- cinn/hlir/framework/op_lowering.cc | 19 ++++++---- cinn/hlir/framework/op_lowering.h | 59 +++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 1c86fc4e76..e4244fa686 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -170,7 +170,7 @@ std::vector OpLowerer::PostProcess(ir::IRSchedule* ir_sch, const std::unordered_map& tensor_map, std::vector* group_func_arg_tensors, bool done_op_schedule) { - // function args + // 1.Prepare function args group->input_names.clear(); std::vector group_func_args; for (auto& arg_tensor : *group_func_arg_tensors) { @@ -217,11 +217,14 @@ std::vector OpLowerer::PostProcess(ir::IRSchedule* ir_sch, optim::OptimizeExprGPU(&(func_body)); #endif + // 2.Prepare temp buffers poly::StageMap stages; auto temp_buffers = lang::GetTempBuffers(*group_func_arg_tensors, stages, func_body); - auto func = ir::_LoweredFunc_::Make( + // 3.Building LoweredFunc + auto func = ir::_LoweredFunc_::Make( group->GetFuncName(), group_func_args, ir_sch->GetModule().GetExprs().at(0), temp_buffers); func->PrepareBufferCastExprs(); + // 4.Apply low level pass func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); return {func}; } @@ -234,7 +237,7 @@ std::vector OpLowerer::LowerOps(const std::vector& nodes, auto& strategy = Operator::GetAttrs("CINNStrategy"); std::vector func_bodies; for (Node* node : nodes) { - // 1. select Op impl + // 1.Select Op impl std::vector out_types; std::vector> out_shapes; std::vector node_datas = GetAllNodeData(node); @@ -247,11 +250,11 @@ std::vector OpLowerer::LowerOps(const std::vector& nodes, auto op_impl = OpStrategy::SelectImpl( strategy[node->op()](node->attrs, op_func_arg_tensors, out_types, out_shapes, this->target_)); - // 2. perform the lower process of Op + // 2.Perform the lower process of Op std::vector funcs = DoOpLower(node, op_impl, tensor_map, &op_func_arg_tensors); if (apply_op_schedule && (this->*schedule_determine_func)(node)) { - // 3. perform the schedule of Op + // 3.Perform the schedule of Op func_bodies.push_back(DoOpSchedule(op_impl, op_func_arg_tensors, funcs)); } else { for (const ir::LoweredFunc& func : funcs) { @@ -322,15 +325,15 @@ ir::Expr OpLowerer::DoOpSchedule(std::shared_ptr op_imp const std::vector& op_func_arg_tensors, const std::vector& lowered_funcs) { std::vector schedule_inputs; - // collect tensors + // 1.Collect tensors for (const ir::Tensor& op_func_arg_tensor : op_func_arg_tensors) { schedule_inputs.push_back(common::CINNValue(op_func_arg_tensor)); } - // collect bodies to be scheduled + // 2.Collect bodies to be scheduled for (const ir::LoweredFunc& func : lowered_funcs) { schedule_inputs.push_back(common::CINNValue(func->body)); } - // do schedule on AST + // 3.Do schedule on AST common::CINNValuePack expr_pack = op_impl->fschedule(common::CINNValuePack{schedule_inputs}); return expr_pack[0].operator ir::Expr(); diff --git a/cinn/hlir/framework/op_lowering.h b/cinn/hlir/framework/op_lowering.h index e7258c13d1..9956d2dc83 100644 --- a/cinn/hlir/framework/op_lowering.h +++ b/cinn/hlir/framework/op_lowering.h @@ -49,45 +49,102 @@ class OpLowerer { const Target&); /** - * @brief Lower a group to CINN IR + * @brief Lower a group to CINN IR. + * @param group The group to be lowered. * @param apply_op_schedule Whether to schedule at Op level. * @param apply_group_schedule Whether to schedule at group level. + * @return The lowered funcs. */ std::vector Lower(GroupPtr& group, bool apply_op_schedule = true, bool apply_group_schedule = true); private: + /** + * @brief Lower a group to CINN IR. + * @param group The group to be lowered. + * @param apply_op_schedule Whether to schedule at Op level. + * @param apply_group_schedule Whether to schedule at group level. + * @param schedule_determine_func Function used to determine which Ops to schedule. + * @return The lowered funcs. + */ std::vector LowerGroup(GroupPtr& group, bool apply_op_schedule, bool apply_group_schedule, ScheduleDetermineFunction schedule_determine_func); + /** + * @brief Lower a group composed of CustomCall Op. + * @param group The group to be lowered. + * @return The lowered funcs. + */ std::vector LowerCustomCall(GroupPtr& group); + /** + * @brief Post processing, including preparing function args and temporary variables, + * applying low-level optimization passes, etc. + * @param group The group to be lowered. + * @param tensor_map All tensors used for calculating the group. + * @param group_func_arg_tensors Tensors used as the group function arguments. + * @param done_op_schedule Mark whether the Op level schedule has been applied. + * @return The lowered funcs after the post processing. + */ std::vector PostProcess(ir::IRSchedule* ir_sch, const GroupPtr& group, const std::unordered_map& tensor_map, std::vector* group_func_arg_tensors, bool done_op_schedule); + /** + * @brief Lower an Op set to CINN IR. + * Compute, Lower and optional Schedule will be performed one by one for each Op. + * @param nodes The Op nodes to be lowered. + * @param group_func_arg_tensors Tensors used as the group function arguments. + * @param tensor_map All tensors used for calculating the group. + * @param apply_op_schedule Whether to schedule at Op level. + * @param schedule_determine_func Function used to determine which Ops to schedule. + * @return The lowered func bodies of Op set. + */ std::vector LowerOps(const std::vector& nodes, std::vector* group_func_arg_tensors, std::unordered_map* tensor_map, bool apply_op_schedule, ScheduleDetermineFunction schedule_determine_func); + /** + * @brief Lower an Op to CINN IR. The Compute and Lower processes will be called sequentially. + * @param node The Op node to be lowered. + * @param op_impl The Op implementation defining Compute and Schedule. + * @param tensor_map All tensors used for calculating the group. + * @param op_func_arg_tensors Tensors used as the Op function arguments. + * @return The lowered func bodies of the Op node. + */ std::vector DoOpLower(Node* node, std::shared_ptr op_impl, std::unordered_map* tensor_map, std::vector* op_func_arg_tensors); + /** + * @brief Apply schedule on an Op. + * @param op_impl The Op implementation defining Compute and Schedule. + * @param op_func_arg_tensors Tensors used as the Op function arguments. + * @param lowered_funcs The lowered funcs of an Op to be scheduled. + * @return The lowered func body after schedule of the Op. + */ ir::Expr DoOpSchedule(std::shared_ptr op_impl, const std::vector& op_func_arg_tensors, const std::vector& lowered_funcs); + /** + * @brief Apply schedule on a group. + * @param ir_sch The IRSchedule containing the entire group's lowered func bodies. + * @param group The group to be scheduled. + * @param tensor_map All tensors used for calculating the group. + * @return The lowered func body after schedule of the group. + */ ir::Expr DoGroupSchedule(ir::IRSchedule& ir_sch, const GroupPtr& group, const std::unordered_map& tensor_map); + // Functions used to determine which Ops to schedule at op level, define a policy for each type of group. inline bool ReduceScheduleDetermineFunction(Node* node); inline bool ElementwiseScheduleDetermineFunction(Node* node); inline bool NonFusibleScheduleDetermineFunction(Node* node); From bf07739cf40d73b0ac73e5df64790857c36f275c Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 27 Jun 2023 15:12:13 +0800 Subject: [PATCH 05/11] fix mlt unittest --- .../auto_gen_rule/multi_level_tiling_test.cc | 121 +++++++++--------- 1 file changed, 59 insertions(+), 62 deletions(-) diff --git a/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc b/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc index 91ddf361da..47eabab56d 100644 --- a/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc +++ b/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc @@ -361,9 +361,9 @@ TEST_F(TestMultiLevelTiling, ReduceSum) { TEST_F(TestMultiLevelTiling, Pool2d) { default_input_names = {"input"}; - default_output_names = {"var_0"}; - std::vector input_shape{2, 8, 16, 16}; - std::vector output_shape{2, 8, 8, 8}; + default_output_names = {"var_0", "pad_temp_0"}; + std::vector> input_shapes{{2, 8, 16, 16}}; + std::vector> output_shapes{{2, 8, 8, 8}, {2, 8, 18, 18}}; std::string pooling_type = "max"; std::vector ksize{3, 3}; std::vector strides{2, 2}; @@ -374,7 +374,7 @@ TEST_F(TestMultiLevelTiling, Pool2d) { std::string data_format = "NCHW"; bool adaptive = false; std::string padding_algorithm = "EXPLICIT"; - frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build({{"input", input_shape}}, + frontend::Program pool2d_program = tests::OpBuilder("pool2d").Build({{"input", input_shapes[0]}}, {{"pool_type", pooling_type}, {"kernel_size", ksize}, {"stride_size", strides}, @@ -411,85 +411,82 @@ TEST_F(TestMultiLevelTiling, Pool2d) { { ScheduleBlock(root) { - serial for (i, 0, 2) { - serial for (j, 0, 8) + serial for (i, 0, 2) { - serial for (k, 0, 18) + serial for (j, 0, 8) { - serial for (a, 0, 18) + serial for (k, 0, 18) { - ScheduleBlock(pad_temp_0) + serial for (a, 0, 18) { - i0, i1, i2, i3 = axis.bind(i, j, k, a) - pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f) + ScheduleBlock(pad_temp_0) + { + i0, i1, i2, i3 = axis.bind(i, j, k, a) + { + pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f) + } + } } } } } - } - } -} -} // end Expr 0 -Expr 1 { -{ - ScheduleBlock(root_0) - { - { - thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16) { - thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4) + thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16) { - serial for (i_1, 0, 1) + thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4) { - serial for (j_1, 0, 4) + serial for (i_1, 0, 1) { - serial for (k_1, 0, 1) + serial for (j_1, 0, 4) { - serial for (a_1, 0, 4) + serial for (k_1, 0, 1) { - ScheduleBlock(var_0__reduce_init) + serial for (a_1, 0, 4) { - i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)) + ScheduleBlock(var_0__reduce_init) { - var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f + i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)) + { + var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f + } } } } } } - } - { - serial for (kernel_idx, 0, 3) { - serial for (kernel_idx_0, 0, 3) + serial for (kernel_idx, 0, 3) { - serial for (ax0_ax1_ax2_ax3_fused, 0, 28) + serial for (kernel_idx_0, 0, 3) { - ScheduleBlock(pad_temp_0_shared_temp_buffer) + serial for (ax0_ax1_ax2_ax3_fused, 0, 28) { - v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0))) - attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0) + ScheduleBlock(pad_temp_0_shared_temp_buffer) { - pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3] + v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0))) + attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0) + { + pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3] + } } } - } - serial for (i_1, 0, 1) - { - serial for (j_1, 0, 4) + serial for (i_1, 0, 1) { - serial for (k_1, 0, 1) + serial for (j_1, 0, 4) { - serial for (a_1, 0, 4) + serial for (k_1, 0, 1) { - ScheduleBlock(var_0_local_temp_buffer) + serial for (a_1, 0, 4) { - i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0) - read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)]) - write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)]) + ScheduleBlock(var_0_local_temp_buffer) { - var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))]) + i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0) + read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)]) + write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)]) + { + var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))]) + } } } } @@ -497,21 +494,21 @@ Expr 1 { } } } - } - serial for (ax0_0, 0, 1) - { - serial for (ax1_0, 0, 4) + serial for (ax0_0, 0, 1) { - serial for (ax2_0, 0, 1) + serial for (ax1_0, 0, 4) { - serial for (ax3_0, 0, 4) + serial for (ax2_0, 0, 1) { - ScheduleBlock(var_0) + serial for (ax3_0, 0, 4) { - v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0)) - attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0) + ScheduleBlock(var_0) { - var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3] + v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0)) + attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0) + { + var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3] + } } } } @@ -524,7 +521,7 @@ Expr 1 { } } } -} // end Expr 1 +} // end Expr 0 )ROC"; ASSERT_EQ(ir, expected_ir); @@ -539,8 +536,8 @@ Expr 1 { BuildIRModule(MakeIRSchedule(pool2d_program, fixed_rand_seed, /* apply_manual_schedule*/ true))), default_input_names, default_output_names, - {input_shape}, - {output_shape}, + input_shapes, + output_shapes, target_); } From 4ca825124783c4329c4295e4cd01a4d3020a051f Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 27 Jun 2023 15:23:29 +0800 Subject: [PATCH 06/11] polish annotations --- cinn/hlir/framework/op_lowering.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index e4244fa686..fc5e2f1778 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -296,7 +296,7 @@ std::vector OpLowerer::DoOpLower(Node* node, post = "_" + std::to_string(idx); } else { // If the number of output tensors defined by Compute is same with the output node_data on the graph, then there - // is a one-to-one correspondence. CHECK_EQ(node_datas[idx]->id(), expr.as_tensor_ref()->name); + // is a one-to-one correspondence. (*tensor_map)[node_datas[idx]->id()] = expr.as_tensor_ref(); } @@ -324,6 +324,7 @@ std::vector OpLowerer::DoOpLower(Node* node, ir::Expr OpLowerer::DoOpSchedule(std::shared_ptr op_impl, const std::vector& op_func_arg_tensors, const std::vector& lowered_funcs) { + VLOG(4) << "Do op schedule"; std::vector schedule_inputs; // 1.Collect tensors for (const ir::Tensor& op_func_arg_tensor : op_func_arg_tensors) { @@ -335,6 +336,7 @@ ir::Expr OpLowerer::DoOpSchedule(std::shared_ptr op_imp } // 3.Do schedule on AST common::CINNValuePack expr_pack = op_impl->fschedule(common::CINNValuePack{schedule_inputs}); + VLOG(4) << "After op schedule: " << expr_pack[0].operator ir::Expr(); return expr_pack[0].operator ir::Expr(); } From bae73d418500f11aa5e335262394f30128111776 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 27 Jun 2023 15:28:31 +0800 Subject: [PATCH 07/11] fix duplicate parameter --- cinn/hlir/framework/op_lowering.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index fc5e2f1778..4e95fd413e 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -202,7 +202,7 @@ std::vector OpLowerer::PostProcess(ir::IRSchedule* ir_sch, } for (auto& tensor_pair : tensor_map) { - if (args_set.count("_" + tensor_pair.first)) { + if (args_set.count("_" + tensor_pair.second->name)) { continue; } group_func_arg_tensors->push_back(tensor_pair.second); From 697c4fd9729e75db49343c2616809f824266ea79 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 27 Jun 2023 17:21:00 +0800 Subject: [PATCH 08/11] fix args not match node_data bug --- cinn/hlir/framework/op_lowering.cc | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 4e95fd413e..4466296f3d 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -173,11 +173,13 @@ std::vector OpLowerer::PostProcess(ir::IRSchedule* ir_sch, // 1.Prepare function args group->input_names.clear(); std::vector group_func_args; + std::unordered_set arg_name_set; for (auto& arg_tensor : *group_func_arg_tensors) { // input node data name. group->input_names.push_back(arg_tensor->name); // input args group_func_args.emplace_back(arg_tensor->buffer, ir::Argument::IO::kInput); + arg_name_set.insert(arg_tensor->buffer->name); } group->output_names.clear(); @@ -186,12 +188,19 @@ std::vector OpLowerer::PostProcess(ir::IRSchedule* ir_sch, for (auto node_data : GetAllNodeData(node)) { std::string output_node_data_name = node_data->id(); group->output_names.push_back(output_node_data_name); - CHECK(tensor_map.count(output_node_data_name)) << "Can't find output tensor " << output_node_data_name; + // CHECK(tensor_map.count(output_node_data_name)) << "Can't find output tensor " << output_node_data_name; + if (tensor_map.count(output_node_data_name) == 0) { + continue; + } auto tensor = tensor_map.at(output_node_data_name); + if (arg_name_set.count(tensor->buffer->name) != 0) { + continue; + } // output arg tensors group_func_arg_tensors->push_back(tensor); // output args group_func_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); + arg_name_set.insert(tensor->buffer->name); } } @@ -295,8 +304,8 @@ std::vector OpLowerer::DoOpLower(Node* node, (*tensor_map)[node_datas[0]->id() + post] = expr.as_tensor_ref(); post = "_" + std::to_string(idx); } else { - // If the number of output tensors defined by Compute is same with the output node_data on the graph, then there - // is a one-to-one correspondence. + // If the number of output tensors defined by Compute is less equal than the output node_data on the graph, + // then there is a one-to-one correspondence, and the redundant output node_data contact empty. (*tensor_map)[node_datas[idx]->id()] = expr.as_tensor_ref(); } From cb5ae68769d58ffdb4b9e78dfa2237661c064506 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 27 Jun 2023 19:58:34 +0800 Subject: [PATCH 09/11] fix logic to determin group schedule --- cinn/hlir/framework/op_lowering.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 4466296f3d..e4617b2d7d 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -104,7 +104,7 @@ std::vector OpLowerer::LowerGroup(GroupPtr& group, ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - if (apply_group_schedule && nodes.size() > 1) { + if (apply_group_schedule && func_bodies.size() > 1) { DoGroupSchedule(ir_sch, group, tensor_map); VLOG(3) << "After group schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); } From 3b5e8c339b14405c5419bb867a3f9381382ba525 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 27 Jun 2023 21:18:15 +0800 Subject: [PATCH 10/11] fix logic to determine group schedule --- cinn/hlir/framework/op_lowering.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index e4617b2d7d..3b931b383e 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -104,7 +104,7 @@ std::vector OpLowerer::LowerGroup(GroupPtr& group, ir::IRSchedule ir_sch(mod_expr); ir_sch.MergeExprs(); VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - if (apply_group_schedule && func_bodies.size() > 1) { + if (apply_group_schedule) { DoGroupSchedule(ir_sch, group, tensor_map); VLOG(3) << "After group schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); } From a30cecd04a40f9c908257ec50b9e136d790ca494 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Wed, 28 Jun 2023 13:06:22 +0800 Subject: [PATCH 11/11] fix x86 segfault --- cinn/hlir/framework/op_lowering.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 3b931b383e..e2796190e5 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -232,7 +232,9 @@ std::vector OpLowerer::PostProcess(ir::IRSchedule* ir_sch, // 3.Building LoweredFunc auto func = ir::_LoweredFunc_::Make( group->GetFuncName(), group_func_args, ir_sch->GetModule().GetExprs().at(0), temp_buffers); - func->PrepareBufferCastExprs(); + if (!done_op_schedule) { + func->PrepareBufferCastExprs(); + } // 4.Apply low level pass func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); return {func};