Skip to content

Commit

Permalink
Stop constantfold_subgraph on nonconstfoldable ShapeOf (openvinotoolk…
Browse files Browse the repository at this point in the history
…it#21171)

* Stop constantfold_subgraph on nonconstfoldable ShapeOf

Ticket: CVS-124628

* pass tensor to Constant constructor

---------

Co-authored-by: Ivan Tikhonov <[email protected]>
  • Loading branch information
mateusztabaka and itikhono authored Nov 27, 2023
1 parent b02ddc5 commit b027766
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/core/src/validation_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,15 @@ std::shared_ptr<ov::op::v0::Constant> ov::util::constantfold_subgraph(const Outp
if (num_inputs == 0)
return nullptr;

if (subgraph_sink.get_tensor().has_and_set_bound()) {
const auto& lower = subgraph_sink.get_tensor().get_lower_value();
return std::make_shared<ov::op::v0::Constant>(lower);
}

if (ov::is_type<op::util::ShapeOfBase>(node) && node->get_input_partial_shape(0).is_dynamic()) {
return nullptr;
}

OutputVector inputs;
inputs.reserve(num_inputs);
for (size_t i = 0; i < num_inputs; i++) {
Expand Down
16 changes: 16 additions & 0 deletions src/core/tests/validation_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,19 @@ TEST(constantfold_subgraph, split) {
auto actual = ret->cast_vector<float>();
ASSERT_EQ(expected, actual);
}

TEST(constantfold_subgraph, shapeof) {
auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{-1, 3, -1});
auto shapeof = std::make_shared<ov::op::v3::ShapeOf>(param);
auto zero = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {1});
auto two = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {2});
auto stop = std::make_shared<ov::op::v8::Slice>(shapeof, one /*start*/, two /*stop*/, one /*step*/, zero /*axis*/);
auto slice = std::make_shared<ov::op::v8::Slice>(param, one /*start*/, stop, one /*step*/, one /*axis*/);

auto ret = ov::util::constantfold_subgraph(stop);
ASSERT_NE(ret, nullptr);
auto actual = ret->cast_vector<int64_t>();
std::vector<int64_t> expected{3};
ASSERT_EQ(expected, actual);
}

0 comments on commit b027766

Please sign in to comment.