diff --git a/include/graphit/midend/mir.h b/include/graphit/midend/mir.h index 13b886bf..b8533091 100644 --- a/include/graphit/midend/mir.h +++ b/include/graphit/midend/mir.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -54,6 +55,8 @@ namespace graphit { return to(cloneNode()); } + // We use a single map to hold all metadata on the MIR Node + std::unordered_map> metadata_map; protected: template std::shared_ptr self() { @@ -68,6 +71,32 @@ namespace graphit { // as I slowly add in support for copy functionalities return nullptr; }; + public: + // Functions to set and retrieve metadata of different types + template + void setMetadata(std::string mdname, T val) { + typename MIRMetadataImpl::Ptr mdnode = std::make_shared>(val); + metadata_map[mdname] = mdnode; + } + // This function is safe to be called even if the metadata with + // the specified name doesn't exist + template + bool hasMetadata(std::string mdname) { + if (metadata_map.find(mdname) == metadata_map.end()) + return false; + typename MIRMetadata::Ptr mdnode = metadata_map[mdname]; + if (!mdnode->isa()) + return false; + return true; + } + // This function should be called only after confirming that the + // metadata with the given name exists + template + T getMetadata(std::string mdname) { + assert(hasMetadata(mdname)); + typename MIRMetadata::Ptr mdnode = metadata_map[mdname]; + return mdnode->to()->val; + } }; struct Expr : public MIRNode { diff --git a/include/graphit/midend/mir_metadata.h b/include/graphit/midend/mir_metadata.h new file mode 100644 index 00000000..77f51ae2 --- /dev/null +++ b/include/graphit/midend/mir_metadata.h @@ -0,0 +1,46 @@ +#ifndef MIR_METADATA_H +#define MIR_METADATA_H + +#include +#include +namespace graphit { +namespace mir { + +template +class MIRMetadataImpl; + +// The abstract class for the mir metadata +// Different templated metadata types inherit from this type +class MIRMetadata: public std::enable_shared_from_this { +public: + typedef std::shared_ptr Ptr; + virtual ~MIRMetadata() = default; + + + template + bool isa (void) { + if(std::dynamic_pointer_cast>(shared_from_this())) + return true; + return false; + } + template + std::shared_ptr> to(void) { + std::shared_ptr> ret = std::dynamic_pointer_cast>(shared_from_this()); + assert(ret != nullptr); + return ret; + } +}; + +// Templated metadata class for each type +template +class MIRMetadataImpl: public MIRMetadata { +public: + typedef std::shared_ptr> Ptr; + T val; + MIRMetadataImpl(T _val): val(_val) { + } +}; + +} +} +#endif diff --git a/test/c++/midend_test.cpp b/test/c++/midend_test.cpp index 5f1fdb99..89e4105b 100644 --- a/test/c++/midend_test.cpp +++ b/test/c++/midend_test.cpp @@ -110,4 +110,76 @@ TEST_F(MidendTest, SimpleVertexSetDeclAllocWithMain) { "const vertices : vertexset{Vertex} = new vertexset{Vertex}(5);\n" "func main() print 4; end"); EXPECT_EQ (0, basicTest(is)); -} \ No newline at end of file +} + +// Test cases for the MIRMetadata API +TEST_F(MidendTest, SimpleMetadataTest) { + istringstream is("func main() print 4; end"); + EXPECT_EQ(0, basicTest(is)); + EXPECT_EQ(true, mir_context_->isFunction("main")); + + mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main"); + + main_func->setMetadata("basic_boolean_md", true); + main_func->setMetadata("basic_int_md", 42); + EXPECT_EQ(true, main_func->hasMetadata("basic_boolean_md")); + EXPECT_EQ(true, main_func->getMetadata("basic_boolean_md")); + + EXPECT_EQ(true, main_func->hasMetadata("basic_int_md")); + EXPECT_EQ(42, main_func->getMetadata("basic_int_md")); + +} +TEST_F(MidendTest, SimpleMetadataTestNoExist) { + istringstream is("func main() print 4; end"); + EXPECT_EQ(0, basicTest(is)); + EXPECT_EQ(true, mir_context_->isFunction("main")); + + mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main"); + + main_func->setMetadata("basic_int_md", 42); + EXPECT_EQ(false, main_func->hasMetadata("other_int_md")); + EXPECT_EQ(false, main_func->hasMetadata("basic_int_md")); +} + +TEST_F(MidendTest, SimpleMetadataTestString) { + istringstream is("func main() print 4; end"); + EXPECT_EQ(0, basicTest(is)); + EXPECT_EQ(true, mir_context_->isFunction("main")); + + mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main"); + + main_func->setMetadata("basic_str_md", "md value"); + EXPECT_EQ(true, main_func->hasMetadata("basic_str_md")); + EXPECT_EQ("md value", main_func->getMetadata("basic_str_md")); +} + +TEST_F(MidendTest, SimpleMetadataTestMIRNodeAsMD) { + istringstream is("const val:int = 42;\nfunc main() print val; end"); + EXPECT_EQ(0, basicTest(is)); + EXPECT_EQ(true, mir_context_->isFunction("main")); + EXPECT_EQ(1, mir_context_->getConstants().size()); + + mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main"); + mir::VarDecl::Ptr decl = mir_context_->getConstants()[0]; + + main_func->setMetadata("used_var_md", decl); + + EXPECT_EQ(true, main_func->hasMetadata("used_var_md")); + mir::MIRNode::Ptr mdnode = main_func->getMetadata("used_var_md"); + EXPECT_EQ(true, mir::isa(mdnode)); +} + +TEST_F(MidendTest, SimpleMetadataTestMIRNodeVectorAsMD) { + istringstream is("const val:int = 42;\nconst val2: int = 55;\nfunc main() print val + val2; end"); + EXPECT_EQ(0, basicTest(is)); + EXPECT_EQ(true, mir_context_->isFunction("main")); + EXPECT_EQ(2, mir_context_->getConstants().size()); + + mir::FuncDecl::Ptr main_func = mir_context_->getFunction("main"); + std::vector decls = mir_context_->getConstants(); + + main_func->setMetadata>("used_vars_md", decls); + + EXPECT_EQ(true, main_func->hasMetadata>("used_vars_md")); + EXPECT_EQ(2, main_func->getMetadata>("used_vars_md").size()); +}