Skip to content

Commit

Permalink
Check broadcast node in can_mark_node
Browse files Browse the repository at this point in the history
Signed-off-by: yuan.xiong <[email protected]>
  • Loading branch information
yuanxion committed Dec 13, 2024
1 parent ddc6af4 commit 1fd12ac
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@ void mark_shape_of_subgraphs::look_for_shape_of_subgraph(program_node& node) {
return;
}

// skip mark_node for broadcast node if dependency nodes are data and shape_of
auto& dependencies = node.get_dependencies();
if (node.is_type<broadcast>() && dependencies.size() == 2) {
if (dependencies[0].first->is_type<data>() && dependencies[1].first->is_type<shape_of>())
return;
}

// Check if all dependencies are constant or marked as a part of shape_of subgraph
bool can_execute_in_subgraph = true;
bool has_shape_of_subgraph_dep = false;
Expand Down Expand Up @@ -94,6 +87,13 @@ bool mark_shape_of_subgraphs::can_mark_node(const program_node& node) {
return false;
}

// skip mark_node for broadcast node if dependency nodes are data and shape_of
auto& dependencies = node.get_dependencies();
if (node.is_type<broadcast>() && dependencies.size() == 2) {
if (dependencies[0].first->is_type<data>() && dependencies[1].first->is_type<shape_of>())
return false;
}

return true;
}

Expand Down

0 comments on commit 1fd12ac

Please sign in to comment.