Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Mmahbubu/provenance fix grpah replace (#3624)
Browse files Browse the repository at this point in the history
* Fixes bug in provenenace for subgraph replacement

* Updates unit tests for the provenance algorithm fix

* Changes provnance set to ordered set for better consistency in iteration order.
  • Loading branch information
tachyon77 authored and diyessi committed Sep 18, 2019
1 parent 8295b84 commit b8b66de
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
14 changes: 10 additions & 4 deletions src/ngraph/graph_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,20 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
{
auto common_args = ngraph::find_common_args(target, replacement);

auto set_replacement_prov = [replacement](std::shared_ptr<Node> node) {
replacement->merge_provenance_tags_from(node);
std::set<string> removed_subgraph_tags;

auto set_replacement_prov = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
for (auto tag : node->get_provenance_tags())
{
removed_subgraph_tags.insert(tag);
}
};

traverse_nodes({target}, set_replacement_prov, false, common_args);
replacement->add_provenance_tags(removed_subgraph_tags);

auto set_prov_new_nodes = [replacement](std::shared_ptr<Node> node) {
node->merge_provenance_tags_from(replacement);
auto set_prov_new_nodes = [&removed_subgraph_tags](std::shared_ptr<Node> node) {
node->add_provenance_tags(removed_subgraph_tags);
};

traverse_nodes({replacement}, set_prov_new_nodes, false, common_args);
Expand Down
10 changes: 9 additions & 1 deletion src/ngraph/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ void Node::set_placement_index(size_t placement)
m_placement_index = placement;
}

const std::unordered_set<std::string>& Node::get_provenance_tags() const
const std::set<std::string>& Node::get_provenance_tags() const
{
return m_provenance_tags;
}
Expand All @@ -328,6 +328,14 @@ void Node::add_provenance_tag(const std::string& tag)
m_provenance_tags.insert(tag);
}

void Node::add_provenance_tags(const std::set<std::string>& tag_set)
{
for (auto tag : tag_set)
{
add_provenance_tag(tag);
}
}

void Node::remove_provenance_tag(const std::string& tag)
{
m_provenance_tags.erase(tag);
Expand Down
5 changes: 3 additions & 2 deletions src/ngraph/node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,9 @@ namespace ngraph
/// Set device placement
void set_placement_index(size_t placement);

const std::unordered_set<std::string>& get_provenance_tags() const;
const std::set<std::string>& get_provenance_tags() const;
void add_provenance_tag(const std::string& tag);
void add_provenance_tags(const std::set<std::string>& tag_set);
void remove_provenance_tag(const std::string& tag);

// to be used when nodes are replaced
Expand Down Expand Up @@ -426,7 +427,7 @@ namespace ngraph
std::string m_unique_name;
NGRAPH_API
static std::atomic<size_t> m_next_instance_id;
std::unordered_set<std::string> m_provenance_tags;
std::set<std::string> m_provenance_tags;
std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs;
std::unordered_map<Node*, autodiff::Adjoints> m_adjoint_map;
Expand Down
8 changes: 4 additions & 4 deletions test/provenance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using namespace std;
using namespace ngraph;
using ::testing::Return;

using ProvSet = std::unordered_set<std::string>;
using ProvSet = std::set<std::string>;

TEST(provenance, provenance)
{
Expand Down Expand Up @@ -231,7 +231,7 @@ TEST(provenance, provenance)
//
// A{tag_a} B{tag_b}
// | |
// E{tag_c, tag_d} |
// E{tag_c} |
// | |
// D{tag_c, tag_d}
//
Expand All @@ -258,7 +258,7 @@ TEST(provenance, provenance)
replace_node(c, d);

EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c"}));
}

//
Expand Down Expand Up @@ -310,7 +310,7 @@ TEST(provenance, provenance)
replace_node(c, d);

EXPECT_EQ(d->get_provenance_tags(), (ProvSet{"tag_c", "tag_d"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_d", "tag_e"}));
EXPECT_EQ(e->get_provenance_tags(), (ProvSet{"tag_c", "tag_e"}));
}

//
Expand Down

0 comments on commit b8b66de

Please sign in to comment.