Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Validate pass handling in pass::Manager #26705

Closed
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,6 @@ bool pass::SimplifyShapeOfSubGraph::run_on_model(const std::shared_ptr<Model>& f

REGISTER_PASS(manager, PrepareShapeOpsForEliminationAroundBE)
REGISTER_PASS(manager, AbsSinking)
// FIXME: manager runs Validate based on the last pass, when fixed the following line must be deleted
REGISTER_PASS(manager, Validate)
REGISTER_PASS(manager, SharedOpOptimization)
REGISTER_PASS(manager, EliminateGatherUnsqueeze) // should run after SharedOpOptimization
REGISTER_PASS(manager, NopElimination, m_use_shapes)
Expand Down
2 changes: 1 addition & 1 deletion src/core/include/openvino/pass/manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class OPENVINO_API Manager {
std::string m_name = "UnnamedManager";

private:
bool run_pass(const std::shared_ptr<PassBase>& pass, const std::shared_ptr<Model>& model, bool needs_validate);
bool run_pass(const std::shared_ptr<PassBase>& pass, const std::shared_ptr<Model>& model, bool& needs_validate);
};
} // namespace pass
} // namespace ov
14 changes: 10 additions & 4 deletions src/core/src/pass/manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,18 @@ bool ov::pass::Manager::run_passes(const shared_ptr<ov::Model>& model) {

bool model_changed = false;
bool pass_changed_model = false;
bool needs_validate = false;

profiler.start_timer(m_name);
for (const auto& pass : m_pass_list) {
const auto& pass_name = pass->get_name();

profiler.start_timer(pass_name);
pass_changed_model = run_pass(pass, model, pass_changed_model);
pass_changed_model = run_pass(pass, model, needs_validate);
profiler.stop_timer(pass_name, pass_changed_model);

model_changed = model_changed || pass_changed_model;
needs_validate = needs_validate || pass_changed_model;

profiler.visualize(model, pass_name);
profiler.serialize(model, pass_name);
Expand All @@ -357,7 +359,7 @@ bool ov::pass::Manager::run_passes(const shared_ptr<ov::Model>& model) {

bool ov::pass::Manager::run_pass(const std::shared_ptr<PassBase>& pass,
const std::shared_ptr<Model>& model,
bool needs_validate) {
bool& needs_validate) {
if (m_pass_config->is_disabled(pass->get_type_info())) {
OPENVINO_DEBUG("Pass ", pass->get_name(), " is disabled.");
return false;
Expand All @@ -379,9 +381,13 @@ bool ov::pass::Manager::run_pass(const std::shared_ptr<PassBase>& pass,
// GraphRewrite is a temporary container for MatcherPass to make execution on entire ov::Model
return GraphRewrite(matcher_pass).run_on_model(model);
} else if (auto model_pass = dynamic_pointer_cast<ModelPass>(pass)) {
if (dynamic_pointer_cast<ov::pass::Validate>(model_pass) && !needs_validate) {
return false;
if (dynamic_pointer_cast<ov::pass::Validate>(model_pass)) {
if (!needs_validate) {
return false;
}
needs_validate = false;
}

return model_pass->run_on_model(model);
}
return false;
Expand Down
175 changes: 175 additions & 0 deletions src/core/tests/pass_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
#include "openvino/op/matmul.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pass.hpp"
#include "openvino/pass/validate.hpp"

using namespace ov;
using namespace std;
Expand Down Expand Up @@ -77,6 +79,100 @@ bool validate_list(const std::vector<std::shared_ptr<ov::Node>>& nodes) {

} // namespace

class TestMatcherPassTrue : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TestMatcherPassTrue");
TestMatcherPassTrue() : MatcherPass() {
auto any_input = ov::pass::pattern::any_input();
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(any_input, "TestMatcherPassTrue");
this->register_matcher(m, callback);
}
};

class TestMatcherPassFalse : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TestMatcherPassFalse");
TestMatcherPassFalse() : MatcherPass() {
auto any_input = ov::pass::pattern::any_input();
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
return false;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(any_input, "TestMatcherPassFalse");
this->register_matcher(m, callback);
}
};

class TestModelPassTrue : public pass::ModelPass {
public:
OPENVINO_RTTI("TestModelPassTrue");

bool run_on_model(const std::shared_ptr<ov::Model>& f) override {
return true;
}
};

class TestModelPassFalse : public pass::ModelPass {
public:
OPENVINO_RTTI("TestModelPassFalse");

bool run_on_model(const std::shared_ptr<ov::Model>& f) override {
return false;
}
};

class TestValidate : public pass::Validate {
public:
OPENVINO_RTTI("TestValidate");

bool run_on_model(const std::shared_ptr<ov::Model>& f) override {
m_applied = true;
return pass::Validate::run_on_model(f);
}

bool is_applied() const {
return m_applied;
}

private:
bool m_applied = false;
};

class TestValidate2 : public TestValidate {};

class TestManager : public pass::Manager {
public:
bool is_validation_applied() {
bool applied = false;
bool is_init = true;
for (const auto& pass : m_pass_list) {
auto validate_2 = std::dynamic_pointer_cast<TestValidate2>(pass);
auto validate = std::dynamic_pointer_cast<TestValidate>(pass);
if (validate && !validate_2) {
if (is_init) {
is_init = false;
applied = validate->is_applied();
}
applied = applied && validate->is_applied();
}
}
return applied;
}

bool is_2nd_validation_applied() {
for (const auto& pass : m_pass_list) {
if (auto validate = std::dynamic_pointer_cast<TestValidate2>(pass)) {
return validate->is_applied();
}
}
return false;
}
};

TEST(pass_manager, add) {
pass::Manager pass_manager;

Expand All @@ -90,3 +186,82 @@ TEST(pass_manager, add) {
EXPECT_EQ(node_count, sorted.size());
EXPECT_TRUE(validate_list(sorted));
}

TEST(pass_manager, passes_not_applied) {
TestManager pass_manager;
pass_manager.set_per_pass_validation(false);

auto graph = make_test_graph();

pass_manager.register_pass<TestMatcherPassFalse>();
pass_manager.register_pass<TestModelPassFalse>();
pass_manager.register_pass<TestMatcherPassFalse>();
pass_manager.register_pass<TestModelPassFalse>();
pass_manager.register_pass<TestValidate>();
const auto res = pass_manager.run_passes(graph);

EXPECT_FALSE(res);
EXPECT_FALSE(pass_manager.is_validation_applied());
}

TEST(pass_manager, model_pass_applied) {
TestManager pass_manager;
pass_manager.set_per_pass_validation(false);

auto graph = make_test_graph();

pass_manager.register_pass<TestMatcherPassFalse>();
pass_manager.register_pass<TestModelPassTrue>();
pass_manager.register_pass<TestMatcherPassFalse>();
pass_manager.register_pass<TestModelPassFalse>();
pass_manager.register_pass<TestValidate>();
const auto res = pass_manager.run_passes(graph);

EXPECT_TRUE(res);
EXPECT_TRUE(pass_manager.is_validation_applied());
}

TEST(pass_manager, matcher_pass_applied) {
TestManager pass_manager;
pass_manager.set_per_pass_validation(false);

auto graph = make_test_graph();

pass_manager.register_pass<TestMatcherPassTrue>();
pass_manager.register_pass<TestModelPassFalse>();
pass_manager.register_pass<TestMatcherPassFalse>();
pass_manager.register_pass<TestModelPassFalse>();
pass_manager.register_pass<TestValidate>();
const auto res = pass_manager.run_passes(graph);

EXPECT_TRUE(res);
EXPECT_TRUE(pass_manager.is_validation_applied());
}

TEST(pass_manager, two_validations) {
TestManager pass_manager;
pass_manager.set_per_pass_validation(false);

auto graph = make_test_graph();

pass_manager.register_pass<TestMatcherPassTrue>();
pass_manager.register_pass<TestModelPassFalse>();
pass_manager.register_pass<TestMatcherPassFalse>();
pass_manager.register_pass<TestModelPassFalse>();
pass_manager.register_pass<TestValidate>();
pass_manager.register_pass<TestMatcherPassFalse>();
pass_manager.register_pass<TestModelPassFalse>();
pass_manager.register_pass<TestMatcherPassFalse>();
pass_manager.register_pass<TestModelPassFalse>();
pass_manager.register_pass<TestValidate2>();
pass_manager.register_pass<TestMatcherPassFalse>();
pass_manager.register_pass<TestModelPassTrue>();
pass_manager.register_pass<TestMatcherPassFalse>();
pass_manager.register_pass<TestModelPassFalse>();
pass_manager.register_pass<TestValidate>();
const auto res = pass_manager.run_passes(graph);

EXPECT_TRUE(res);
EXPECT_TRUE(pass_manager.is_validation_applied());
EXPECT_FALSE(pass_manager.is_2nd_validation_applied());
}
Loading