diff --git a/include/elf_parser.h b/include/elf_parser.h new file mode 100644 index 0000000..9574228 --- /dev/null +++ b/include/elf_parser.h @@ -0,0 +1,76 @@ +#include +#include + +#include +#include +#include +#include + +#include "logger/logger.h" +#include "macro.h" + +namespace hook { + +class HOOK_API CachedSymbolTable { + public: + struct StringRefIterator { + StringRefIterator(const char *str); + + StringRefIterator &operator++() &; + StringRefIterator operator++(int) &; + + adt::StringRef operator*(); + + bool operator==(const StringRefIterator &other) const; + bool operator!=(const StringRefIterator &other) const; + const void *data() const; + + private: + const char *str_; + }; + + StringRefIterator strtab_begin(const char *str) const; + + StringRefIterator strtab_end(const char *str) const; + + CachedSymbolTable(const std::string &name, const void *base_address); + + void move_to_section_header(); + + void move_to_section_header(size_t index); + + adt::StringRef getSectionName(size_t index) const; + size_t find_section(adt::StringRef name) const; + + void load_symbol_table(); + void parse_section(); + void parse_named_section(); + + const std::string &lookUpSymbol(const void *func) const; + + const std::unordered_map &getSymbolTable() const { + return symbol_table; + } + + size_t min_addrtess() const { return min_address_; } + size_t max_addrtess() const { return max_address_; } + + private: + std::string libName; + std::ifstream ifs; + ElfW(Ehdr) elf_header; + std::vector section_header_str; + std::vector sections; + std::vector strtab; + std::unordered_map symbol_table; + const void *base_address; + size_t min_address_ = -1; + size_t max_address_ = 0; +}; + +CachedSymbolTable *createSymbolTable(const std::string &lib, + const void *address); + +CachedSymbolTable *getSymbolTable(const std::string &lib); + +} // namespace hook diff --git a/include/env_util.h b/include/env_util.h index e7f6ab5..92b386f 100644 --- a/include/env_util.h +++ b/include/env_util.h @@ -18,6 +18,8 @@ namespace hook { // if we use the function template we need implement all of the str2value_impl // functions to support more type, but sometimes we need implement the // str2value_impl near the type define + +// TODO: replace C style const char* and length with StringRef template struct str2value_impl { void operator()(T& value, const char* str, size_t len = std::string::npos, @@ -57,7 +59,7 @@ struct str2value_impl> { for (; i < len && str[i] != '\0'; ++i) { if (str[i] == '=') { pair.first = str2value()(str, i); - pair.second = str2value()(str + i + 1); + pair.second = str2value()(str + i + 1, len - i - 1); break; } } diff --git a/include/hook.h b/include/hook.h index 08d7d51..3b61229 100644 --- a/include/hook.h +++ b/include/hook.h @@ -29,6 +29,7 @@ void install_hook(); struct OriginalInfo { const char* libName = nullptr; const void* basePtr = nullptr; + const void* baseHeadPtr = nullptr; void* relaPtr = nullptr; void* oldFuncPtr = nullptr; void** pltTablePtr = nullptr; @@ -40,6 +41,7 @@ struct OriginalInfo { relaPtr = info.relaPtr; oldFuncPtr = info.oldFuncPtr; pltTablePtr = info.pltTablePtr; + baseHeadPtr = info.baseHeadPtr; return *this; } }; @@ -276,6 +278,7 @@ struct HookFeatureBase { void* newFunc; void** oldFunc; std::function filter_; + std::function getNewCallback_; }; using WrapFuncGenerator = std::function; @@ -340,6 +343,12 @@ struct __HookFeature : public HookFeatureBase { return newFuncGenerator(libName, symName.c_str(), newFunc); } + __HookFeature& setGetNewCallback( + const std::function& getNewCallback) { + getNewCallback_ = getNewCallback; + return *this; + } + std::function findUniqueFunc; WrapFuncGenerator newFuncGenerator; }; @@ -372,6 +381,9 @@ struct MemberDetector(*iter) is a pointer and it's point to // std::get<1>(*iter) then there will return nullptr *iter->oldFunc = info.oldFuncPtr; + if (iter->getNewCallback_) { + iter->getNewCallback_(info); + } return iter->getNewFunc(info.libName); } }; diff --git a/include/logger/StringRef.h b/include/logger/StringRef.h index 5d1143c..aa23848 100644 --- a/include/logger/StringRef.h +++ b/include/logger/StringRef.h @@ -24,7 +24,8 @@ class StringRef { StringRef(const char* str) : size_(strlen(str)), str_(str) {} - StringRef(const char* str, size_t size) : size_(size), str_(str) {} + StringRef(const char* str, size_t size) + : size_(size != std::string::npos ? size : strlen(str_)), str_(str) {} StringRef(const_iterator begin, const_iterator end) : size_(std::distance(begin, end)), str_(begin) {} @@ -46,7 +47,9 @@ class StringRef { size_t size() const { return size_; } - std::string str() const { return std::string(str_, str_ + size_); } + std::string str() const { + return size_ != 0 ? std::string(str_, str_ + size_) : std::string(); + } const char* data() const { return str_; } diff --git a/include/logger/logger.h b/include/logger/logger.h index 578bb2f..63a8287 100644 --- a/include/logger/logger.h +++ b/include/logger/logger.h @@ -40,10 +40,13 @@ enum class LogModule { profile, trace, hook, python, memory, debug, last }; class LogModuleHelper { public: static auto& enum_strs() { - static std::array strs = {"PROFILE", "TRACE", "HOOK", - "PYTHON", "MEMORY", "LAST"}; + static std::array strs = { + "PROFILE", "TRACE", "HOOK", "PYTHON", "MEMORY", "DEBUG", "LAST"}; + static_assert(sizeof(strs) / sizeof(const char*) == + static_cast(LogModule::last) + 1); return strs; } + static auto begin() { return enum_strs().begin(); } static auto end() { return enum_strs().end(); } diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e32fba7..a0bce24 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -18,6 +18,7 @@ add_library(cuda_mock STATIC env_mgr.cpp backtrace.cpp statistic.cpp + elf_parser.cpp hook_context.cpp hook.cpp cuda_op_tracer.cpp diff --git a/lib/elf_parser.cpp b/lib/elf_parser.cpp new file mode 100644 index 0000000..50b09df --- /dev/null +++ b/lib/elf_parser.cpp @@ -0,0 +1,182 @@ +#include "elf_parser.h" +namespace hook { + +CachedSymbolTable::StringRefIterator::StringRefIterator(const char *str) + : str_(str) {} + +CachedSymbolTable::StringRefIterator & +CachedSymbolTable::StringRefIterator::operator++() & { + size_t len = strlen(str_); + ++len; + str_ += len; + return *this; +} + +CachedSymbolTable::StringRefIterator +CachedSymbolTable::StringRefIterator::operator++(int) & { + auto ret = StringRefIterator(str_); + ++*this; + return ret; +} + +adt::StringRef CachedSymbolTable::StringRefIterator::operator*() { + return adt::StringRef(str_); +} + +bool CachedSymbolTable::StringRefIterator::operator==( + const StringRefIterator &other) const { + return str_ == other.str_; +} +bool CachedSymbolTable::StringRefIterator::operator!=( + const StringRefIterator &other) const { + return !(*this == other); +} + +const void *CachedSymbolTable::StringRefIterator::data() const { return str_; } + +CachedSymbolTable::CachedSymbolTable(const std::string &name, + const void *base_address) + : libName(name), ifs(name), base_address(base_address) { + MLOG(DEBUG, INFO) << name << " base address:" << base_address; + ifs.read(reinterpret_cast(&elf_header), sizeof(elf_header)); + parse_named_section(); + parse_section(); + load_symbol_table(); + for (size_t i = 0; i < sections.size(); ++i) { + MLOG(DEBUG, INFO) << "found section:" << i + << " name:" << getSectionName(i); + } +} + +CachedSymbolTable::StringRefIterator CachedSymbolTable::strtab_begin( + const char *str) const { + return CachedSymbolTable::StringRefIterator(str); +} + +CachedSymbolTable::StringRefIterator CachedSymbolTable::strtab_end( + const char *str) const { + return CachedSymbolTable::StringRefIterator(str); +} + +void CachedSymbolTable::move_to_section_header() { + ifs.seekg(elf_header.e_shoff, std::ios::beg); +} + +void CachedSymbolTable::move_to_section_header(size_t index) { + MLOG(DEBUG, INFO) << "move to:" << index; + move_to_section_header(); + size_t shstr_h_oft = sizeof(ElfW(Shdr)) * index; + ifs.seekg(shstr_h_oft, std::ios::cur); +} + +adt::StringRef CachedSymbolTable::getSectionName(size_t index) const { + return adt::StringRef(§ion_header_str.at(sections.at(index).sh_name)); +} + +size_t CachedSymbolTable::find_section(adt::StringRef name) const { + size_t index = 0; + for (; index < sections.size(); index++) { + if (getSectionName(index) == name) { + break; + } + } + return index; +} + +void CachedSymbolTable::load_symbol_table() { + size_t symtab_h_index = find_section(".symtab"); + if (symtab_h_index >= sections.size()) { + LOG(WARN) << "can't found secton: .symtab"; + return; + } + ifs.seekg(sections[symtab_h_index].sh_offset, std::ios::beg); + std::vector symbol_tb(sections[symtab_h_index].sh_size / + sizeof(symbol_tb[0])); + ifs.read(reinterpret_cast(symbol_tb.data()), + sections[symtab_h_index].sh_size); + + size_t strtab_h_index = find_section(".strtab"); + if (strtab_h_index >= sections.size()) { + LOG(WARN) << "can't found secton: .strtab"; + return; + } + ifs.seekg(sections[strtab_h_index].sh_offset, std::ios::beg); + std::vector buf(sections[strtab_h_index].sh_size); + ifs.read(buf.data(), buf.size()); + + auto strtab_begin_iter = strtab_begin(buf.data()); + auto strtab_end_iter = strtab_end(buf.data() + buf.size()); + for (StringRefIterator iter = strtab_begin_iter; iter != strtab_end_iter; + ++iter) { + strtab.push_back((*iter).str()); + } + for (size_t i = 0; i < symbol_tb.size() / sizeof(ElfW(Sym)); ++i) { + if (strtab.size() <= symbol_tb[i].st_name) { + LOG(WARN) << "symbol_tb[" << i << "].st_name(" + << symbol_tb[i].st_name + << ") over strtab size:" << strtab.size(); + } + symbol_table.emplace(symbol_tb[i].st_value, + strtab[symbol_tb[i].st_name]); + if (symbol_tb[i].st_value > max_address_) { + max_address_ = symbol_tb[i].st_value; + } + if (symbol_tb[i].st_value < min_address_) { + min_address_ = symbol_tb[i].st_value; + } + } + MLOG(DEBUG, INFO) << libName << "\naddress range:" << min_address_ << "~" + << max_address_; + std::vector().swap(strtab); +} +void CachedSymbolTable::parse_section() { + CHECK_EQ(sizeof(ElfW(Shdr)), elf_header.e_shentsize); + move_to_section_header(); + sections.resize(elf_header.e_shnum); + MLOG(DEBUG, INFO) << "elf_header.e_shnum:" << elf_header.e_shnum; + ifs.read(reinterpret_cast(sections.data()), + sections.size() * sizeof(sections[0])); +} +void CachedSymbolTable::parse_named_section() { + move_to_section_header(elf_header.e_shstrndx); + + ElfW(Shdr) shstr_h; + ifs.read(reinterpret_cast(&shstr_h), sizeof(shstr_h)); + ifs.seekg(shstr_h.sh_offset, std::ios::beg); + + section_header_str.resize(shstr_h.sh_size); + ifs.read(section_header_str.data(), section_header_str.size()); +} + +const std::string &CachedSymbolTable::lookUpSymbol(const void *func) const { + auto offset = reinterpret_cast(func) - + reinterpret_cast(base_address); + MLOG(DEBUG, INFO) << "lookup address:" << offset; + auto iter = symbol_table.find(offset); + if (iter == symbol_table.end()) { + LOG(WARN) << libName + << "\nnot find launch_async symbol offset:" << offset + << " base address:" << base_address + << " func address:" << func << " range(" << min_address_ + << "~" << max_address_; + } else { + LOG(WARN) << "found launch_async symbol:" << iter->second; + } + static std::string empty(""); + return empty; +} + +static std::unordered_map> + table; + +CachedSymbolTable *createSymbolTable(const std::string &lib, + const void *address) { + auto iter = table.emplace(lib, new CachedSymbolTable(lib, address)); + return iter.first->second.get(); +} + +CachedSymbolTable *getSymbolTable(const std::string &lib) { + return table[lib].get(); +} + +} // namespace hook \ No newline at end of file diff --git a/lib/hook.cpp b/lib/hook.cpp index 6707338..bfbb680 100644 --- a/lib/hook.cpp +++ b/lib/hook.cpp @@ -282,6 +282,7 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { originalInfo.libName = pltTable->lib_name.c_str(); originalInfo.basePtr = pltTable->base_addr; originalInfo.relaPtr = pltTable->rela_plt; + originalInfo.baseHeadPtr = pltTable->base_header_addr; originalInfo.pltTablePtr = reinterpret_cast(addr); originalInfo.oldFuncPtr = reinterpret_cast(*reinterpret_cast(addr)); diff --git a/lib/logger.cpp b/lib/logger.cpp index bd01a96..5217e00 100644 --- a/lib/logger.cpp +++ b/lib/logger.cpp @@ -37,12 +37,14 @@ struct str2value_impl { void operator()(logger::LogLevel& lvl, const char* str, size_t len = std::string::npos) { auto iter = std::find(std::begin(gLoggerLevelStringSet()), - std::end(gLoggerLevelStringSet()), str); + std::end(gLoggerLevelStringSet()), + adt::StringRef(str, len)); if (iter != std::end(gLoggerLevelStringSet())) { lvl = static_cast( std::distance(std::begin(gLoggerLevelStringSet()), iter)); } else { - lvl = static_cast(::atoi(str)); + // default warning + lvl = logger::LogLevel::warning; } } }; @@ -360,6 +362,7 @@ class LogConsumer : public std::enable_shared_from_this { } void sync_pause_loop() { + if (exit_) return; exit_.store(true); if (cfg_->mode == LogConfig::kAsync) { if (th_ && th_->joinable()) th_->join(); diff --git a/lib/xpu_mock.cpp b/lib/xpu_mock.cpp index 9fd5d38..09fec27 100644 --- a/lib/xpu_mock.cpp +++ b/lib/xpu_mock.cpp @@ -13,6 +13,7 @@ #include #include "backtrace.h" +#include "elf_parser.h" #include "hook.h" #include "hooks/print_hook.h" #include "logger/StringRef.h" @@ -95,6 +96,8 @@ DEF_FUNCTION_INT(xpu_set_device, int devid) { DEF_FUNCTION_INT(xpu_launch_async, void* func) { // TODO: get symbol name from symbol table + auto libName = hook::HookRuntimeContext::instance().curLibName(); + hook::getSymbolTable(libName)->lookUpSymbol(func); return origin_xpu_launch_async(func); } @@ -163,8 +166,11 @@ DEF_FUNCTION_INT(cudaMemcpy, void* dst, const void* src, size_t count, return origin_cudaMemcpy(dst, src, count, kind); } -#define BUILD_FEATURE(name) \ - hook::FHookFeature(STR_TO_TYPE(#name), &name, &origin_##name) +/// need gcc version > 9.0 +/// #define BUILD_FEATURE(name) hook::FHookFeature(STR_TO_TYPE(#name), &name, +/// &origin_##name) + +#define BUILD_FEATURE(name) hook::HookFeature(#name, &name, &origin_##name) class XpuRuntimeApiHook : public hook::HookInstallerWrap { public: @@ -173,19 +179,30 @@ class XpuRuntimeApiHook : public hook::HookInstallerWrap { !adt::StringRef(name).contain("libcudart.so"); } - hook::FHookFeature symbols[14] = { - BUILD_FEATURE(xpu_malloc), BUILD_FEATURE(xpu_free), - BUILD_FEATURE(xpu_current_device), BUILD_FEATURE(xpu_set_device), - BUILD_FEATURE(xpu_wait), BUILD_FEATURE(xpu_memcpy), - BUILD_FEATURE(xpu_launch_async), BUILD_FEATURE(xpu_stream_create), + // need gcc version > 9.0 + // hook::FHookFeature symbols[14] = { + hook::HookFeature symbols[14] = { + BUILD_FEATURE(xpu_malloc), + BUILD_FEATURE(xpu_free), + BUILD_FEATURE(xpu_current_device), + BUILD_FEATURE(xpu_set_device), + BUILD_FEATURE(xpu_wait), + BUILD_FEATURE(xpu_memcpy), + BUILD_FEATURE(xpu_launch_async) + .setGetNewCallback([](const hook::OriginalInfo& info) { + hook::createSymbolTable(info.libName, info.baseHeadPtr); + }), + BUILD_FEATURE(xpu_stream_create), BUILD_FEATURE(xpu_stream_destroy), - BUILD_FEATURE(cudaMalloc), BUILD_FEATURE(cudaFree), - BUILD_FEATURE(cudaMemcpy), BUILD_FEATURE(cudaSetDevice), + BUILD_FEATURE(cudaMalloc), + BUILD_FEATURE(cudaFree), + BUILD_FEATURE(cudaMemcpy), + BUILD_FEATURE(cudaSetDevice), BUILD_FEATURE(cudaGetDevice), }; - void onSuccess() { LOG(WARN) << "install " << curSymName() << " success"; } + void onSuccess() { LOG(INFO) << "install " << curSymName() << " success"; } }; struct PatchRuntimeHook : public hook::HookInstallerWrap { diff --git a/test/cpp_test/mock_api_test.cpp b/test/cpp_test/mock_api_test.cpp index df17386..6940b8e 100644 --- a/test/cpp_test/mock_api_test.cpp +++ b/test/cpp_test/mock_api_test.cpp @@ -7,6 +7,7 @@ #include "GlobalVarMgr.h" #include "cuda_mock.h" +#include "elf_parser.h" #include "gtest/gtest.h" #include "hook.h" #include "logger/logger_stl.h" @@ -17,4 +18,12 @@ TEST(MockAnyHook, base) { dh_any_hook_install(); int ret = mock::foo(nullptr); EXPECT_EQ(ret, 0); -} \ No newline at end of file +} + +// TEST(MockAnyHook, symbol) { +// hook::CachedSymbolTable ctb("", nullptr); +// ctb.load_symbol_table(); +// for(auto& it : ctb.getSymbolTable()) { +// LOG(WARN) << it.second; +// } +// } diff --git a/test/cpp_test/test_hook.cpp b/test/cpp_test/test_hook.cpp index f4ab0a1..6892c6d 100644 --- a/test/cpp_test/test_hook.cpp +++ b/test/cpp_test/test_hook.cpp @@ -4,6 +4,7 @@ #include #include "GlobalVarMgr.h" +#include "elf_parser.h" #include "gtest/gtest.h" #include "hook.h" #include "logger/logger_stl.h" @@ -152,3 +153,27 @@ TEST(TestHookWrap, feature) { hook::HookRuntimeContext::instance().getCallCount("test1", "memcpy"), 2); } + +namespace { +int foo(void*) { return 0; } + +void* org_foo = nullptr; + +class SymbolParserHook : public hook::HookInstallerWrap { + public: + bool targetLib(const char* name) { return adt::StringRef(name) == ""; } + hook::HookFeature symbols[1] = { + HookFeature("_ZN4mock3fooEPv", &foo, &org_foo) + .setGetNewCallback([](const hook::OriginalInfo& info) { + hook::createSymbolTable(info.libName, info.baseHeadPtr); + })}; + + void onSuccess() {} +}; + +} // namespace + +TEST(TestHookWrap, symbol_parser) { + auto sh = std::make_shared(); + sh->install(); +}