From 1f8021fd68ab4ea2831951be65106d3d895643e7 Mon Sep 17 00:00:00 2001 From: "long.chen" Date: Wed, 17 Jul 2024 05:42:49 -0400 Subject: [PATCH] add elf parser (#192) --- .github/workflows/clang-format-check.yml | 2 +- include/elf_parser.h | 156 ++++++++++++++++++ include/env_util.h | 63 +++---- include/hook.h | 54 ++++-- include/hook_context.h | 10 ++ include/logger/StringRef.h | 16 +- include/logger/logger.h | 7 +- include/util.h | 7 + lib/CMakeLists.txt | 2 + lib/elf_parser.cpp | 200 +++++++++++++++++++++++ lib/env_util.cpp | 8 +- lib/hook.cpp | 26 ++- lib/logger.cpp | 83 ++++++---- lib/statistic.cpp | 24 +-- lib/util.cpp | 21 +++ lib/xpu_mock.cpp | 42 +++-- test/cpp_test/mock_api_test.cpp | 3 +- test/cpp_test/test_elf_parser.cpp | 48 ++++++ test/cpp_test/test_hook.cpp | 1 + 19 files changed, 645 insertions(+), 128 deletions(-) create mode 100644 include/elf_parser.h create mode 100644 include/util.h create mode 100644 lib/elf_parser.cpp create mode 100644 lib/util.cpp create mode 100644 test/cpp_test/test_elf_parser.cpp diff --git a/.github/workflows/clang-format-check.yml b/.github/workflows/clang-format-check.yml index 81b74ca..52be082 100644 --- a/.github/workflows/clang-format-check.yml +++ b/.github/workflows/clang-format-check.yml @@ -41,7 +41,7 @@ jobs: - name: Install clang-format uses: aminya/setup-cpp@v1 with: - clangformat: 18.1.7 + clangformat: 13.0.1 - name: Run clang-format env: diff --git a/include/elf_parser.h b/include/elf_parser.h new file mode 100644 index 0000000..16c6640 --- /dev/null +++ b/include/elf_parser.h @@ -0,0 +1,156 @@ +#include +#include + +#include +#include +#include +#include + +#include "logger/logger.h" +#include "macro.h" + +namespace hook { + +class HOOK_API CachedSymbolTable { + public: + struct StringRefIterator { + StringRefIterator(const char *str); + + StringRefIterator &operator++() &; + StringRefIterator operator++(int) &; + + adt::StringRef operator*(); + + bool operator==(const StringRefIterator &other) const; + bool operator!=(const StringRefIterator &other) const; + const void *data() const; + + private: + const char *str_; + }; + + template + struct OwnerBuf { + OwnerBuf() = default; + OwnerBuf(const OwnerBuf &other) = delete; + OwnerBuf &operator=(const OwnerBuf &other) = delete; + + OwnerBuf(OwnerBuf &&other) { *this = std::move(other); } + OwnerBuf &operator=(OwnerBuf &&other) { + std::swap(buf_, other.buf_); + std::swap(size_, other.size_); + return *this; + } + + static OwnerBuf alloc(size_t size) { + OwnerBuf buf; + buf.buf_ = reinterpret_cast(malloc(size * sizeof(T))); + buf.size_ = size; + return buf; + } + + T &operator[](size_t index) { + assert(index < size_); + assert(buf_); + return buf_[index]; + } + + T *data() const { return buf_; } + + size_t size() const { return size_; } + + std::tuple release() { + auto result = std::make_tuple(buf_, size_); + buf_ = 0; + size_ = 0; + return result; + } + + template + friend class OwnerBuf; + + template + OwnerBuf cast() { + OwnerBuf result; + auto [buf, size] = release(); + result.buf_ = reinterpret_cast(buf); + result.size_ = size * sizeof(T) / sizeof(N); + return result; + } + + ~OwnerBuf() { + if (buf_) { + free(buf_); +#ifndef NDEBUG + buf_ = nullptr; + size_ = 0; +#endif + } + } + + private: + T *buf_ = nullptr; + size_t size_ = 0; + }; + + StringRefIterator strtab_begin(const char *str) const; + + StringRefIterator strtab_end(const char *str) const; + + std::tuple strtab_range( + const char *str, size_t size) const { + return std::make_tuple(strtab_begin(str), strtab_begin(str + size)); + } + + CachedSymbolTable(const std::string &name, const void *base_address, + const std::vector §ion_names = {}); + + void move_to_section_header(); + + void move_to_section_header(size_t index); + + adt::StringRef getSectionName(size_t index) const; + size_t find_section(adt::StringRef name) const; + + void load_symbol_table(); + void parse_section_header(); + void parse_named_section(); + + template + OwnerBuf load_section_data(adt::StringRef name) { + auto buf = load_section_data(name); + return buf.cast(); + } + + OwnerBuf load_section_data(adt::StringRef name); + + const std::string &lookUpSymbol(const void *func) const; + + const std::unordered_map &getSymbolTable() const { + return symbol_table; + } + + size_t min_addrtess() const { return min_address_; } + size_t max_addrtess() const { return max_address_; } + + const std::vector §ions() const { return sections_; } + + private: + std::string libName; + std::ifstream ifs; + ElfW(Ehdr) elf_header; + std::vector section_header_str; + std::vector sections_; + std::unordered_map symbol_table; + const void *base_address; + size_t min_address_ = -1; + size_t max_address_ = 0; + std::vector section_names; +}; + +CachedSymbolTable *createSymbolTable(const std::string &lib, + const void *address); + +CachedSymbolTable *getSymbolTable(const std::string &lib); + +} // namespace hook diff --git a/include/env_util.h b/include/env_util.h index e7f6ab5..ce0532b 100644 --- a/include/env_util.h +++ b/include/env_util.h @@ -18,17 +18,14 @@ namespace hook { // if we use the function template we need implement all of the str2value_impl // functions to support more type, but sometimes we need implement the // str2value_impl near the type define + template struct str2value_impl { - void operator()(T& value, const char* str, size_t len = std::string::npos, + void operator()(T& value, adt::StringRef str, std::enable_if_t::value || std::is_integral::value>* = nullptr) { std::stringstream ss; - if (len != std::string::npos) { - ss << std::string(str, str + len); - } else { - ss << str; - } + ss << str; ss >> value; } }; @@ -36,57 +33,51 @@ struct str2value_impl { // TODO: return bool value to check parse result template <> struct str2value_impl { - void operator()(int& value, const char* cstr, - size_t len = std::string::npos); + void operator()(int& value, adt::StringRef str); }; template struct str2value { - T operator()(const char* str, size_t len = std::string::npos) { + T operator()(adt::StringRef str) { T ret; - str2value_impl()(ret, str, len); + str2value_impl()(ret, str); return ret; } }; template struct str2value_impl> { - void operator()(std::pair& pair, const char* str, - size_t len = std::string::npos) { + void operator()(std::pair& pair, adt::StringRef str) { size_t i = 0; - for (; i < len && str[i] != '\0'; ++i) { + for (; i < str.size() && str[i] != '\0'; ++i) { if (str[i] == '=') { - pair.first = str2value()(str, i); - pair.second = str2value()(str + i + 1); + pair.first = str2value()(str.slice(0, i)); + pair.second = str2value()(str.slice(i + 1)); break; } } - if (i == '\0' || i == len) { - pair.first = str2value()(str, i); - } } }; template struct str2value_impl> { - void operator()(std::vector& vec, const char* str, - size_t len = std::string::npos) { + void operator()(std::vector& vec, adt::StringRef str) { size_t i = 0, j = 0; - for (; j < len && str[j] != '\0'; ++j) { + for (; j < str.size() && str[j] != '\0'; ++j) { if (str[j] == ',') { - vec.push_back(str2value()(str + i, j)); + vec.push_back(str2value()(str.slice(i, j))); i = j + 1; } } - vec.push_back(str2value()(str + i, j - i)); + vec.push_back(str2value()(str.slice(i, j))); } }; template -T get_env_value(const char* str, +T get_env_value(adt::StringRef str, std::__void_t() >> std::declval())>* = nullptr) { - auto env_value_str = std::getenv(str); + auto env_value_str = std::getenv(str.data()); if (!env_value_str) { return {}; } @@ -96,8 +87,8 @@ T get_env_value(const char* str, template typename std::enable_if>::value, T>::type -get_env_value(const char* str) { - auto env_value_str = std::getenv(str); +get_env_value(adt::StringRef str) { + auto env_value_str = std::getenv(str.data()); if (!env_value_str) { return {}; } @@ -106,8 +97,8 @@ get_env_value(const char* str) { template typename std::enable_if>::value, T>::type -get_env_value(const char* str) { - auto env_value_str = std::getenv(str); +get_env_value(adt::StringRef str) { + auto env_value_str = std::getenv(str.data()); if (!env_value_str) { return {}; } @@ -116,8 +107,18 @@ get_env_value(const char* str) { template inline typename std::enable_if::value, T>::type -get_env_value(const char* str) { - return std::getenv(str); +get_env_value(adt::StringRef str) { + return std::getenv(str.data()); +} + +template +inline typename std::enable_if::value, T>::type +get_env_value(adt::StringRef str) { + auto env_value_str = std::getenv(str.data()); + if (!env_value_str) { + return {}; + } + return adt::StringRef(env_value_str); } } // namespace hook \ No newline at end of file diff --git a/include/hook.h b/include/hook.h index 08d7d51..af62a20 100644 --- a/include/hook.h +++ b/include/hook.h @@ -29,6 +29,7 @@ void install_hook(); struct OriginalInfo { const char* libName = nullptr; const void* basePtr = nullptr; + const void* baseHeadPtr = nullptr; void* relaPtr = nullptr; void* oldFuncPtr = nullptr; void** pltTablePtr = nullptr; @@ -40,6 +41,7 @@ struct OriginalInfo { relaPtr = info.relaPtr; oldFuncPtr = info.oldFuncPtr; pltTablePtr = info.pltTablePtr; + baseHeadPtr = info.baseHeadPtr; return *this; } }; @@ -120,20 +122,27 @@ std::string args_to_string(Args... args) { return ss.str(); } -#define IF_ENABLE_LOG_TRACE_AND_ARGS(func) \ - do { \ - int ctrl = enable_log_backtrace(func); \ - if (ctrl) { \ - if (ctrl & 0b10) { \ - MLOG(TRACE, WARN) << func << ": " << args_to_string(args...); \ - } \ - if (ctrl & 0b01) { \ - trace::CallFrames callFrames; \ - callFrames.CollectNative(); \ - callFrames.CollectPython(); \ - MLOG(TRACE, WARN) << func << " with frame:\n" << callFrames; \ - } \ - } \ +#define IF_ENABLE_LOG_TRACE_AND_ARGS(func) \ + do { \ + int ctrl = enable_log_backtrace((func)); \ + if (ctrl) { \ + if (ctrl & 0b10) { \ + auto parser_func = \ + HookRuntimeContext::instance().lookUpArgsParser((func)); \ + MLOG(TRACE, WARN) \ + << func << ": " \ + << (parser_func \ + ? reinterpret_cast( \ + parser_func)(args...) \ + : args_to_string(args...)); \ + } \ + if (ctrl & 0b01) { \ + trace::CallFrames callFrames; \ + callFrames.CollectNative(); \ + callFrames.CollectPython(); \ + MLOG(TRACE, WARN) << func << " with frame:\n" << callFrames; \ + } \ + } \ } while (0) template @@ -276,6 +285,7 @@ struct HookFeatureBase { void* newFunc; void** oldFunc; std::function filter_; + std::function getNewCallback_; }; using WrapFuncGenerator = std::function; @@ -340,6 +350,19 @@ struct __HookFeature : public HookFeatureBase { return newFuncGenerator(libName, symName.c_str(), newFunc); } + __HookFeature& setGetNewCallback( + const std::function& getNewCallback) { + getNewCallback_ = getNewCallback; + return *this; + } + + template + __HookFeature& setArgsParser(std::string (*parser)(Args...)) { + HookRuntimeContext::instance().argsParserMap().emplace( + symName, reinterpret_cast(parser)); + return *this; + } + std::function findUniqueFunc; WrapFuncGenerator newFuncGenerator; }; @@ -372,6 +395,9 @@ struct MemberDetector(*iter) is a pointer and it's point to // std::get<1>(*iter) then there will return nullptr *iter->oldFunc = info.oldFuncPtr; + if (iter->getNewCallback_) { + iter->getNewCallback_(info); + } return iter->getNewFunc(info.libName); } }; diff --git a/include/hook_context.h b/include/hook_context.h index 5355174..d78a1cb 100644 --- a/include/hook_context.h +++ b/include/hook_context.h @@ -179,6 +179,15 @@ class HookRuntimeContext { size_t getCallCount(const std::string& libName, const std::string& symName); + void* lookUpArgsParser(const std::string& name) { + auto iter = args_parser_map_.find(name); + return iter == args_parser_map_.end() ? nullptr : iter->second; + } + + std::unordered_map& argsParserMap() { + return args_parser_map_; + } + struct TypeInfoHash { size_t operator()(const std::type_info* ti) const { return ti->hash_code(); @@ -197,6 +206,7 @@ class HookRuntimeContext { last_index_map_; std::unordered_map, TypeInfoHash> id_map_; + std::unordered_map args_parser_map_; }; template diff --git a/include/logger/StringRef.h b/include/logger/StringRef.h index 5d1143c..ab4c100 100644 --- a/include/logger/StringRef.h +++ b/include/logger/StringRef.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -24,7 +25,8 @@ class StringRef { StringRef(const char* str) : size_(strlen(str)), str_(str) {} - StringRef(const char* str, size_t size) : size_(size), str_(str) {} + StringRef(const char* str, size_t size) + : size_(size != std::string::npos ? size : strlen(str_)), str_(str) {} StringRef(const_iterator begin, const_iterator end) : size_(std::distance(begin, end)), str_(begin) {} @@ -40,13 +42,20 @@ class StringRef { StringRef& operator=(const StringRef& other) = default; StringRef& operator=(StringRef&& other) = default; + char operator[](size_t i) const { + assert(i < size_); + return str_[i]; + } + const CharT* c_str() const { return str_; } bool empty() const { return !size_ || !str_; } size_t size() const { return size_; } - std::string str() const { return std::string(str_, str_ + size_); } + std::string str() const { + return size_ != 0 ? std::string(str_, str_ + size_) : std::string(); + } const char* data() const { return str_; } @@ -85,6 +94,9 @@ class StringRef { return StringRef(this->begin(), this->end() - size); } + StringRef slice(size_t s, size_t e) { return StringRef(str_ + s, e - s); } + StringRef slice(size_t s) { return this->slice(s, this->size()); } + bool startsWith(StringRef prefix) { if (prefix.size() > this->size()) return false; return drop_back(this->size() - prefix.size()) == prefix; diff --git a/include/logger/logger.h b/include/logger/logger.h index 578bb2f..63a8287 100644 --- a/include/logger/logger.h +++ b/include/logger/logger.h @@ -40,10 +40,13 @@ enum class LogModule { profile, trace, hook, python, memory, debug, last }; class LogModuleHelper { public: static auto& enum_strs() { - static std::array strs = {"PROFILE", "TRACE", "HOOK", - "PYTHON", "MEMORY", "LAST"}; + static std::array strs = { + "PROFILE", "TRACE", "HOOK", "PYTHON", "MEMORY", "DEBUG", "LAST"}; + static_assert(sizeof(strs) / sizeof(const char*) == + static_cast(LogModule::last) + 1); return strs; } + static auto begin() { return enum_strs().begin(); } static auto end() { return enum_strs().end(); } diff --git a/include/util.h b/include/util.h new file mode 100644 index 0000000..d4b3c5e --- /dev/null +++ b/include/util.h @@ -0,0 +1,7 @@ +#include + +namespace hook { + +std::string prettyFormatSize(size_t bytes); + +} // namespace hook diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e32fba7..d5edaef 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -12,12 +12,14 @@ find_package(PythonLibs REQUIRED) add_library(cuda_mock STATIC + util.cpp env_util.cpp GlobalVarMgr.cpp logger.cpp env_mgr.cpp backtrace.cpp statistic.cpp + elf_parser.cpp hook_context.cpp hook.cpp cuda_op_tracer.cpp diff --git a/lib/elf_parser.cpp b/lib/elf_parser.cpp new file mode 100644 index 0000000..e1060e2 --- /dev/null +++ b/lib/elf_parser.cpp @@ -0,0 +1,200 @@ +#include "elf_parser.h" + +#include "util.h" + +namespace hook { + +CachedSymbolTable::StringRefIterator::StringRefIterator(const char *str) + : str_(str) {} + +CachedSymbolTable::StringRefIterator & +CachedSymbolTable::StringRefIterator::operator++() & { + size_t len = strlen(str_); + ++len; + str_ += len; + return *this; +} + +CachedSymbolTable::StringRefIterator +CachedSymbolTable::StringRefIterator::operator++(int) & { + auto ret = StringRefIterator(str_); + ++*this; + return ret; +} + +adt::StringRef CachedSymbolTable::StringRefIterator::operator*() { + return adt::StringRef(str_); +} + +bool CachedSymbolTable::StringRefIterator::operator==( + const StringRefIterator &other) const { + return str_ == other.str_; +} +bool CachedSymbolTable::StringRefIterator::operator!=( + const StringRefIterator &other) const { + return !(*this == other); +} + +const void *CachedSymbolTable::StringRefIterator::data() const { return str_; } + +CachedSymbolTable::CachedSymbolTable( + const std::string &name, const void *base_address, + const std::vector §ion_names) + : libName(name), + ifs(name), + base_address(base_address), + section_names(section_names) { + CHECK(ifs.is_open(), "can't open file:{}", name); + MLOG(DEBUG, INFO) << name << " base address:" << base_address; + ifs.read(reinterpret_cast(&elf_header), sizeof(elf_header)); + parse_named_section(); + parse_section_header(); + load_symbol_table(); + for (size_t i = 0; i < sections_.size(); ++i) { + MLOG(DEBUG, INFO) << "found section:" << i + << " name:" << getSectionName(i) + << " size:" << prettyFormatSize(sections_[i].sh_size); + } +} + +CachedSymbolTable::StringRefIterator CachedSymbolTable::strtab_begin( + const char *str) const { + return CachedSymbolTable::StringRefIterator(str); +} + +CachedSymbolTable::StringRefIterator CachedSymbolTable::strtab_end( + const char *str) const { + return CachedSymbolTable::StringRefIterator(str); +} + +void CachedSymbolTable::move_to_section_header() { + ifs.seekg(elf_header.e_shoff, std::ios::beg); +} + +void CachedSymbolTable::move_to_section_header(size_t index) { + MLOG(DEBUG, INFO) << "move to:" << index; + move_to_section_header(); + size_t shstr_h_oft = sizeof(ElfW(Shdr)) * index; + ifs.seekg(shstr_h_oft, std::ios::cur); +} + +adt::StringRef CachedSymbolTable::getSectionName(size_t index) const { + return adt::StringRef(§ion_header_str.at(sections_.at(index).sh_name)); +} + +size_t CachedSymbolTable::find_section(adt::StringRef name) const { + size_t index = 0; + for (; index < sections_.size(); index++) { + if (getSectionName(index) == name) { + break; + } + } + return index; +} + +void CachedSymbolTable::load_symbol_table() { + auto symbol_tb = load_section_data(".symtab"); + auto xpu_symbol_tb = load_section_data("XPU_KERNEL"); + auto strtab_buf = load_section_data(".strtab"); +#if 0 + auto [begin, end] = strtab_range(buf.data(), buf.size()); +#endif + for (size_t i = 0; i < symbol_tb.size(); ++i) { + if (strtab_buf.size() <= symbol_tb[i].st_name) { + MLOG(DEBUG, INFO) + << "symbol_tb[" << i << "].st_name(" << symbol_tb[i].st_name + << ") over buf size:" << strtab_buf.size(); + continue; + } + symbol_table.emplace( + symbol_tb[i].st_value, + adt::StringRef(&strtab_buf[symbol_tb[i].st_name]).str()); + if (symbol_tb[i].st_value > max_address_) { + max_address_ = symbol_tb[i].st_value; + } + if (symbol_tb[i].st_value < min_address_) { + min_address_ = symbol_tb[i].st_value; + } + } + + for (size_t i = 0; i < xpu_symbol_tb.size(); ++i) { + if (strtab_buf.size() <= xpu_symbol_tb[i].st_name) { + MLOG(DEBUG, INFO) << "xpu_symbol_tb[" << i << "].st_name(" + << xpu_symbol_tb[i].st_name + << ") over buf size:" << strtab_buf.size(); + continue; + } + symbol_table.emplace( + xpu_symbol_tb[i].st_value, + adt::StringRef(&strtab_buf[xpu_symbol_tb[i].st_name]).str()); + } + + MLOG(DEBUG, INFO) << libName << "\naddress range:" << min_address_ << "~" + << max_address_; +} + +void CachedSymbolTable::parse_section_header() { + CHECK_EQ(sizeof(ElfW(Shdr)), elf_header.e_shentsize); + move_to_section_header(); + sections_.resize(elf_header.e_shnum); + MLOG(DEBUG, INFO) << "elf_header.e_shnum:" << elf_header.e_shnum; + ifs.read(reinterpret_cast(sections_.data()), + sections_.size() * sizeof(sections_[0])); +} + +void CachedSymbolTable::parse_named_section() { + move_to_section_header(elf_header.e_shstrndx); + + ElfW(Shdr) shstr_h; + ifs.read(reinterpret_cast(&shstr_h), sizeof(shstr_h)); + ifs.seekg(shstr_h.sh_offset, std::ios::beg); + + section_header_str.resize(shstr_h.sh_size); + ifs.read(section_header_str.data(), section_header_str.size()); +} + +CachedSymbolTable::OwnerBuf CachedSymbolTable::load_section_data( + adt::StringRef name) { + size_t section_index = find_section(name); + if (section_index >= sections_.size()) { + LOG(INFO) << "can't found secton: " << name; + return {}; + } + ifs.seekg(sections_[section_index].sh_offset, std::ios::beg); + auto result = CachedSymbolTable::OwnerBuf::alloc( + sections_[section_index].sh_size); + ifs.read(reinterpret_cast(result.data()), result.size()); + return result; +} + +const std::string &CachedSymbolTable::lookUpSymbol(const void *func) const { + static std::string empty(""); + auto offset = reinterpret_cast(func) - + reinterpret_cast(base_address); + MLOG(DEBUG, INFO) << "lookup address:" << offset; + auto iter = symbol_table.find(offset); + if (iter == symbol_table.end()) { + MLOG(DEBUG, INFO) << libName + << "\nnot find launch_async symbol offset:" << offset + << " base address:" << base_address + << " func address:" << func << " range(" + << min_address_ << "~" << max_address_; + return empty; + } + return iter->second; +} + +static std::unordered_map> + table; + +CachedSymbolTable *createSymbolTable(const std::string &lib, + const void *address) { + auto iter = table.emplace(lib, new CachedSymbolTable(lib, address)); + return iter.first->second.get(); +} + +CachedSymbolTable *getSymbolTable(const std::string &lib) { + return table[lib].get(); +} + +} // namespace hook \ No newline at end of file diff --git a/lib/env_util.cpp b/lib/env_util.cpp index 27a6e07..958fc4d 100644 --- a/lib/env_util.cpp +++ b/lib/env_util.cpp @@ -2,13 +2,7 @@ namespace hook { -void str2value_impl::operator()(int& value, const char* cstr, size_t len) { - adt::StringRef str; - if (len != std::string::npos) { - str = adt::StringRef(cstr, cstr + len); - } else { - str = cstr; - } +void str2value_impl::operator()(int& value, adt::StringRef str) { auto result = str.toIntegral(); value = result.has_value() ? *result : 0; } diff --git a/lib/hook.cpp b/lib/hook.cpp index 6707338..7327ed3 100644 --- a/lib/hook.cpp +++ b/lib/hook.cpp @@ -5,13 +5,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include #include @@ -26,9 +26,9 @@ // #include "backward.hpp" #include "cuda_types.h" -#include "logger/logger.h" #include "env_mgr.h" #include "env_util.h" +#include "logger/logger.h" #if defined __x86_64__ || defined __x86_64 #define R_JUMP_SLOT R_X86_64_JUMP_SLOT @@ -262,7 +262,8 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { size_t idx = ELF64_R_SYM(plt->r_info); idx = pltTable->dynsym[idx].st_name; - MLOG(HOOK, INFO) << pltTable->symbol_table + idx; //got symbol name from STRTAB + MLOG(HOOK, INFO) << pltTable->symbol_table + + idx; // got symbol name from STRTAB if (!installer.isTargetSymbol(pltTable->symbol_table + idx)) { continue; } @@ -272,7 +273,8 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { if (prot == 0) { return -1; } - if (!(prot & PROT_WRITE)) { //not writable and cannot convert to writable page + if (!(prot & PROT_WRITE)) { // not writable and cannot convert to + // writable page if (mprotect(ALIGN_ADDR(addr), page_size, PROT_READ | PROT_WRITE) != 0) { return -1; @@ -282,6 +284,7 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { originalInfo.libName = pltTable->lib_name.c_str(); originalInfo.basePtr = pltTable->base_addr; originalInfo.relaPtr = pltTable->rela_plt; + originalInfo.baseHeadPtr = pltTable->base_header_addr; originalInfo.pltTablePtr = reinterpret_cast(addr); originalInfo.oldFuncPtr = reinterpret_cast(*reinterpret_cast(addr)); @@ -322,10 +325,12 @@ int retrieve_dyn_lib(struct dl_phdr_info* info, size_t info_size, void* table) { << reinterpret_cast(info->dlpi_phdr) << " info_size:" << info_size; /* - 遍历info->dlpi_phdr数组中的program header, dlpi_phdr是一个ElfW(Phdr)类型的数组。 + 遍历info->dlpi_phdr数组中的program header, + dlpi_phdr是一个ElfW(Phdr)类型的数组。 info->dlpi_phnum是一个ElfW(Half)类型的变量,表示dlpi_phdr数组的长度,也就是header的个数 - ElfW(Phdr)是segment header,表征了segment的各个属性,它是一个宏,在64位系统下是Elf64_Phdr,定义在/usr/include/elf.h中,其定义如下: + ElfW(Phdr)是segment + header,表征了segment的各个属性,它是一个宏,在64位系统下是Elf64_Phdr,定义在/usr/include/elf.h中,其定义如下: typedef struct { Elf64_Word p_type; // Segment type @@ -341,13 +346,16 @@ int retrieve_dyn_lib(struct dl_phdr_info* info, size_t info_size, void* table) { for (size_t header_index = 0; header_index < info->dlpi_phnum; header_index++) { /* - 如果一个elf文件参与动态链接,那么program header中会出现类型为PT_DYNAMIC的header + 如果一个elf文件参与动态链接,那么program + header中会出现类型为PT_DYNAMIC的header 只有这个段是我们所关心的,其他不关心。 */ if (info->dlpi_phdr[header_index].p_type == PT_DYNAMIC) { /* - info->dlpi_addr: 共享对象的虚拟内存起始地址,相对于进程的地址空间。对于可执行文件,这通常是0。 - info->dlpi_phdr[header_index].p_vaddr 第header_index个segment的虚拟起始地址,相对于info->dlpi_addr + info->dlpi_addr: + 共享对象的虚拟内存起始地址,相对于进程的地址空间。对于可执行文件,这通常是0。 + info->dlpi_phdr[header_index].p_vaddr + 第header_index个segment的虚拟起始地址,相对于info->dlpi_addr 那么二者相加,就是segment的虚拟地址,也就是segment在进程中的实际地址 ElfW(Dyn)宏扩展为,定义在/usr/include/elf.h中 diff --git a/lib/logger.cpp b/lib/logger.cpp index bd01a96..2391a18 100644 --- a/lib/logger.cpp +++ b/lib/logger.cpp @@ -34,15 +34,15 @@ namespace hook { template <> struct str2value_impl { - void operator()(logger::LogLevel& lvl, const char* str, - size_t len = std::string::npos) { + void operator()(logger::LogLevel& lvl, adt::StringRef str) { auto iter = std::find(std::begin(gLoggerLevelStringSet()), std::end(gLoggerLevelStringSet()), str); if (iter != std::end(gLoggerLevelStringSet())) { lvl = static_cast( std::distance(std::begin(gLoggerLevelStringSet()), iter)); } else { - lvl = static_cast(::atoi(str)); + // default warning + lvl = logger::LogLevel::warning; } } }; @@ -359,18 +359,33 @@ class LogConsumer : public std::enable_shared_from_this { } while (!exit_.load()); } - void sync_pause_loop() { - exit_.store(true); - if (cfg_->mode == LogConfig::kAsync) { - if (th_ && th_->joinable()) th_->join(); - } - flush_queue(); - fwriteString("[LOG END]\n", cfg_->stream); - fflush(cfg_->stream); + void sync_pause_loop(int signum) { + static std::once_flag onceFlag; + std::call_once(onceFlag, [this, signum = signum]() { + exit_.store(true); + if (cfg_->mode == LogConfig::kAsync) { + if (th_ && th_->joinable()) th_->join(); + } + flush_queue(); + switch (signum) { + case SIGSEGV: + fwriteString("[LOG END reason:SIGSEGV]\n", cfg_->stream); + break; + case SIGABRT: + fwriteString("[LOG END reason:SIGABRT]\n", cfg_->stream); + break; + case SIGTERM: + fwriteString("[LOG END reason:SIGTERM]\n", cfg_->stream); + break; + default: + break; + } + fflush(cfg_->stream); + }); } void report_fatal() { - sync_pause_loop(); + sync_pause_loop(SIGUSR1); // write nullptr statement maybe be motion to front int n = 0; *reinterpret_cast(n) = 0; @@ -438,7 +453,7 @@ void destroy_logger(); void core_dump_handler(int signum) { auto consumer = LogStreamCollection::instance().release_consumer(); // TODO: move to destructor - consumer->sync_pause_loop(); + consumer->sync_pause_loop(signum); auto on_exit = LogStreamCollection::instance().on_exit(); if (on_exit) on_exit(); exit(signum); @@ -467,9 +482,30 @@ LogStream& LogStream::instance(const LogConfig& cfg) { void setLoggerLevel( std::array(LogModule::last) + 1>& module_set_, LogLevel& level_) { - auto modules = hook::get_env_value< - std::vector>>( - env_mgr::LOG_LEVEL); + adt::StringRef envValue = + hook::get_env_value(env_mgr::LOG_LEVEL); + adt::StringRef mainLeveleStr, modulesStr; + for (auto m = LogModuleHelper::begin(); m != LogModuleHelper::end(); ++m) { + if (envValue.startsWith(*m)) { + modulesStr = envValue; + break; + } + } + if (modulesStr.empty()) { + size_t index = 0; + for (auto chr : envValue) { + if (chr == ',') { + modulesStr = envValue.slice(index + 1); + mainLeveleStr = envValue.slice(0, index); + break; + } + index++; + } + } + + auto modules = hook::str2value< + std::vector>>()(modulesStr); + std::fill(std::begin(module_set_), std::end(module_set_), logger::LogLevel::warning); for (auto name : LogModuleHelper::enum_strs()) { @@ -483,18 +519,7 @@ void setLoggerLevel( module_set_[IntModule] = iter->second; } } - auto default_lvl_iter = - std::find_if(std::begin(gLoggerLevelStringSet()), - std::end(gLoggerLevelStringSet()), [&](const auto& str) { - return std::find_if(modules.begin(), modules.end(), - [&](const auto& env_v) { - return env_v.first == str; - }) != modules.end(); - }); - if (default_lvl_iter != std::end(gLoggerLevelStringSet())) { - level_ = static_cast(default_lvl_iter - - std::begin(gLoggerLevelStringSet())); - } + level_ = hook::str2value()(mainLeveleStr); } LogStream::LogStream(std::shared_ptr& logConsumer, @@ -534,7 +559,7 @@ void destroy_logger() { LogStreamCollection::instance().release_all_stream(); auto consumer = LogStreamCollection::instance().release_consumer(); // TODO: move to destructor - consumer->sync_pause_loop(); + consumer->sync_pause_loop(0); } thread_local std::chrono::high_resolution_clock::duration diff --git a/lib/statistic.cpp b/lib/statistic.cpp index e9fad71..be5b42d 100644 --- a/lib/statistic.cpp +++ b/lib/statistic.cpp @@ -6,28 +6,10 @@ #include "hook.h" #include "hook_context.h" +#include "util.h" namespace hook { -namespace { - -std::string prettyFormatSize(size_t bytes) { - const char* sizes[] = {"B", "KB", "MB", "GB"}; - size_t order = 0; - double size = static_cast(bytes); - - while (size >= 1024 && order < sizeof(sizes) / sizeof(sizes[0]) - 1) { - order++; - size /= 1024; - } - - std::ostringstream out; - out << std::fixed << std::setprecision(2) << size << " " << sizes[order]; - return out.str(); -} - -} // namespace - std::string shortLibName(const std::string& full_lib_name) { #if 1 auto pos = full_lib_name.find_last_of('/'); @@ -40,8 +22,8 @@ std::string shortLibName(const std::string& full_lib_name) { std::ostream& operator<<(std::ostream& os, const MemoryStatisticCollection::PtrIdentity& id) { - os << "id(" << id.lib << "),devId:" << id.devId - << ",kind:" << id.kind << ")"; + os << "id(" << id.lib << "),devId:" << id.devId << ",kind:" << id.kind + << ")"; return os; } diff --git a/lib/util.cpp b/lib/util.cpp new file mode 100644 index 0000000..601a9e3 --- /dev/null +++ b/lib/util.cpp @@ -0,0 +1,21 @@ +#include +#include +#include + +namespace hook { +std::string prettyFormatSize(size_t bytes) { + const char* sizes[] = {"B", "KB", "MB", "GB"}; + size_t order = 0; + double size = static_cast(bytes); + + while (size >= 1024 && order < sizeof(sizes) / sizeof(sizes[0]) - 1) { + order++; + size /= 1024; + } + + std::ostringstream out; + out << std::fixed << std::setprecision(2) << size << " " << sizes[order]; + return out.str(); +} + +} // namespace hook diff --git a/lib/xpu_mock.cpp b/lib/xpu_mock.cpp index 9fd5d38..9cb45f7 100644 --- a/lib/xpu_mock.cpp +++ b/lib/xpu_mock.cpp @@ -13,6 +13,7 @@ #include #include "backtrace.h" +#include "elf_parser.h" #include "hook.h" #include "hooks/print_hook.h" #include "logger/StringRef.h" @@ -94,10 +95,14 @@ DEF_FUNCTION_INT(xpu_set_device, int devid) { } DEF_FUNCTION_INT(xpu_launch_async, void* func) { - // TODO: get symbol name from symbol table return origin_xpu_launch_async(func); } +std::string launch_args_parser(void* func) { + auto libName = hook::HookRuntimeContext::instance().curLibName(); + return hook::getSymbolTable(libName)->lookUpSymbol(func); +} + DEF_FUNCTION_INT(xpu_stream_create, void** pstream) { return origin_xpu_stream_create(pstream); } @@ -163,8 +168,11 @@ DEF_FUNCTION_INT(cudaMemcpy, void* dst, const void* src, size_t count, return origin_cudaMemcpy(dst, src, count, kind); } -#define BUILD_FEATURE(name) \ - hook::FHookFeature(STR_TO_TYPE(#name), &name, &origin_##name) +/// need gcc version > 9.0 +/// #define BUILD_FEATURE(name) hook::FHookFeature(STR_TO_TYPE(#name), &name, +/// &origin_##name) + +#define BUILD_FEATURE(name) hook::HookFeature(#name, &name, &origin_##name) class XpuRuntimeApiHook : public hook::HookInstallerWrap { public: @@ -173,19 +181,31 @@ class XpuRuntimeApiHook : public hook::HookInstallerWrap { !adt::StringRef(name).contain("libcudart.so"); } - hook::FHookFeature symbols[14] = { - BUILD_FEATURE(xpu_malloc), BUILD_FEATURE(xpu_free), - BUILD_FEATURE(xpu_current_device), BUILD_FEATURE(xpu_set_device), - BUILD_FEATURE(xpu_wait), BUILD_FEATURE(xpu_memcpy), - BUILD_FEATURE(xpu_launch_async), BUILD_FEATURE(xpu_stream_create), + // need gcc version > 9.0 + // hook::FHookFeature symbols[14] = { + hook::HookFeature symbols[14] = { + BUILD_FEATURE(xpu_malloc), + BUILD_FEATURE(xpu_free), + BUILD_FEATURE(xpu_current_device), + BUILD_FEATURE(xpu_set_device), + BUILD_FEATURE(xpu_wait), + BUILD_FEATURE(xpu_memcpy), + BUILD_FEATURE(xpu_launch_async) + .setGetNewCallback([](const hook::OriginalInfo& info) { + hook::createSymbolTable(info.libName, info.baseHeadPtr); + }) + .setArgsParser(&launch_args_parser), + BUILD_FEATURE(xpu_stream_create), BUILD_FEATURE(xpu_stream_destroy), - BUILD_FEATURE(cudaMalloc), BUILD_FEATURE(cudaFree), - BUILD_FEATURE(cudaMemcpy), BUILD_FEATURE(cudaSetDevice), + BUILD_FEATURE(cudaMalloc), + BUILD_FEATURE(cudaFree), + BUILD_FEATURE(cudaMemcpy), + BUILD_FEATURE(cudaSetDevice), BUILD_FEATURE(cudaGetDevice), }; - void onSuccess() { LOG(WARN) << "install " << curSymName() << " success"; } + void onSuccess() { LOG(INFO) << "install " << curSymName() << " success"; } }; struct PatchRuntimeHook : public hook::HookInstallerWrap { diff --git a/test/cpp_test/mock_api_test.cpp b/test/cpp_test/mock_api_test.cpp index df17386..8537f14 100644 --- a/test/cpp_test/mock_api_test.cpp +++ b/test/cpp_test/mock_api_test.cpp @@ -7,6 +7,7 @@ #include "GlobalVarMgr.h" #include "cuda_mock.h" +#include "elf_parser.h" #include "gtest/gtest.h" #include "hook.h" #include "logger/logger_stl.h" @@ -17,4 +18,4 @@ TEST(MockAnyHook, base) { dh_any_hook_install(); int ret = mock::foo(nullptr); EXPECT_EQ(ret, 0); -} \ No newline at end of file +} diff --git a/test/cpp_test/test_elf_parser.cpp b/test/cpp_test/test_elf_parser.cpp new file mode 100644 index 0000000..cd8029b --- /dev/null +++ b/test/cpp_test/test_elf_parser.cpp @@ -0,0 +1,48 @@ +#include + +#include +#include + +#include "GlobalVarMgr.h" +#include "elf_parser.h" +#include "gtest/gtest.h" +#include "hook.h" +#include "logger/logger_stl.h" + +using namespace hook; + +namespace { +int foo(void*) { return 0; } + +void* org_foo = nullptr; + +class SymbolParserHook : public hook::HookInstallerWrap { + public: + bool targetLib(const char* name) { + return adt::StringRef(name).contain("libmock_api.so"); + } + hook::HookFeature symbols[1] = { + HookFeature("_ZN4mock3fooEPv", &foo, &org_foo) + .setGetNewCallback([](const hook::OriginalInfo& info) { + hook::createSymbolTable(info.libName, info.baseHeadPtr); + })}; + + void onSuccess() {} +}; + +} // namespace + +TEST(TestHookWrap, symbol_parser) { + auto sh = std::make_shared(); + sh->install(); +} + +// TEST(MockAnyHook, symbol) { +// hook::CachedSymbolTable ctb("./test/cpp_test/mock_api/libmock_api.so", +// nullptr); auto strtab = ctb.load_section_data(".strtab"); LOG(WARN) << +// "strtab size:" << strtab.size(); + +// for (auto it : ctb.getSymbolTable()) { +// LOG(WARN) << it.second; +// } +// } diff --git a/test/cpp_test/test_hook.cpp b/test/cpp_test/test_hook.cpp index f4ab0a1..6a001f3 100644 --- a/test/cpp_test/test_hook.cpp +++ b/test/cpp_test/test_hook.cpp @@ -4,6 +4,7 @@ #include #include "GlobalVarMgr.h" +#include "elf_parser.h" #include "gtest/gtest.h" #include "hook.h" #include "logger/logger_stl.h"