From 59403651fbeac30785522ecf1872708ce87019fa Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 25 Apr 2023 13:21:32 +0000 Subject: [PATCH 01/12] feat(fuse): add reduce group fuse with broadcast --- cinn/hlir/framework/op_lowering.cc | 2 +- cinn/hlir/framework/op_lowering_util.cc | 64 ++++++++------- cinn/hlir/framework/op_lowering_util.h | 3 +- cinn/hlir/pass/fusion_helper_base.h | 11 +++ cinn/hlir/pass/fusion_merge_pass.cc | 22 +++-- cinn/hlir/pass/fusion_merge_pass_util.h | 105 ++++++++++++++++++++++++ 6 files changed, 171 insertions(+), 36 deletions(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index aa6f3143da..b3b33f6d5f 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -1326,7 +1326,7 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, } VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map); + SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map, group); VLOG(4) << "After IRSchedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); } diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index 22bd056f14..1a2edfabea 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -593,10 +593,13 @@ bool CanbeInline(Node* node, return false; } else { auto node_shape = GetOutputShape(node, shape_dict); - auto last_shape = GetOutputShape(laster, shape_dict); - if (std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) != - std::accumulate(last_shape.begin(), last_shape.end(), 1, std::multiplies())) { - return true; + auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + for (auto consumer : consumers) { + auto consumer_shape = GetOutputShape(consumer, shape_dict); + auto consumer_size = std::accumulate(consumer_shape.begin(), consumer_shape.end(), 1, std::multiplies()); + if (node_size != consumer_size) { + return true; + } } return false; @@ -1228,7 +1231,7 @@ void LoopComputeAt(ir::IRSchedule& ir_sch, auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); if (!group->output_nodes.count(node)) { auto block = ir_sch.GetBlock(GetNodeData(node)->id()); - ir_sch.SetBuffer(block, "local", true); + ir_sch.SetBuffer(block, "local"); } if (op_pattern_dict[node->op()] == framework::kReduction) { @@ -1285,11 +1288,14 @@ std::unordered_map GetNodeDataSet(const std::unordered_s return node_data_set; } -Node* GetMaster(Node* node, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set) { +std::unordered_set GetMasters(Node* node, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set) { // find consumer std::unordered_set visited; std::queue candidates; candidates.push(node); + std::unordered_set masters; while (!candidates.empty()) { auto candidate = candidates.front(); @@ -1304,19 +1310,20 @@ Node* GetMaster(Node* node, const std::unordered_set& nodes_inline, const candidates.push(consumer); visited.insert(consumer); } else { - return consumer; + masters.insert(consumer); } } } - return nullptr; + return masters; } void SyncThreadWithShared(ir::IRSchedule& ir_sch, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map) { + const std::unordered_map& tensor_map, + const GroupPtr& group) { auto exprs_inorder = ir_sch.GetAllBlocks(); auto node_data_set = GetNodeDataSet(nodes_set); auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); @@ -1353,34 +1360,35 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, auto node = node_data->source_node.get(); auto node_shape = shape_dict.at(node_data->id()); - auto master = GetMaster(node, nodes_inline, nodes_set); - if (!master) { + auto masters = GetMasters(node, nodes_inline, nodes_set); + if (masters.empty()) { continue; } - auto master_data = GetNodeData(master); - auto master_shape = shape_dict.at(master_data->id()); - if (op_pattern_dict[master->op()] == framework::kReduction) { - master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); - } + bool do_set_buffer_to_shared = false; + for (auto master : masters) { + auto master_data = GetNodeData(master); + auto master_shape = shape_dict.at(master_data->id()); + if (op_pattern_dict[master->op()] == framework::kReduction) { + master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); + } - auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); - auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); + auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); - if (node_size == master_size) { - continue; + if (node_size != master_size) { + if (check_sync_mark(idx, master_data->id())) { + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SyncThreads(loops.back(), false); + sync_mark.insert(master_data->id()); + } + do_set_buffer_to_shared = true; + } } - - { + if (do_set_buffer_to_shared && group->output_nodes.find(node) == group->output_nodes.end()) { auto block = ir_sch.GetBlock(node_data->id()); ir_sch.SetBuffer(block, "shared", true); } - - if (check_sync_mark(idx, master_data->id())) { - auto loops = ir_sch.GetLoops(master_data->id()); - ir_sch.SyncThreads(loops.back(), false); - sync_mark.insert(master_data->id()); - } } } diff --git a/cinn/hlir/framework/op_lowering_util.h b/cinn/hlir/framework/op_lowering_util.h index 03e18912e4..f5e7267f43 100644 --- a/cinn/hlir/framework/op_lowering_util.h +++ b/cinn/hlir/framework/op_lowering_util.h @@ -86,7 +86,8 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map); + const std::unordered_map& tensor_map, + const GroupPtr& group); } // namespace framework } // namespace hlir diff --git a/cinn/hlir/pass/fusion_helper_base.h b/cinn/hlir/pass/fusion_helper_base.h index d4a0fc56ce..09e1bac7a9 100644 --- a/cinn/hlir/pass/fusion_helper_base.h +++ b/cinn/hlir/pass/fusion_helper_base.h @@ -112,6 +112,17 @@ class FusionHelperBase { return producer_node; } + std::vector GetConsumerNode(const Node* node) const { + std::vector consumer_nodes; + auto node_data = GetNodeData(node); + for (auto& link : node_data->outlinks()) { + auto consumer = link->sink()->safe_as(); + CHECK(consumer); + consumer_nodes.push_back(consumer); + } + return consumer_nodes; + } + bool WithoutLastDimInReduce(const std::vector& inshape, const std::vector& axes) const { // if last axis is in reduce. if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 88f54dc566..148c82946b 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -43,7 +43,7 @@ using ConditionFunction = std::functionoutputs) { fusion_groups_ = graph->fusion_groups; // init fusion relation. InitFusionRelation(); @@ -56,6 +56,7 @@ class FusionMergePassHelper : public FusionHelperBase { GroupList operator()() { // run fusion merge untill no update. DoFusionMerge(); + AddGlobalOutputNodesToGroups(); for (auto& group : fusion_groups_) { VLOG(3) << "Fusion Group -> " << group->group_id; for (auto& sub_group : group->fused_sub_groups) { @@ -72,6 +73,18 @@ class FusionMergePassHelper : public FusionHelperBase { } private: + void AddGlobalOutputNodesToGroups() { + for (auto group : fusion_groups_) { + for (const auto& output_node_data : graph_output_node_data_) { + Node* node = output_node_data->source_node.get(); + std::unordered_set node_set = group->NodeSet(); + if (node_set.find(node) != node_set.end()) { + group->output_nodes.insert(node); + } + } + } + } + void DoFusionMerge() { VLOG(3) << "DoFusionMerge...!"; while (DoHorizontalFusion()) { @@ -617,10 +630,6 @@ class FusionMergePassHelper : public FusionHelperBase { void RecomputeWithCostModel(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { - if (producer->op_pattern_kind == framework::kReduction) { - CHECK_EQ(fusionable_consumers.size(), 1) << "Find more than one consumer can fuse to " << producer->group_id; - } - // if is const op if (is_const_group(this, producer)) { std::unordered_set candidates; @@ -976,7 +985,7 @@ class FusionMergePassHelper : public FusionHelperBase { relation.vertical_relation = {// reduce and elementwise can be horizontal/vertical relation. {OpPatternKind::kElementWise, reduce_fuse_elementwise}, // reduce and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, + {OpPatternKind::kBroadcast, reduce_fuse_broadcast}, // reduce and injective op must be horizontal relation. {OpPatternKind::kInjective, horizontal_with_injective}, // reduce and reduce must be horizontal relation. @@ -985,6 +994,7 @@ class FusionMergePassHelper : public FusionHelperBase { } GroupList fusion_groups_; + const std::vector& graph_output_node_data_; std::unordered_map fusion_groups_index_; std::unordered_map> input_to_consumers_; diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index a37126e764..fa7305e2b8 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -277,6 +277,111 @@ CONDITION_FUNC(injective_horizontal_with_reduce) { return elementwise_fuse_reduce(helper, first, second); } +CONDITION_FUNC(reduce_fuse_broadcast) { + if (helper->target_ == common::DefaultHostTarget()) { + return true; + } + // if same shape with horizontal relation + if (is_same_size(helper, first, second)) { + return true; + } + + // Traversing all reducers in all producers requires two types of conditions to be met. + // The first type is the condition that the reducer itself needs to meet, + // and the second type is the condition that the relationship between each reducer and its consumers with type of + // Broadcast needs to meet. It is required that each consumer of type Broadcast meet the same shape after broadcast as + // before reduce. + for (auto& node_in_master : first->master_nodes) { + if (helper->GetOpKind(node_in_master) != OpPatternKind::kReduction) { + continue; + } + Node* reducer = node_in_master; + // First type conditions + // Get some reduce infomation + auto reducer_input_shape = helper->GetNodeInputShape(reducer); + auto reducer_output_shape = helper->GetNodeDataShape(reducer); + auto reduce_axes = absl::get>(reducer->attrs.attr_store.at("dim")); + auto keep_dim = absl::get(reducer->attrs.attr_store.at("keep_dim")); + for (auto& axis : reduce_axes) { + if (axis == -1) { + axis = reducer_input_shape.size() - 1; + } + } + // Check if the reduce axes are continuous + int reduce_size = reducer_input_shape.back(); + for (auto idx = reduce_axes.size() - 1; idx >= 1; --idx) { + if (reduce_axes[idx] != reduce_axes[idx - 1] + 1) { + return false; + } + reduce_size *= reducer_input_shape[idx - 1]; + } + // Check if the reduce size exceeds the hardware limit + if (reduce_size > helper->target_.max_num_threads()) { + return false; + } + + // Second type conditions + // Find directly or indirectly consumers with type of Broadcast in the second group + auto find_broadcasters_in_descendants = [&](const Node* producer) -> std::unordered_set { + std::queue candidates; + std::unordered_set visited_set; + std::unordered_set broadcasters; + candidates.push(producer); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto consumer : helper->GetConsumerNode(candidate)) { + if (helper->GetOpKind(consumer) == OpPatternKind::kBroadcast && + second->NodeSet().find(consumer) != second->NodeSet().end()) { + broadcasters.insert(consumer); + } else if (!visited_set.count(consumer)) { + visited_set.insert(consumer); + candidates.push(consumer); + } + } + } + + return broadcasters; + }; + + // Check if each broadcast node meets the conditions + std::unordered_set broadcasters_in_consumers = find_broadcasters_in_descendants(reducer); + for (auto broadcaster : broadcasters_in_consumers) { + auto broadcaster_output_shape = absl::get>(broadcaster->attrs.attr_store.at("out_shape")); + auto broadcast_axes = absl::get>(broadcaster->attrs.attr_store.at("broadcast_axes")); + for (auto& axis : broadcast_axes) { + if (axis == -1) { + axis = broadcaster_output_shape.size() - 1; + } + } + + if (reducer_input_shape != broadcaster_output_shape) { + return false; + } + + if (keep_dim) { + continue; + } else { + // if reducer_output_shape = [1] + if (reducer_output_shape.size() == 1 && reducer_output_shape[0] == 1) { + continue; + } + // check union [reduce_axes, broadcast_axes] = reducer_input_shape + for (int idx = 0; idx < reducer_input_shape.size(); ++idx) { + if (!(std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) == broadcast_axes.end()) ^ + std::find(reduce_axes.begin(), reduce_axes.end(), idx) == reduce_axes.end()) { + return false; + } + } + } + } + } + + return true; +} + CONDITION_FUNC(reduce_fuse_reduce) { // check reduce horizontal with reduce. if (!horizontal_relation(helper, first, second, framework::OpPatternKind::kReduction)) { From c15f034a9647cdeff3057f14ff434d6c6df3a2b0 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Wed, 26 Apr 2023 17:04:02 +0800 Subject: [PATCH 02/12] fix(fuse): fix mean_op unittests --- cinn/hlir/framework/op_lowering_util.cc | 7 +++---- cinn/hlir/pass/fusion_merge_pass.cc | 16 +--------------- cinn/hlir/pass/fusion_merge_pass_util.h | 5 +---- 3 files changed, 5 insertions(+), 23 deletions(-) diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index 1a2edfabea..53dc2735eb 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -165,11 +165,10 @@ bool IsConstOp(const framework::Node* node) { } std::vector GetInputShape(const Node* node, const absl::flat_hash_map& shape_dict) { - auto producers = GetProducers(node); - CHECK(producers.size()); + auto input_data = GetInputNodeData(node); + CHECK(input_data.size()); - auto producer_data = GetNodeData(producers.front()); - return shape_dict.at(producer_data->id()); + return shape_dict.at(input_data.front()->id()); } std::vector GetOutputShape(const Node* node, const absl::flat_hash_map& shape_dict) { diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 148c82946b..c7c5a850bc 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -43,7 +43,7 @@ using ConditionFunction = std::functionoutputs) { + FusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph) { fusion_groups_ = graph->fusion_groups; // init fusion relation. InitFusionRelation(); @@ -56,7 +56,6 @@ class FusionMergePassHelper : public FusionHelperBase { GroupList operator()() { // run fusion merge untill no update. DoFusionMerge(); - AddGlobalOutputNodesToGroups(); for (auto& group : fusion_groups_) { VLOG(3) << "Fusion Group -> " << group->group_id; for (auto& sub_group : group->fused_sub_groups) { @@ -73,18 +72,6 @@ class FusionMergePassHelper : public FusionHelperBase { } private: - void AddGlobalOutputNodesToGroups() { - for (auto group : fusion_groups_) { - for (const auto& output_node_data : graph_output_node_data_) { - Node* node = output_node_data->source_node.get(); - std::unordered_set node_set = group->NodeSet(); - if (node_set.find(node) != node_set.end()) { - group->output_nodes.insert(node); - } - } - } - } - void DoFusionMerge() { VLOG(3) << "DoFusionMerge...!"; while (DoHorizontalFusion()) { @@ -994,7 +981,6 @@ class FusionMergePassHelper : public FusionHelperBase { } GroupList fusion_groups_; - const std::vector& graph_output_node_data_; std::unordered_map fusion_groups_index_; std::unordered_map> input_to_consumers_; diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index fa7305e2b8..b9b7417416 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -278,9 +278,6 @@ CONDITION_FUNC(injective_horizontal_with_reduce) { } CONDITION_FUNC(reduce_fuse_broadcast) { - if (helper->target_ == common::DefaultHostTarget()) { - return true; - } // if same shape with horizontal relation if (is_same_size(helper, first, second)) { return true; @@ -316,7 +313,7 @@ CONDITION_FUNC(reduce_fuse_broadcast) { reduce_size *= reducer_input_shape[idx - 1]; } // Check if the reduce size exceeds the hardware limit - if (reduce_size > helper->target_.max_num_threads()) { + if (helper->target_ == common::DefaultNVGPUTarget() && reduce_size > helper->target_.max_num_threads()) { return false; } From 9ea72b09559c698a87185753562df85479dd4332 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 25 Apr 2023 13:21:32 +0000 Subject: [PATCH 03/12] feat(fuse): add reduce group fuse with broadcast --- cinn/hlir/framework/op_lowering.cc | 2 +- cinn/hlir/framework/op_lowering_util.cc | 64 ++++++++------- cinn/hlir/framework/op_lowering_util.h | 3 +- cinn/hlir/pass/fusion_helper_base.h | 11 +++ cinn/hlir/pass/fusion_merge_pass.cc | 22 +++-- cinn/hlir/pass/fusion_merge_pass_util.h | 105 ++++++++++++++++++++++++ 6 files changed, 171 insertions(+), 36 deletions(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index 3a00034c27..c8148e8af1 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -1326,7 +1326,7 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, } VLOG(3) << "Before Sync IRLowerOp schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map); + SyncThreadWithShared(ir_sch, nodes_inline, nodes_set, this->shape_dict_, tensor_map, group); VLOG(4) << "After IRSchedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); } diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index 9680a37f36..d762906438 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -593,10 +593,13 @@ bool CanbeInline(Node* node, return false; } else { auto node_shape = GetOutputShape(node, shape_dict); - auto last_shape = GetOutputShape(laster, shape_dict); - if (std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) != - std::accumulate(last_shape.begin(), last_shape.end(), 1, std::multiplies())) { - return true; + auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + for (auto consumer : consumers) { + auto consumer_shape = GetOutputShape(consumer, shape_dict); + auto consumer_size = std::accumulate(consumer_shape.begin(), consumer_shape.end(), 1, std::multiplies()); + if (node_size != consumer_size) { + return true; + } } return false; @@ -1228,7 +1231,7 @@ void LoopComputeAt(ir::IRSchedule& ir_sch, auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); if (!group->output_nodes.count(node)) { auto block = ir_sch.GetBlock(GetNodeData(node)->id()); - ir_sch.SetBuffer(block, "local", true); + ir_sch.SetBuffer(block, "local"); } if (op_pattern_dict[node->op()] == framework::kReduction) { @@ -1285,11 +1288,14 @@ std::unordered_map GetNodeDataSet(const std::unordered_s return node_data_set; } -Node* GetMaster(Node* node, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set) { +std::unordered_set GetMasters(Node* node, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set) { // find consumer std::unordered_set visited; std::queue candidates; candidates.push(node); + std::unordered_set masters; while (!candidates.empty()) { auto candidate = candidates.front(); @@ -1304,19 +1310,20 @@ Node* GetMaster(Node* node, const std::unordered_set& nodes_inline, const candidates.push(consumer); visited.insert(consumer); } else { - return consumer; + masters.insert(consumer); } } } - return nullptr; + return masters; } void SyncThreadWithShared(ir::IRSchedule& ir_sch, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map) { + const std::unordered_map& tensor_map, + const GroupPtr& group) { auto exprs_inorder = ir_sch.GetAllBlocks(); auto node_data_set = GetNodeDataSet(nodes_set); auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); @@ -1353,34 +1360,35 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, auto node = node_data->source_node.get(); auto node_shape = shape_dict.at(node_data->id()); - auto master = GetMaster(node, nodes_inline, nodes_set); - if (!master) { + auto masters = GetMasters(node, nodes_inline, nodes_set); + if (masters.empty()) { continue; } - auto master_data = GetNodeData(master); - auto master_shape = shape_dict.at(master_data->id()); - if (op_pattern_dict[master->op()] == framework::kReduction) { - master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); - } + bool do_set_buffer_to_shared = false; + for (auto master : masters) { + auto master_data = GetNodeData(master); + auto master_shape = shape_dict.at(master_data->id()); + if (op_pattern_dict[master->op()] == framework::kReduction) { + master_shape = shape_dict.at(master->inlinks_in_order()[0]->source()->id()); + } - auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); - auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); + auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); - if (node_size == master_size) { - continue; + if (node_size != master_size) { + if (check_sync_mark(idx, master_data->id())) { + auto loops = ir_sch.GetLoops(master_data->id()); + ir_sch.SyncThreads(loops.back(), false); + sync_mark.insert(master_data->id()); + } + do_set_buffer_to_shared = true; + } } - - { + if (do_set_buffer_to_shared && group->output_nodes.find(node) == group->output_nodes.end()) { auto block = ir_sch.GetBlock(node_data->id()); ir_sch.SetBuffer(block, "shared", true); } - - if (check_sync_mark(idx, master_data->id())) { - auto loops = ir_sch.GetLoops(master_data->id()); - ir_sch.SyncThreads(loops.back(), false); - sync_mark.insert(master_data->id()); - } } } diff --git a/cinn/hlir/framework/op_lowering_util.h b/cinn/hlir/framework/op_lowering_util.h index 03e18912e4..f5e7267f43 100644 --- a/cinn/hlir/framework/op_lowering_util.h +++ b/cinn/hlir/framework/op_lowering_util.h @@ -86,7 +86,8 @@ void SyncThreadWithShared(ir::IRSchedule& ir_sch, const std::unordered_set& nodes_inline, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict, - const std::unordered_map& tensor_map); + const std::unordered_map& tensor_map, + const GroupPtr& group); } // namespace framework } // namespace hlir diff --git a/cinn/hlir/pass/fusion_helper_base.h b/cinn/hlir/pass/fusion_helper_base.h index 94ef3460aa..7658cc0792 100644 --- a/cinn/hlir/pass/fusion_helper_base.h +++ b/cinn/hlir/pass/fusion_helper_base.h @@ -112,6 +112,17 @@ class FusionHelperBase { return producer_node; } + std::vector GetConsumerNode(const Node* node) const { + std::vector consumer_nodes; + auto node_data = GetNodeData(node); + for (auto& link : node_data->outlinks()) { + auto consumer = link->sink()->safe_as(); + CHECK(consumer); + consumer_nodes.push_back(consumer); + } + return consumer_nodes; + } + bool WithoutLastDimInReduce(const std::vector& inshape, const std::vector& axes) const { // if last axis is in reduce. if (std::find(axes.begin(), axes.end(), inshape.size() - 1) != axes.end() || diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 88f54dc566..148c82946b 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -43,7 +43,7 @@ using ConditionFunction = std::functionoutputs) { fusion_groups_ = graph->fusion_groups; // init fusion relation. InitFusionRelation(); @@ -56,6 +56,7 @@ class FusionMergePassHelper : public FusionHelperBase { GroupList operator()() { // run fusion merge untill no update. DoFusionMerge(); + AddGlobalOutputNodesToGroups(); for (auto& group : fusion_groups_) { VLOG(3) << "Fusion Group -> " << group->group_id; for (auto& sub_group : group->fused_sub_groups) { @@ -72,6 +73,18 @@ class FusionMergePassHelper : public FusionHelperBase { } private: + void AddGlobalOutputNodesToGroups() { + for (auto group : fusion_groups_) { + for (const auto& output_node_data : graph_output_node_data_) { + Node* node = output_node_data->source_node.get(); + std::unordered_set node_set = group->NodeSet(); + if (node_set.find(node) != node_set.end()) { + group->output_nodes.insert(node); + } + } + } + } + void DoFusionMerge() { VLOG(3) << "DoFusionMerge...!"; while (DoHorizontalFusion()) { @@ -617,10 +630,6 @@ class FusionMergePassHelper : public FusionHelperBase { void RecomputeWithCostModel(const GroupPtr& producer, std::unordered_set& fusionable_consumers) { - if (producer->op_pattern_kind == framework::kReduction) { - CHECK_EQ(fusionable_consumers.size(), 1) << "Find more than one consumer can fuse to " << producer->group_id; - } - // if is const op if (is_const_group(this, producer)) { std::unordered_set candidates; @@ -976,7 +985,7 @@ class FusionMergePassHelper : public FusionHelperBase { relation.vertical_relation = {// reduce and elementwise can be horizontal/vertical relation. {OpPatternKind::kElementWise, reduce_fuse_elementwise}, // reduce and broadcast op must be horizontal relation. - {OpPatternKind::kBroadcast, is_same_size}, + {OpPatternKind::kBroadcast, reduce_fuse_broadcast}, // reduce and injective op must be horizontal relation. {OpPatternKind::kInjective, horizontal_with_injective}, // reduce and reduce must be horizontal relation. @@ -985,6 +994,7 @@ class FusionMergePassHelper : public FusionHelperBase { } GroupList fusion_groups_; + const std::vector& graph_output_node_data_; std::unordered_map fusion_groups_index_; std::unordered_map> input_to_consumers_; diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index a37126e764..fa7305e2b8 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -277,6 +277,111 @@ CONDITION_FUNC(injective_horizontal_with_reduce) { return elementwise_fuse_reduce(helper, first, second); } +CONDITION_FUNC(reduce_fuse_broadcast) { + if (helper->target_ == common::DefaultHostTarget()) { + return true; + } + // if same shape with horizontal relation + if (is_same_size(helper, first, second)) { + return true; + } + + // Traversing all reducers in all producers requires two types of conditions to be met. + // The first type is the condition that the reducer itself needs to meet, + // and the second type is the condition that the relationship between each reducer and its consumers with type of + // Broadcast needs to meet. It is required that each consumer of type Broadcast meet the same shape after broadcast as + // before reduce. + for (auto& node_in_master : first->master_nodes) { + if (helper->GetOpKind(node_in_master) != OpPatternKind::kReduction) { + continue; + } + Node* reducer = node_in_master; + // First type conditions + // Get some reduce infomation + auto reducer_input_shape = helper->GetNodeInputShape(reducer); + auto reducer_output_shape = helper->GetNodeDataShape(reducer); + auto reduce_axes = absl::get>(reducer->attrs.attr_store.at("dim")); + auto keep_dim = absl::get(reducer->attrs.attr_store.at("keep_dim")); + for (auto& axis : reduce_axes) { + if (axis == -1) { + axis = reducer_input_shape.size() - 1; + } + } + // Check if the reduce axes are continuous + int reduce_size = reducer_input_shape.back(); + for (auto idx = reduce_axes.size() - 1; idx >= 1; --idx) { + if (reduce_axes[idx] != reduce_axes[idx - 1] + 1) { + return false; + } + reduce_size *= reducer_input_shape[idx - 1]; + } + // Check if the reduce size exceeds the hardware limit + if (reduce_size > helper->target_.max_num_threads()) { + return false; + } + + // Second type conditions + // Find directly or indirectly consumers with type of Broadcast in the second group + auto find_broadcasters_in_descendants = [&](const Node* producer) -> std::unordered_set { + std::queue candidates; + std::unordered_set visited_set; + std::unordered_set broadcasters; + candidates.push(producer); + + while (!candidates.empty()) { + auto candidate = candidates.front(); + candidates.pop(); + + for (auto consumer : helper->GetConsumerNode(candidate)) { + if (helper->GetOpKind(consumer) == OpPatternKind::kBroadcast && + second->NodeSet().find(consumer) != second->NodeSet().end()) { + broadcasters.insert(consumer); + } else if (!visited_set.count(consumer)) { + visited_set.insert(consumer); + candidates.push(consumer); + } + } + } + + return broadcasters; + }; + + // Check if each broadcast node meets the conditions + std::unordered_set broadcasters_in_consumers = find_broadcasters_in_descendants(reducer); + for (auto broadcaster : broadcasters_in_consumers) { + auto broadcaster_output_shape = absl::get>(broadcaster->attrs.attr_store.at("out_shape")); + auto broadcast_axes = absl::get>(broadcaster->attrs.attr_store.at("broadcast_axes")); + for (auto& axis : broadcast_axes) { + if (axis == -1) { + axis = broadcaster_output_shape.size() - 1; + } + } + + if (reducer_input_shape != broadcaster_output_shape) { + return false; + } + + if (keep_dim) { + continue; + } else { + // if reducer_output_shape = [1] + if (reducer_output_shape.size() == 1 && reducer_output_shape[0] == 1) { + continue; + } + // check union [reduce_axes, broadcast_axes] = reducer_input_shape + for (int idx = 0; idx < reducer_input_shape.size(); ++idx) { + if (!(std::find(broadcast_axes.begin(), broadcast_axes.end(), idx) == broadcast_axes.end()) ^ + std::find(reduce_axes.begin(), reduce_axes.end(), idx) == reduce_axes.end()) { + return false; + } + } + } + } + } + + return true; +} + CONDITION_FUNC(reduce_fuse_reduce) { // check reduce horizontal with reduce. if (!horizontal_relation(helper, first, second, framework::OpPatternKind::kReduction)) { From ae6efa7c36d74d4408c018074fb00e0ee8b2af01 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Wed, 26 Apr 2023 17:04:02 +0800 Subject: [PATCH 04/12] fix(fuse): fix mean_op unittests --- cinn/hlir/framework/op_lowering_util.cc | 7 +++---- cinn/hlir/pass/fusion_merge_pass.cc | 16 +--------------- cinn/hlir/pass/fusion_merge_pass_util.h | 5 +---- 3 files changed, 5 insertions(+), 23 deletions(-) diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index d762906438..d0946ce90d 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -165,11 +165,10 @@ bool IsConstOp(const framework::Node* node) { } std::vector GetInputShape(const Node* node, const absl::flat_hash_map& shape_dict) { - auto producers = GetProducers(node); - CHECK(producers.size()); + auto input_data = GetInputNodeData(node); + CHECK(input_data.size()); - auto producer_data = GetNodeData(producers.front()); - return shape_dict.at(producer_data->id()); + return shape_dict.at(input_data.front()->id()); } std::vector GetOutputShape(const Node* node, const absl::flat_hash_map& shape_dict) { diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 148c82946b..c7c5a850bc 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -43,7 +43,7 @@ using ConditionFunction = std::functionoutputs) { + FusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph) { fusion_groups_ = graph->fusion_groups; // init fusion relation. InitFusionRelation(); @@ -56,7 +56,6 @@ class FusionMergePassHelper : public FusionHelperBase { GroupList operator()() { // run fusion merge untill no update. DoFusionMerge(); - AddGlobalOutputNodesToGroups(); for (auto& group : fusion_groups_) { VLOG(3) << "Fusion Group -> " << group->group_id; for (auto& sub_group : group->fused_sub_groups) { @@ -73,18 +72,6 @@ class FusionMergePassHelper : public FusionHelperBase { } private: - void AddGlobalOutputNodesToGroups() { - for (auto group : fusion_groups_) { - for (const auto& output_node_data : graph_output_node_data_) { - Node* node = output_node_data->source_node.get(); - std::unordered_set node_set = group->NodeSet(); - if (node_set.find(node) != node_set.end()) { - group->output_nodes.insert(node); - } - } - } - } - void DoFusionMerge() { VLOG(3) << "DoFusionMerge...!"; while (DoHorizontalFusion()) { @@ -994,7 +981,6 @@ class FusionMergePassHelper : public FusionHelperBase { } GroupList fusion_groups_; - const std::vector& graph_output_node_data_; std::unordered_map fusion_groups_index_; std::unordered_map> input_to_consumers_; diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index fa7305e2b8..b9b7417416 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -278,9 +278,6 @@ CONDITION_FUNC(injective_horizontal_with_reduce) { } CONDITION_FUNC(reduce_fuse_broadcast) { - if (helper->target_ == common::DefaultHostTarget()) { - return true; - } // if same shape with horizontal relation if (is_same_size(helper, first, second)) { return true; @@ -316,7 +313,7 @@ CONDITION_FUNC(reduce_fuse_broadcast) { reduce_size *= reducer_input_shape[idx - 1]; } // Check if the reduce size exceeds the hardware limit - if (reduce_size > helper->target_.max_num_threads()) { + if (helper->target_ == common::DefaultNVGPUTarget() && reduce_size > helper->target_.max_num_threads()) { return false; } From 12cd9f9bd9149867c93881d6679e561c21a7ec9b Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Thu, 4 May 2023 15:19:47 +0800 Subject: [PATCH 05/12] fix(fuse): Modify topological order with priority --- cinn/hlir/framework/op_lowering.cc | 2 +- cinn/hlir/framework/op_lowering_util.cc | 99 +++++++++++++++++++++++++ cinn/hlir/framework/op_lowering_util.h | 4 + 3 files changed, 104 insertions(+), 1 deletion(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index c8148e8af1..fe9b2b00dc 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -1207,7 +1207,7 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, // topological order. auto nodes_set = group->NodeSet(); auto v_consumers = BuildVirtualConsumer(group, this->shape_dict_); - auto nodes_in_order = TopologicalOrder(group, v_consumers); + auto nodes_in_order = BFSTopologicalOrderWithPriority(group, v_consumers, this->shape_dict_); // find reducer. std::unordered_set nodes_inline; auto greducer = FindGlobalReducer(nodes_in_order); diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index 6b2c5976e5..37d2ea6fd0 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -310,6 +310,19 @@ std::vector FindConsumers(Node* node, return consumers; } +std::vector FindProducers(Node* node, + const std::unordered_set& nodes_set, + const std::unordered_map& virtual_consumers) { + auto producers = GetProducersInSet(node, nodes_set); + for (const auto& iter : virtual_consumers) { + if (iter.second == node) { + producers.push_back(iter.first); + } + } + + return producers; +} + std::vector TopologicalOrder(const GroupPtr& group, const std::unordered_map& virtual_consumers) { std::vector nodes_in_order; std::unordered_set nodes_set = group->NodeSet(); @@ -335,6 +348,92 @@ std::vector TopologicalOrder(const GroupPtr& group, const std::unordered_ return nodes_in_order; } +std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, + const std::unordered_map& virtual_consumers, + const absl::flat_hash_map& shape_dict) { + struct NodeWithPriority { + Node* node; + int priority; + }; + + struct Comparator { + bool operator()(const NodeWithPriority& lhs, const NodeWithPriority& y) { return lhs.priority < y.priority; } + }; + + std::vector nodes_in_order; + std::unordered_set visited; + std::unordered_set nodes_set = group->NodeSet(); + std::unordered_map degree_map; + std::priority_queue, Comparator> priority_candidates; + std::vector visited_numel; + + // Calculate the priority of a node. + // The smaller the value, the higher the priority. + // Prioritize the same shape before considering OpPattern + auto PriorityFunc = [&visited_numel, &shape_dict](const Node* node) -> int { + auto node_shape = GetOutputShape(node, shape_dict); + int numel = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); + int index = -1; + for (int i = 0; i < visited_numel.size(); ++i) { + if (numel == visited_numel[i]) { + index = i; + break; + } + } + if (index == -1) { + index = visited_numel.size(); + visited_numel.push_back(numel); + } + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + return index * 10 + static_cast(op_pattern_dict[node->op()]); + }; + + for (Node* node : nodes_set) { + auto consumers = FindConsumers(node, nodes_set, virtual_consumers); + // Some nodes may have multiple edges between them, resulting in duplicates in the consumer. + // We only need to calculate once. + std::unordered_set consumers_without_duplicate(consumers.begin(), consumers.end()); + degree_map[node] = consumers_without_duplicate.size(); + if (degree_map.at(node) == 0) { + priority_candidates.push(NodeWithPriority{node, PriorityFunc(node)}); + } + } + + // Nested BFS, outer layer traverses priority, inner layer performs BFS on current priority. + while (!priority_candidates.empty()) { + Node* cur_priority_node = priority_candidates.top().node; + priority_candidates.pop(); + + std::queue bfs_queue; + bfs_queue.push(cur_priority_node); + visited.insert(cur_priority_node); + while (!bfs_queue.empty()) { + Node* cur = bfs_queue.front(); + bfs_queue.pop(); + + nodes_in_order.push_back(cur); + auto producers = FindProducers(cur, nodes_set, virtual_consumers); + for (Node* node : producers) { + --degree_map[node]; + // Ensure that each node is accessed only once and maintain topological order. + if (visited.count(node) != 0 || degree_map[node] != 0) { + continue; + } + // Perform BFS access to the current priority producers + int node_priority = PriorityFunc(node); + if (node_priority == PriorityFunc(cur_priority_node)) { + bfs_queue.push(node); + visited.insert(node); + } else { + priority_candidates.push(NodeWithPriority{node, node_priority}); + } + } + } + } + + return nodes_in_order; +} + bool WithoutLastDimInReduce(const std::vector& shape, const std::vector& axes) { if (axes.empty()) { return false; diff --git a/cinn/hlir/framework/op_lowering_util.h b/cinn/hlir/framework/op_lowering_util.h index f5e7267f43..8687ec7fbe 100644 --- a/cinn/hlir/framework/op_lowering_util.h +++ b/cinn/hlir/framework/op_lowering_util.h @@ -49,6 +49,10 @@ std::vector GetConsumersInSet(const Node* node, const std::unordered_set< std::vector TopologicalOrder(const GroupPtr& group, const std::unordered_map& virtual_consumers); +std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, + const std::unordered_map& virtual_consumers, + const absl::flat_hash_map& shape_dict); + Node* FindGlobalReducer(const std::vector& nodes_in_order); Node* FindNearestReducer(const Node* node, const std::unordered_set& nodes_set); From 6cb0331467d042e5f9d213c85426363ea347a621 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Fri, 5 May 2023 11:09:54 +0800 Subject: [PATCH 06/12] fix(fuse): fix bugs in topo order with priority --- cinn/hlir/framework/op_lowering_util.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index 37d2ea6fd0..7f23038095 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -357,7 +357,7 @@ std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, }; struct Comparator { - bool operator()(const NodeWithPriority& lhs, const NodeWithPriority& y) { return lhs.priority < y.priority; } + bool operator()(const NodeWithPriority& lhs, const NodeWithPriority& rhs) { return lhs.priority > rhs.priority; } }; std::vector nodes_in_order; @@ -413,7 +413,8 @@ std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, nodes_in_order.push_back(cur); auto producers = FindProducers(cur, nodes_set, virtual_consumers); - for (Node* node : producers) { + std::unordered_set producers_without_duplicate(producers.begin(), producers.end()); + for (Node* node : producers_without_duplicate) { --degree_map[node]; // Ensure that each node is accessed only once and maintain topological order. if (visited.count(node) != 0 || degree_map[node] != 0) { @@ -421,7 +422,7 @@ std::vector BFSTopologicalOrderWithPriority(const GroupPtr& group, } // Perform BFS access to the current priority producers int node_priority = PriorityFunc(node); - if (node_priority == PriorityFunc(cur_priority_node)) { + if (node_priority <= PriorityFunc(cur_priority_node)) { bfs_queue.push(node); visited.insert(node); } else { From 241b715b696d0945ef4246d119282500860fe039 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Sat, 6 May 2023 14:36:32 +0800 Subject: [PATCH 07/12] fix(fuse): fix bug in CanbeInline condition --- cinn/hlir/framework/op_lowering_util.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index 7f23038095..bb25dc49f8 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -686,6 +686,11 @@ bool CanbeInline(Node* node, } } + auto last_shape = GetOutputShape(laster, shape_dict); + if (node_size != std::accumulate(last_shape.begin(), last_shape.end(), 1, std::multiplies())) { + return true; + } + return false; } } From b905cefecf2a5544e43a1ea59c048d3540dcbace Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Sat, 6 May 2023 15:37:46 +0800 Subject: [PATCH 08/12] fix(fuse): fix CanbeInline condition --- cinn/hlir/framework/op_lowering.cc | 3 ++- cinn/hlir/framework/op_lowering_util.cc | 16 ++++++---------- cinn/hlir/framework/op_lowering_util.h | 6 +++++- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/cinn/hlir/framework/op_lowering.cc b/cinn/hlir/framework/op_lowering.cc index fe9b2b00dc..2a1ae9711c 100644 --- a/cinn/hlir/framework/op_lowering.cc +++ b/cinn/hlir/framework/op_lowering.cc @@ -1225,8 +1225,9 @@ void OpLowerer::IRSchedule(ir::IRSchedule& ir_sch, } } + auto masters = GetMasters(node, nodes_inline, nodes_set); // node can be inline. - if (CanbeInline(node, consumers, reducer, nodes_in_order.front(), group, nodes_set, this->shape_dict_)) { + if (CanbeInline(node, consumers, reducer, masters, group, nodes_set, this->shape_dict_)) { auto block = ir_sch.GetBlock(GetNodeData(node)->id()); ir::ComputeInlineChecker checker(ir_sch, block); if (!checker.Check()) { diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index bb25dc49f8..01650c3446 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -635,7 +635,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, bool CanbeInline(Node* node, const std::vector consumers, const Node* reducer, - const Node* laster, + const std::unordered_set masters, const GroupPtr& group, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict) { @@ -678,19 +678,15 @@ bool CanbeInline(Node* node, } else { auto node_shape = GetOutputShape(node, shape_dict); auto node_size = std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()); - for (auto consumer : consumers) { - auto consumer_shape = GetOutputShape(consumer, shape_dict); - auto consumer_size = std::accumulate(consumer_shape.begin(), consumer_shape.end(), 1, std::multiplies()); - if (node_size != consumer_size) { + + for (auto master : masters) { + auto master_shape = GetOutputShape(master, shape_dict); + auto master_size = std::accumulate(master_shape.begin(), master_shape.end(), 1, std::multiplies()); + if (node_size != master_size) { return true; } } - auto last_shape = GetOutputShape(laster, shape_dict); - if (node_size != std::accumulate(last_shape.begin(), last_shape.end(), 1, std::multiplies())) { - return true; - } - return false; } } diff --git a/cinn/hlir/framework/op_lowering_util.h b/cinn/hlir/framework/op_lowering_util.h index 8687ec7fbe..db92b74c68 100644 --- a/cinn/hlir/framework/op_lowering_util.h +++ b/cinn/hlir/framework/op_lowering_util.h @@ -60,7 +60,7 @@ Node* FindNearestReducer(const Node* node, const std::unordered_set& node bool CanbeInline(Node* node, const std::vector consumers, const Node* reducer, - const Node* laster, + const std::unordered_set masters, const GroupPtr& group, const std::unordered_set& nodes_set, const absl::flat_hash_map& shape_dict); @@ -72,6 +72,10 @@ Node* GetMasterToComputeAt(Node* node, const std::unordered_map& virtual_consumers, const absl::flat_hash_map& shape_dict); +std::unordered_set GetMasters(Node* node, + const std::unordered_set& nodes_inline, + const std::unordered_set& nodes_set); + void LoopAssignReduce(ir::IRSchedule& ir_sch, const Node* node, const Node* reducer, From 5586b2d52a6d935527a7126135f658c13b68ba17 Mon Sep 17 00:00:00 2001 From: BiynXu <244524405@qq.com> Date: Tue, 25 Apr 2023 03:14:04 +0000 Subject: [PATCH 09/12] feat(fuse): support vertical reduce fuse reduce --- cinn/hlir/pass/fusion_merge_pass.cc | 23 ++++++++++++++++------- cinn/hlir/pass/fusion_merge_pass_util.h | 4 ---- cinn/hlir/pass/op_fusion_pass.cc | 2 +- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index c7c5a850bc..475b9e9869 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -814,14 +814,23 @@ class FusionMergePassHelper : public FusionHelperBase { auto& consumers = input_consumers.second; std::unordered_set updated_consumers; for (auto& consumer : consumers) { - // if group is sub group - if (consumer->belong_groups.size()) { - // inset belong group to consumers. - for (auto& belong_group : consumer->belong_groups) { - updated_consumers.insert(belong_group); + std::queue fused_groups; + fused_groups.push(consumer); + while (!fused_groups.empty()) { + auto& cur = fused_groups.front(); + fused_groups.pop(); + // if group is sub group + if (cur->belong_groups.empty()) { + updated_consumers.insert(cur); + } else { + for (auto& belong_group : cur->belong_groups) { + if (belong_group->group_id == cur->group_id) { + updated_consumers.insert(cur); + } else { + fused_groups.push(belong_group); + } + } } - } else { - updated_consumers.insert(consumer); } } consumers = updated_consumers; diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index 0057d39b12..82bbabd20f 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -388,10 +388,6 @@ CONDITION_FUNC(reduce_fuse_broadcast) { } CONDITION_FUNC(reduce_fuse_reduce) { - // check reduce horizontal with reduce. - if (!horizontal_relation(helper, first, second, framework::OpPatternKind::kReduction)) { - return false; - } if (!limit_args(helper, first, second)) { return false; } diff --git a/cinn/hlir/pass/op_fusion_pass.cc b/cinn/hlir/pass/op_fusion_pass.cc index 026f2c6195..021e66e9d3 100644 --- a/cinn/hlir/pass/op_fusion_pass.cc +++ b/cinn/hlir/pass/op_fusion_pass.cc @@ -267,7 +267,7 @@ class OpFusionPassHelper : public FusionHelperBase { // producer -> fusion relation.fusion_op_kind = { // horizontal or vertical relation(Reduce + Elementwise*), check without last dimension in reduce. - {framework::kElementWise, without_last_dimension_in_reduce}, + {framework::kElementWise, is_same_size}, // must be horizontal relation, check with same output shape and without last dimension in reduce. {framework::kBroadcast, reduce_fuse_broadcast}, // must be horizontal relation and with same reduce attr. From 7e752a4e74bcf40321e06d8c4fdb42fc41c03564 Mon Sep 17 00:00:00 2001 From: 6clc Date: Fri, 5 May 2023 17:41:12 +0800 Subject: [PATCH 10/12] feat(fusion): reduce fusion reduce --- cinn/hlir/pass/fusion_merge_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cinn/hlir/pass/fusion_merge_pass.cc b/cinn/hlir/pass/fusion_merge_pass.cc index 475b9e9869..0121f8f056 100644 --- a/cinn/hlir/pass/fusion_merge_pass.cc +++ b/cinn/hlir/pass/fusion_merge_pass.cc @@ -825,7 +825,7 @@ class FusionMergePassHelper : public FusionHelperBase { } else { for (auto& belong_group : cur->belong_groups) { if (belong_group->group_id == cur->group_id) { - updated_consumers.insert(cur); + updated_consumers.insert(belong_group); } else { fused_groups.push(belong_group); } From b5c7bbf1b0c82a0d1b10921919654e16562bd497 Mon Sep 17 00:00:00 2001 From: 6clc Date: Mon, 8 May 2023 11:16:43 +0800 Subject: [PATCH 11/12] test(fusion): fix fusion_merge_pass according to new fusion rule --- cinn/hlir/pass/fusion_merge_pass_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cinn/hlir/pass/fusion_merge_pass_test.cc b/cinn/hlir/pass/fusion_merge_pass_test.cc index e834da510c..544f86019c 100755 --- a/cinn/hlir/pass/fusion_merge_pass_test.cc +++ b/cinn/hlir/pass/fusion_merge_pass_test.cc @@ -401,7 +401,7 @@ TEST(FusionMergePass, Reduce_Test_2) { auto graph = std::make_shared(program, target); hlir::framework::ApplyPass(graph.get(), "OpFusionPass"); - CHECK_EQ(graph->fusion_groups.size(), 4); + CHECK_EQ(graph->fusion_groups.size(), 3); hlir::framework::ApplyPass(graph.get(), "FusionMergePass"); CHECK_EQ(graph->fusion_groups.size(), 2); } From 389e2b56ab8b2bbc274d2bb0769e5be4f745312f Mon Sep 17 00:00:00 2001 From: 6clc Date: Mon, 24 Apr 2023 11:00:45 +0800 Subject: [PATCH 12/12] feat(reduction): support warp reduce (#1354) * feature(reduction): support warp reduce --- cinn/common/target.cc | 25 ++++++++- cinn/common/target.h | 4 ++ cinn/frontend/net_builder.cc | 9 +++- cinn/hlir/framework/op_lowering_test.cc | 69 +++++++++++++++++++++++++ cinn/hlir/framework/op_lowering_util.cc | 21 ++++++-- cinn/hlir/op/reduction_test.cc | 47 +++++++++++++++++ cinn/hlir/pe/reduction.cc | 21 ++++++-- 7 files changed, 187 insertions(+), 9 deletions(-) diff --git a/cinn/common/target.cc b/cinn/common/target.cc index c2b26601b8..8646bd4d7e 100644 --- a/cinn/common/target.cc +++ b/cinn/common/target.cc @@ -11,13 +11,16 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#include "cinn/common/target.h" +#ifdef CINN_WITH_CUDA +#include +#include +#endif #include #include +#include "cinn/common/target.h" #include "cinn/runtime/cinn_runtime.h" namespace cinn { @@ -49,6 +52,24 @@ int Target::max_num_threads() const { return 1024; } +int Target::get_multi_processor_count() const { + CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get multi processor count"; + int num_sm = 0; +#ifdef CINN_WITH_CUDA + cudaDeviceGetAttribute(&num_sm, cudaDeviceAttr::cudaDevAttrMultiProcessorCount, 0); +#endif + return num_sm; +} + +int Target::get_max_threads_per_sm() const { + CHECK(arch == Arch::NVGPU) << "The target is not NVGPU! Cannot get max threads per stream processor"; + int max_thread = 0; +#ifdef CINN_WITH_CUDA + cudaDeviceGetAttribute(&max_thread, cudaDeviceAttr::cudaDevAttrMaxThreadsPerMultiProcessor, 0); +#endif + return max_thread; +} + std::vector Target::get_target_libs() const { return libs; } int Target::get_target_bits() const { diff --git a/cinn/common/target.h b/cinn/common/target.h index 33dacefe29..f9fe56efa7 100755 --- a/cinn/common/target.h +++ b/cinn/common/target.h @@ -80,6 +80,10 @@ struct Target { int max_num_threads() const; + int get_multi_processor_count() const; + + int get_max_threads_per_sm() const; + int get_target_bits() const; std::vector get_target_libs() const; diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index 53563ca8f2..121844f798 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -115,7 +115,14 @@ Variable NetBuilder::Reduce(const std::string& op_type, const Variable& x, const return Reshape(x, new_shape); } } - return CustomInstr(op_type, {x}, {{"dim", dim}, {"keep_dim", keep_dim}}).front(); + // Convert the negative dim to a positive number + std::vector reduce_dim(dim.begin(), dim.end()); + for (int i = 0; i < dim.size(); i++) { + if (reduce_dim[i] < 0) { + reduce_dim[i] = x->shape.size() + reduce_dim[i]; + } + } + return CustomInstr(op_type, {x}, {{"dim", reduce_dim}, {"keep_dim", keep_dim}}).front(); } #define NETBUILDER_UNARY_OP_DEF(func_name__, op_type__) \ diff --git a/cinn/hlir/framework/op_lowering_test.cc b/cinn/hlir/framework/op_lowering_test.cc index 3b3601055a..336d46ee9b 100644 --- a/cinn/hlir/framework/op_lowering_test.cc +++ b/cinn/hlir/framework/op_lowering_test.cc @@ -1171,6 +1171,75 @@ TEST(OP_LOWERING, Reduce_Fusion_Test_21) { Compile(net_builder); } +TEST(OpFusionPass, Block_Reduce_Fuse_Broadcast) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold - 10; + int w = 256; + NetBuilder net_builder("Block_Reduce_Fuse_Broadcast"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1}, true); + auto C = net_builder.BroadcastTo(B, {h, w}, {0, 1}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Block_Reduce_Fuse_Elementwise) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold - 10; + int w = 256; + NetBuilder net_builder("Block_Reduce_Fuse_Elementwise"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h}, "B"); + auto C = net_builder.ReduceSum(A, {1}, true); + auto D = net_builder.Add(B, C); + } + + Compile(net_builder); +} +TEST(OpFusionPass, Warp_Reduce_Fuse_Broadcast) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold + 10; + int w = 256; + NetBuilder net_builder("Warp_Reduce_Fuse_Broadcast"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.ReduceSum(A, {1}, true); + auto C = net_builder.BroadcastTo(B, {h, w}, {0, 1}); + } + + Compile(net_builder); +} + +TEST(OpFusionPass, Warp_Reduce_Fuse_Elementwise) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + int h = warp_reduce_threshold + 10; + int w = 256; + NetBuilder net_builder("Warp_Reduce_Fuse_Elementwise"); + // create model + { + auto A = net_builder.CreateInput(Float(32), {h, w}, "A"); + auto B = net_builder.CreateInput(Float(32), {h}, "B"); + auto C = net_builder.ReduceSum(A, {1}, true); + auto D = net_builder.Add(B, C); + } + + Compile(net_builder); +} + } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/framework/op_lowering_util.cc b/cinn/hlir/framework/op_lowering_util.cc index e193b44970..ef5ea83f40 100644 --- a/cinn/hlir/framework/op_lowering_util.cc +++ b/cinn/hlir/framework/op_lowering_util.cc @@ -574,10 +574,25 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, const std::vector& inshape, const std::vector& axes, const common::Target& target) { + // If the number of current device SM is smaller than the number of SM + // required by Warp Reduce, the performance of Warp Reduce is better. + // Otherwise, use Block Reduce. + auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int need_reduce_last_count = 1; + for (int i = 0; i < inshape.size(); i++) { + if (find(axes.begin(), axes.end(), i) == axes.end()) { + need_reduce_last_count *= inshape[i]; + } + } + int warp_reduce_need_sm_count = ceil((need_reduce_last_count * 32) / float(target.get_max_threads_per_sm())); + // Set Num_max_threads to 32 is Warp Reduce + if (target.get_multi_processor_count() < warp_reduce_need_sm_count) { + max_num_threads = 32; + } // find first reduce and second reduce axis. - int lane = 1; - int index = static_cast(axes.size()) - 1; - auto max_num_threads = target.max_num_threads(); + int lane = 1; + int index = static_cast(axes.size()) - 1; + for (; index >= 0; --index) { if (index + 1 < axes.size() && axes[index] != axes[index + 1] - 1) { break; diff --git a/cinn/hlir/op/reduction_test.cc b/cinn/hlir/op/reduction_test.cc index b1986be20f..870dda7d5d 100644 --- a/cinn/hlir/op/reduction_test.cc +++ b/cinn/hlir/op/reduction_test.cc @@ -465,6 +465,53 @@ TEST(Operator, Operator_Reduction_Case_11) { GenReduceCode(shape, dim, "Operator_Reduction_Case_11"); } +TEST(Operator, Operator_Reduction_Case_Warp_Reduce) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {warp_reduce_threshold + 10, 256}; + std::vector dim = {1}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce"); + CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); +} + +TEST(Operator, Operator_Reduction_Case_Block_Reduce) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {warp_reduce_threshold - 10, 33}; + std::vector dim = {1}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce"); + CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); +} + +TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {(warp_reduce_threshold + 32) / 2, 2, 10, 256}; + std::vector dim = {2, 3}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Warp_Reduce_Case_1"); + CHECK(res.second.find("threadIdx.x < 32") != std::string::npos); +} + +TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) { + int sm_count = common::DefaultNVGPUTarget().get_multi_processor_count(); + int max_threads_per_sm = common::DefaultNVGPUTarget().get_max_threads_per_sm(); + int warp_reduce_threshold = sm_count * max_threads_per_sm / 32; + + std::vector shape = {(warp_reduce_threshold - 32) / 2, 2, 10, 33}; + std::vector dim = {2, 3}; + + auto res = GenReduceCode(shape, dim, "Operator_Reduction_Case_Block_Reduce_Case_2"); + CHECK(res.second.find("threadIdx.x < 32") == std::string::npos); +} } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/reduction.cc b/cinn/hlir/pe/reduction.cc index d1f58a0bab..0ea9abe849 100644 --- a/cinn/hlir/pe/reduction.cc +++ b/cinn/hlir/pe/reduction.cc @@ -684,10 +684,25 @@ std::vector TwoStepBlockReduceInternal(const ir::Tensor& A, BlockReduceFunc block_reduce_func, ir::Expr initial) { CHECK(!WithoutLastDimInReduce(A->shape, axes)) << "Can't find last axis in reduce!"; + // If the number of current device SM is smaller than the number of SM + // required by Warp Reduce, the performance of Warp Reduce is better. + // Otherwise, use Block Reduce. + auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int need_reduce_last_count = 1; + for (int i = 0; i < A->shape.size(); i++) { + if (find(axes.begin(), axes.end(), i) == axes.end()) { + need_reduce_last_count *= A->shape[i].as_int32(); + } + } + int warp_reduce_need_sm_count = + ceil((need_reduce_last_count * 32) / float(common::DefaultNVGPUTarget().get_max_threads_per_sm())); + // Set Num_max_threads to 32 is Warp Reduce + if (common::DefaultNVGPUTarget().get_multi_processor_count() < warp_reduce_need_sm_count) { + max_num_threads = 32; + } - int lane = A->shape[axes.back()].as_int32(); - int index = static_cast(axes.size()) - 2; - auto max_num_threads = common::DefaultNVGPUTarget().max_num_threads(); + int lane = A->shape[axes.back()].as_int32(); + int index = static_cast(axes.size()) - 2; for (; index >= 0; --index) { if (lane >= max_num_threads / 2) { break;