From 90f64e0278f5f9434449db68968aff755f6c7398 Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Tue, 26 Nov 2024 09:24:40 +0330 Subject: [PATCH] Fix Or pattern behavior (#27721) ### Details: - Fixed Or pattern - Added new unit tests for Or and Optional patterns Or pattern have to point to the node from the selected branch. It also affects Optional pattern behavior as it uses Or pattern inside. ### Tickets: - CVS-157939 --- src/core/src/pattern/op/or.cpp | 2 +- src/core/tests/pattern.cpp | 57 ++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/src/core/src/pattern/op/or.cpp b/src/core/src/pattern/op/or.cpp index e2c37a322f3a5c..f0aa96120cc2be 100644 --- a/src/core/src/pattern/op/or.cpp +++ b/src/core/src/pattern/op/or.cpp @@ -13,7 +13,7 @@ bool ov::pass::pattern::op::Or::match_value(Matcher* matcher, auto saved = matcher->start_match(); if (matcher->match_value(input_value, graph_value)) { auto& pattern_map = matcher->get_pattern_value_map(); - pattern_map[input_value.get_node_shared_ptr()] = graph_value; + pattern_map[shared_from_this()] = graph_value; return saved.finish(true); } } diff --git a/src/core/tests/pattern.cpp b/src/core/tests/pattern.cpp index 694e291a61326a..050c36b65baad1 100644 --- a/src/core/tests/pattern.cpp +++ b/src/core/tests/pattern.cpp @@ -523,6 +523,63 @@ TEST(pattern, optional_match_node_with_single_input) { } } +TEST(pattern, or_pattern_points_the_selected_branch) { + using namespace ov::op; + using namespace ov::pass::pattern; + + // Graph: + auto model_param = make_shared(); + auto model_sigmoid = make_shared(model_param); + + // Pattern: + auto option_1 = wrap_type(); + auto option_2 = wrap_type(); + auto or_pattern = std::make_shared(ov::OutputVector{option_1, option_2}); + + // Test: + TestMatcher matcher; + EXPECT_TRUE(matcher.match(or_pattern, model_sigmoid)); + + auto pattern_val_mp = matcher.get_pattern_value_map(); + EXPECT_EQ(pattern_val_mp.count(or_pattern), 1); + + // we expect that Or pattern points to the first node of the selected branch + EXPECT_NE(ov::as_type(pattern_val_mp.at(or_pattern).get_node()), nullptr); +} + +TEST(pattern, multiple_optionals_in_row) { + using namespace ov::op; + using namespace ov::pass::pattern; + + // Graph: + Shape shape{1, 2, 3}; + auto model_input_0 = make_shared(element::f32, shape); + auto model_sigmoid = make_shared(model_input_0); + + // Pattern: + auto in = wrap_type(); + auto pattern_convert = optional(in); + auto pattern_relu = optional(pattern_convert); + auto pattern_sigmoid = wrap_type({pattern_relu}); + + // Test: + TestMatcher matcher; + EXPECT_TRUE(matcher.match(pattern_sigmoid, model_sigmoid)); + + auto pattern_val_mp = matcher.get_pattern_value_map(); + + EXPECT_EQ(pattern_val_mp.count(in), 1); + EXPECT_NE(ov::as_type(pattern_val_mp.at(in).get_node()), nullptr); + + // as Convert and Relu ops are not present in the graph, so we expect the optional nodes + // do not point to the graph nodes, in other words, the optional nodes are not in the pattern map. + EXPECT_EQ(pattern_val_mp.count(pattern_convert), 0); + EXPECT_EQ(pattern_val_mp.count(pattern_relu), 0); + + EXPECT_EQ(pattern_val_mp.count(pattern_sigmoid), 1); + EXPECT_NE(ov::as_type(pattern_val_mp.at(pattern_sigmoid).get_node()), nullptr); +} + // match optional nodes with multi input where order in not important TEST(pattern, optional_match_cumulative_node_with_multi_input) { Shape shape{1, 2, 3};