diff --git a/include/elf_parser.h b/include/elf_parser.h new file mode 100644 index 0000000..6898370 --- /dev/null +++ b/include/elf_parser.h @@ -0,0 +1,157 @@ +#include +#include + +#include +#include +#include +#include + +#include "logger/logger.h" +#include "macro.h" + +namespace hook { + +class HOOK_API CachedSymbolTable { + public: + struct StringRefIterator { + StringRefIterator(const char *str); + + StringRefIterator &operator++() &; + StringRefIterator operator++(int) &; + + adt::StringRef operator*(); + + bool operator==(const StringRefIterator &other) const; + bool operator!=(const StringRefIterator &other) const; + const void *data() const; + + private: + const char *str_; + }; + + template + struct OwnerBuf { + OwnerBuf() = default; + OwnerBuf(const OwnerBuf &other) = delete; + OwnerBuf &operator=(const OwnerBuf &other) = delete; + + OwnerBuf(OwnerBuf &&other) { *this = std::move(other); } + OwnerBuf &operator=(OwnerBuf &&other) { + std::swap(buf_, other.buf_); + std::swap(size_, other.size_); + return *this; + } + + static OwnerBuf alloc(size_t size) { + OwnerBuf buf; + buf.buf_ = reinterpret_cast(malloc(size * sizeof(T))); + buf.size_ = size; + return buf; + } + + T &operator[](size_t index) { + assert(index < size_); + assert(buf_); + return buf_[index]; + } + + T *data() const { return buf_; } + + size_t size() const { return size_; } + + std::tuple release() { + auto result = std::make_tuple(buf_, size_); + buf_ = 0; + size_ = 0; + return result; + } + + template + friend class OwnerBuf; + + template + OwnerBuf cast() { + OwnerBuf result; + auto [buf, size] = release(); + result.buf_ = reinterpret_cast(buf); + result.size_ = size * sizeof(T) / sizeof(N); + return result; + } + + ~OwnerBuf() { + if (buf_) { + free(buf_); +#ifndef NDEBUG + buf_ = nullptr; + size_ = 0; +#endif + } + } + + private: + T *buf_ = nullptr; + size_t size_ = 0; + }; + + StringRefIterator strtab_begin(const char *str) const; + + StringRefIterator strtab_end(const char *str) const; + + std::tuple strtab_range( + const char *str, size_t size) const { + return std::make_tuple(strtab_begin(str), strtab_begin(str + size)); + } + + CachedSymbolTable(const std::string &name, const void *base_address, + const std::vector §ion_names = {}); + + void move_to_section_header(); + + void move_to_section_header(size_t index); + + adt::StringRef getSectionName(size_t index) const; + size_t find_section(adt::StringRef name) const; + + void load_symbol_table(); + void parse_section_header(); + void parse_named_section(); + + template + OwnerBuf load_section_data(adt::StringRef name) { + auto buf = load_section_data(name); + return buf.cast(); + } + + OwnerBuf load_section_data(adt::StringRef name); + + const std::string &lookUpSymbol(const void *func) const; + + const std::unordered_map &getSymbolTable() const { + return symbol_table; + } + + size_t min_addrtess() const { return min_address_; } + size_t max_addrtess() const { return max_address_; } + + const std::vector §ions() const { return sections_; } + + private: + std::string libName; + std::ifstream ifs; + ElfW(Ehdr) elf_header; + std::vector section_header_str; + std::vector sections_; + std::vector strtab; + std::unordered_map symbol_table; + const void *base_address; + size_t min_address_ = -1; + size_t max_address_ = 0; + std::vector section_names; +}; + +CachedSymbolTable *createSymbolTable(const std::string &lib, + const void *address); + +CachedSymbolTable *getSymbolTable(const std::string &lib); + +} // namespace hook diff --git a/include/env_util.h b/include/env_util.h index e7f6ab5..92b386f 100644 --- a/include/env_util.h +++ b/include/env_util.h @@ -18,6 +18,8 @@ namespace hook { // if we use the function template we need implement all of the str2value_impl // functions to support more type, but sometimes we need implement the // str2value_impl near the type define + +// TODO: replace C style const char* and length with StringRef template struct str2value_impl { void operator()(T& value, const char* str, size_t len = std::string::npos, @@ -57,7 +59,7 @@ struct str2value_impl> { for (; i < len && str[i] != '\0'; ++i) { if (str[i] == '=') { pair.first = str2value()(str, i); - pair.second = str2value()(str + i + 1); + pair.second = str2value()(str + i + 1, len - i - 1); break; } } diff --git a/include/hook.h b/include/hook.h index 08d7d51..3b61229 100644 --- a/include/hook.h +++ b/include/hook.h @@ -29,6 +29,7 @@ void install_hook(); struct OriginalInfo { const char* libName = nullptr; const void* basePtr = nullptr; + const void* baseHeadPtr = nullptr; void* relaPtr = nullptr; void* oldFuncPtr = nullptr; void** pltTablePtr = nullptr; @@ -40,6 +41,7 @@ struct OriginalInfo { relaPtr = info.relaPtr; oldFuncPtr = info.oldFuncPtr; pltTablePtr = info.pltTablePtr; + baseHeadPtr = info.baseHeadPtr; return *this; } }; @@ -276,6 +278,7 @@ struct HookFeatureBase { void* newFunc; void** oldFunc; std::function filter_; + std::function getNewCallback_; }; using WrapFuncGenerator = std::function; @@ -340,6 +343,12 @@ struct __HookFeature : public HookFeatureBase { return newFuncGenerator(libName, symName.c_str(), newFunc); } + __HookFeature& setGetNewCallback( + const std::function& getNewCallback) { + getNewCallback_ = getNewCallback; + return *this; + } + std::function findUniqueFunc; WrapFuncGenerator newFuncGenerator; }; @@ -372,6 +381,9 @@ struct MemberDetector(*iter) is a pointer and it's point to // std::get<1>(*iter) then there will return nullptr *iter->oldFunc = info.oldFuncPtr; + if (iter->getNewCallback_) { + iter->getNewCallback_(info); + } return iter->getNewFunc(info.libName); } }; diff --git a/include/logger/StringRef.h b/include/logger/StringRef.h index 5d1143c..aa23848 100644 --- a/include/logger/StringRef.h +++ b/include/logger/StringRef.h @@ -24,7 +24,8 @@ class StringRef { StringRef(const char* str) : size_(strlen(str)), str_(str) {} - StringRef(const char* str, size_t size) : size_(size), str_(str) {} + StringRef(const char* str, size_t size) + : size_(size != std::string::npos ? size : strlen(str_)), str_(str) {} StringRef(const_iterator begin, const_iterator end) : size_(std::distance(begin, end)), str_(begin) {} @@ -46,7 +47,9 @@ class StringRef { size_t size() const { return size_; } - std::string str() const { return std::string(str_, str_ + size_); } + std::string str() const { + return size_ != 0 ? std::string(str_, str_ + size_) : std::string(); + } const char* data() const { return str_; } diff --git a/include/logger/logger.h b/include/logger/logger.h index 578bb2f..63a8287 100644 --- a/include/logger/logger.h +++ b/include/logger/logger.h @@ -40,10 +40,13 @@ enum class LogModule { profile, trace, hook, python, memory, debug, last }; class LogModuleHelper { public: static auto& enum_strs() { - static std::array strs = {"PROFILE", "TRACE", "HOOK", - "PYTHON", "MEMORY", "LAST"}; + static std::array strs = { + "PROFILE", "TRACE", "HOOK", "PYTHON", "MEMORY", "DEBUG", "LAST"}; + static_assert(sizeof(strs) / sizeof(const char*) == + static_cast(LogModule::last) + 1); return strs; } + static auto begin() { return enum_strs().begin(); } static auto end() { return enum_strs().end(); } diff --git a/include/util.h b/include/util.h new file mode 100644 index 0000000..d4b3c5e --- /dev/null +++ b/include/util.h @@ -0,0 +1,7 @@ +#include + +namespace hook { + +std::string prettyFormatSize(size_t bytes); + +} // namespace hook diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e32fba7..d5edaef 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -12,12 +12,14 @@ find_package(PythonLibs REQUIRED) add_library(cuda_mock STATIC + util.cpp env_util.cpp GlobalVarMgr.cpp logger.cpp env_mgr.cpp backtrace.cpp statistic.cpp + elf_parser.cpp hook_context.cpp hook.cpp cuda_op_tracer.cpp diff --git a/lib/elf_parser.cpp b/lib/elf_parser.cpp new file mode 100644 index 0000000..15b6749 --- /dev/null +++ b/lib/elf_parser.cpp @@ -0,0 +1,204 @@ +#include "elf_parser.h" + +#include "util.h" + +namespace hook { + +CachedSymbolTable::StringRefIterator::StringRefIterator(const char *str) + : str_(str) {} + +CachedSymbolTable::StringRefIterator & +CachedSymbolTable::StringRefIterator::operator++() & { + size_t len = strlen(str_); + ++len; + str_ += len; + return *this; +} + +CachedSymbolTable::StringRefIterator +CachedSymbolTable::StringRefIterator::operator++(int) & { + auto ret = StringRefIterator(str_); + ++*this; + return ret; +} + +adt::StringRef CachedSymbolTable::StringRefIterator::operator*() { + return adt::StringRef(str_); +} + +bool CachedSymbolTable::StringRefIterator::operator==( + const StringRefIterator &other) const { + return str_ == other.str_; +} +bool CachedSymbolTable::StringRefIterator::operator!=( + const StringRefIterator &other) const { + return !(*this == other); +} + +const void *CachedSymbolTable::StringRefIterator::data() const { return str_; } + +CachedSymbolTable::CachedSymbolTable( + const std::string &name, const void *base_address, + const std::vector §ion_names) + : libName(name), + ifs(name), + base_address(base_address), + section_names(section_names) { + CHECK(ifs.is_open(), "can't open file:{}", name); + MLOG(DEBUG, INFO) << name << " base address:" << base_address; + ifs.read(reinterpret_cast(&elf_header), sizeof(elf_header)); + parse_named_section(); + parse_section_header(); + load_symbol_table(); + for (size_t i = 0; i < sections_.size(); ++i) { + MLOG(DEBUG, INFO) << "found section:" << i + << " name:" << getSectionName(i) + << " size:" << prettyFormatSize(sections_[i].sh_size); + } +} + +CachedSymbolTable::StringRefIterator CachedSymbolTable::strtab_begin( + const char *str) const { + return CachedSymbolTable::StringRefIterator(str); +} + +CachedSymbolTable::StringRefIterator CachedSymbolTable::strtab_end( + const char *str) const { + return CachedSymbolTable::StringRefIterator(str); +} + +void CachedSymbolTable::move_to_section_header() { + ifs.seekg(elf_header.e_shoff, std::ios::beg); +} + +void CachedSymbolTable::move_to_section_header(size_t index) { + MLOG(DEBUG, INFO) << "move to:" << index; + move_to_section_header(); + size_t shstr_h_oft = sizeof(ElfW(Shdr)) * index; + ifs.seekg(shstr_h_oft, std::ios::cur); +} + +adt::StringRef CachedSymbolTable::getSectionName(size_t index) const { + return adt::StringRef(§ion_header_str.at(sections_.at(index).sh_name)); +} + +size_t CachedSymbolTable::find_section(adt::StringRef name) const { + size_t index = 0; + for (; index < sections_.size(); index++) { + if (getSectionName(index) == name) { + break; + } + } + return index; +} + +void CachedSymbolTable::load_symbol_table() { + auto symbol_tb = load_section_data(".symtab"); + auto xpu_symbol_tb = load_section_data("XPU_KERNEL"); + auto strtab_buf = load_section_data(".strtab"); +#if 0 + auto [begin, end] = strtab_range(buf.data(), buf.size()); +#endif + for (size_t i = 0; i < symbol_tb.size(); ++i) { + if (strtab_buf.size() <= symbol_tb[i].st_name) { + MLOG(DEBUG, INFO) + << "symbol_tb[" << i << "].st_name(" << symbol_tb[i].st_name + << ") over buf size:" << strtab_buf.size(); + continue; + } + symbol_table.emplace( + symbol_tb[i].st_value, + adt::StringRef(&strtab_buf[symbol_tb[i].st_name]).str()); + if (symbol_tb[i].st_value > max_address_) { + max_address_ = symbol_tb[i].st_value; + } + if (symbol_tb[i].st_value < min_address_) { + min_address_ = symbol_tb[i].st_value; + } + } + + if (xpu_symbol_tb.size()) { + for (size_t i = 0; i < xpu_symbol_tb.size(); ++i) { + if (strtab_buf.size() <= xpu_symbol_tb[i].st_name) { + MLOG(DEBUG, INFO) << "xpu_symbol_tb[" << i << "].st_name(" + << xpu_symbol_tb[i].st_name + << ") over buf size:" << strtab_buf.size(); + continue; + } + symbol_table.emplace( + xpu_symbol_tb[i].st_value, + adt::StringRef(&strtab_buf[xpu_symbol_tb[i].st_name]).str()); + } + } + + MLOG(DEBUG, INFO) << libName << "\naddress range:" << min_address_ << "~" + << max_address_; + std::vector().swap(strtab); +} + +void CachedSymbolTable::parse_section_header() { + CHECK_EQ(sizeof(ElfW(Shdr)), elf_header.e_shentsize); + move_to_section_header(); + sections_.resize(elf_header.e_shnum); + MLOG(DEBUG, INFO) << "elf_header.e_shnum:" << elf_header.e_shnum; + ifs.read(reinterpret_cast(sections_.data()), + sections_.size() * sizeof(sections_[0])); +} + +void CachedSymbolTable::parse_named_section() { + move_to_section_header(elf_header.e_shstrndx); + + ElfW(Shdr) shstr_h; + ifs.read(reinterpret_cast(&shstr_h), sizeof(shstr_h)); + ifs.seekg(shstr_h.sh_offset, std::ios::beg); + + section_header_str.resize(shstr_h.sh_size); + ifs.read(section_header_str.data(), section_header_str.size()); +} + +CachedSymbolTable::OwnerBuf CachedSymbolTable::load_section_data( + adt::StringRef name) { + size_t section_index = find_section(name); + if (section_index >= sections_.size()) { + LOG(WARN) << "can't found secton: " << name; + return {}; + } + ifs.seekg(sections_[section_index].sh_offset, std::ios::beg); + auto result = CachedSymbolTable::OwnerBuf::alloc( + sections_[section_index].sh_size); + ifs.read(reinterpret_cast(result.data()), result.size()); + return result; +} + +const std::string &CachedSymbolTable::lookUpSymbol(const void *func) const { + auto offset = reinterpret_cast(func) - + reinterpret_cast(base_address); + MLOG(DEBUG, INFO) << "lookup address:" << offset; + auto iter = symbol_table.find(offset); + if (iter == symbol_table.end()) { + LOG(WARN) << libName + << "\nnot find launch_async symbol offset:" << offset + << " base address:" << base_address + << " func address:" << func << " range(" << min_address_ + << "~" << max_address_; + } else { + LOG(WARN) << "found launch_async symbol:" << iter->second; + } + static std::string empty(""); + return empty; +} + +static std::unordered_map> + table; + +CachedSymbolTable *createSymbolTable(const std::string &lib, + const void *address) { + auto iter = table.emplace(lib, new CachedSymbolTable(lib, address)); + return iter.first->second.get(); +} + +CachedSymbolTable *getSymbolTable(const std::string &lib) { + return table[lib].get(); +} + +} // namespace hook \ No newline at end of file diff --git a/lib/hook.cpp b/lib/hook.cpp index 6707338..7327ed3 100644 --- a/lib/hook.cpp +++ b/lib/hook.cpp @@ -5,13 +5,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include #include @@ -26,9 +26,9 @@ // #include "backward.hpp" #include "cuda_types.h" -#include "logger/logger.h" #include "env_mgr.h" #include "env_util.h" +#include "logger/logger.h" #if defined __x86_64__ || defined __x86_64 #define R_JUMP_SLOT R_X86_64_JUMP_SLOT @@ -262,7 +262,8 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { size_t idx = ELF64_R_SYM(plt->r_info); idx = pltTable->dynsym[idx].st_name; - MLOG(HOOK, INFO) << pltTable->symbol_table + idx; //got symbol name from STRTAB + MLOG(HOOK, INFO) << pltTable->symbol_table + + idx; // got symbol name from STRTAB if (!installer.isTargetSymbol(pltTable->symbol_table + idx)) { continue; } @@ -272,7 +273,8 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { if (prot == 0) { return -1; } - if (!(prot & PROT_WRITE)) { //not writable and cannot convert to writable page + if (!(prot & PROT_WRITE)) { // not writable and cannot convert to + // writable page if (mprotect(ALIGN_ADDR(addr), page_size, PROT_READ | PROT_WRITE) != 0) { return -1; @@ -282,6 +284,7 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) { originalInfo.libName = pltTable->lib_name.c_str(); originalInfo.basePtr = pltTable->base_addr; originalInfo.relaPtr = pltTable->rela_plt; + originalInfo.baseHeadPtr = pltTable->base_header_addr; originalInfo.pltTablePtr = reinterpret_cast(addr); originalInfo.oldFuncPtr = reinterpret_cast(*reinterpret_cast(addr)); @@ -322,10 +325,12 @@ int retrieve_dyn_lib(struct dl_phdr_info* info, size_t info_size, void* table) { << reinterpret_cast(info->dlpi_phdr) << " info_size:" << info_size; /* - 遍历info->dlpi_phdr数组中的program header, dlpi_phdr是一个ElfW(Phdr)类型的数组。 + 遍历info->dlpi_phdr数组中的program header, + dlpi_phdr是一个ElfW(Phdr)类型的数组。 info->dlpi_phnum是一个ElfW(Half)类型的变量,表示dlpi_phdr数组的长度,也就是header的个数 - ElfW(Phdr)是segment header,表征了segment的各个属性,它是一个宏,在64位系统下是Elf64_Phdr,定义在/usr/include/elf.h中,其定义如下: + ElfW(Phdr)是segment + header,表征了segment的各个属性,它是一个宏,在64位系统下是Elf64_Phdr,定义在/usr/include/elf.h中,其定义如下: typedef struct { Elf64_Word p_type; // Segment type @@ -341,13 +346,16 @@ int retrieve_dyn_lib(struct dl_phdr_info* info, size_t info_size, void* table) { for (size_t header_index = 0; header_index < info->dlpi_phnum; header_index++) { /* - 如果一个elf文件参与动态链接,那么program header中会出现类型为PT_DYNAMIC的header + 如果一个elf文件参与动态链接,那么program + header中会出现类型为PT_DYNAMIC的header 只有这个段是我们所关心的,其他不关心。 */ if (info->dlpi_phdr[header_index].p_type == PT_DYNAMIC) { /* - info->dlpi_addr: 共享对象的虚拟内存起始地址,相对于进程的地址空间。对于可执行文件,这通常是0。 - info->dlpi_phdr[header_index].p_vaddr 第header_index个segment的虚拟起始地址,相对于info->dlpi_addr + info->dlpi_addr: + 共享对象的虚拟内存起始地址,相对于进程的地址空间。对于可执行文件,这通常是0。 + info->dlpi_phdr[header_index].p_vaddr + 第header_index个segment的虚拟起始地址,相对于info->dlpi_addr 那么二者相加,就是segment的虚拟地址,也就是segment在进程中的实际地址 ElfW(Dyn)宏扩展为,定义在/usr/include/elf.h中 diff --git a/lib/logger.cpp b/lib/logger.cpp index bd01a96..e402a1f 100644 --- a/lib/logger.cpp +++ b/lib/logger.cpp @@ -37,12 +37,14 @@ struct str2value_impl { void operator()(logger::LogLevel& lvl, const char* str, size_t len = std::string::npos) { auto iter = std::find(std::begin(gLoggerLevelStringSet()), - std::end(gLoggerLevelStringSet()), str); + std::end(gLoggerLevelStringSet()), + adt::StringRef(str, len)); if (iter != std::end(gLoggerLevelStringSet())) { lvl = static_cast( std::distance(std::begin(gLoggerLevelStringSet()), iter)); } else { - lvl = static_cast(::atoi(str)); + // default warning + lvl = logger::LogLevel::warning; } } }; @@ -359,18 +361,33 @@ class LogConsumer : public std::enable_shared_from_this { } while (!exit_.load()); } - void sync_pause_loop() { - exit_.store(true); - if (cfg_->mode == LogConfig::kAsync) { - if (th_ && th_->joinable()) th_->join(); - } - flush_queue(); - fwriteString("[LOG END]\n", cfg_->stream); - fflush(cfg_->stream); + void sync_pause_loop(int signum) { + static std::once_flag onceFlag; + std::call_once(onceFlag, [this, signum = signum]() { + exit_.store(true); + if (cfg_->mode == LogConfig::kAsync) { + if (th_ && th_->joinable()) th_->join(); + } + flush_queue(); + switch (signum) { + case SIGSEGV: + fwriteString("[LOG END reason:SIGSEGV]\n", cfg_->stream); + break; + case SIGABRT: + fwriteString("[LOG END reason:SIGABRT]\n", cfg_->stream); + break; + case SIGTERM: + fwriteString("[LOG END reason:SIGTERM]\n", cfg_->stream); + break; + default: + break; + } + fflush(cfg_->stream); + }); } void report_fatal() { - sync_pause_loop(); + sync_pause_loop(SIGUSR1); // write nullptr statement maybe be motion to front int n = 0; *reinterpret_cast(n) = 0; @@ -438,7 +455,7 @@ void destroy_logger(); void core_dump_handler(int signum) { auto consumer = LogStreamCollection::instance().release_consumer(); // TODO: move to destructor - consumer->sync_pause_loop(); + consumer->sync_pause_loop(signum); auto on_exit = LogStreamCollection::instance().on_exit(); if (on_exit) on_exit(); exit(signum); @@ -534,7 +551,7 @@ void destroy_logger() { LogStreamCollection::instance().release_all_stream(); auto consumer = LogStreamCollection::instance().release_consumer(); // TODO: move to destructor - consumer->sync_pause_loop(); + consumer->sync_pause_loop(0); } thread_local std::chrono::high_resolution_clock::duration diff --git a/lib/statistic.cpp b/lib/statistic.cpp index e9fad71..09091c6 100644 --- a/lib/statistic.cpp +++ b/lib/statistic.cpp @@ -6,28 +6,10 @@ #include "hook.h" #include "hook_context.h" +#include "util.h" namespace hook { -namespace { - -std::string prettyFormatSize(size_t bytes) { - const char* sizes[] = {"B", "KB", "MB", "GB"}; - size_t order = 0; - double size = static_cast(bytes); - - while (size >= 1024 && order < sizeof(sizes) / sizeof(sizes[0]) - 1) { - order++; - size /= 1024; - } - - std::ostringstream out; - out << std::fixed << std::setprecision(2) << size << " " << sizes[order]; - return out.str(); -} - -} // namespace - std::string shortLibName(const std::string& full_lib_name) { #if 1 auto pos = full_lib_name.find_last_of('/'); diff --git a/lib/util.cpp b/lib/util.cpp new file mode 100644 index 0000000..601a9e3 --- /dev/null +++ b/lib/util.cpp @@ -0,0 +1,21 @@ +#include +#include +#include + +namespace hook { +std::string prettyFormatSize(size_t bytes) { + const char* sizes[] = {"B", "KB", "MB", "GB"}; + size_t order = 0; + double size = static_cast(bytes); + + while (size >= 1024 && order < sizeof(sizes) / sizeof(sizes[0]) - 1) { + order++; + size /= 1024; + } + + std::ostringstream out; + out << std::fixed << std::setprecision(2) << size << " " << sizes[order]; + return out.str(); +} + +} // namespace hook diff --git a/lib/xpu_mock.cpp b/lib/xpu_mock.cpp index 9fd5d38..09fec27 100644 --- a/lib/xpu_mock.cpp +++ b/lib/xpu_mock.cpp @@ -13,6 +13,7 @@ #include #include "backtrace.h" +#include "elf_parser.h" #include "hook.h" #include "hooks/print_hook.h" #include "logger/StringRef.h" @@ -95,6 +96,8 @@ DEF_FUNCTION_INT(xpu_set_device, int devid) { DEF_FUNCTION_INT(xpu_launch_async, void* func) { // TODO: get symbol name from symbol table + auto libName = hook::HookRuntimeContext::instance().curLibName(); + hook::getSymbolTable(libName)->lookUpSymbol(func); return origin_xpu_launch_async(func); } @@ -163,8 +166,11 @@ DEF_FUNCTION_INT(cudaMemcpy, void* dst, const void* src, size_t count, return origin_cudaMemcpy(dst, src, count, kind); } -#define BUILD_FEATURE(name) \ - hook::FHookFeature(STR_TO_TYPE(#name), &name, &origin_##name) +/// need gcc version > 9.0 +/// #define BUILD_FEATURE(name) hook::FHookFeature(STR_TO_TYPE(#name), &name, +/// &origin_##name) + +#define BUILD_FEATURE(name) hook::HookFeature(#name, &name, &origin_##name) class XpuRuntimeApiHook : public hook::HookInstallerWrap { public: @@ -173,19 +179,30 @@ class XpuRuntimeApiHook : public hook::HookInstallerWrap { !adt::StringRef(name).contain("libcudart.so"); } - hook::FHookFeature symbols[14] = { - BUILD_FEATURE(xpu_malloc), BUILD_FEATURE(xpu_free), - BUILD_FEATURE(xpu_current_device), BUILD_FEATURE(xpu_set_device), - BUILD_FEATURE(xpu_wait), BUILD_FEATURE(xpu_memcpy), - BUILD_FEATURE(xpu_launch_async), BUILD_FEATURE(xpu_stream_create), + // need gcc version > 9.0 + // hook::FHookFeature symbols[14] = { + hook::HookFeature symbols[14] = { + BUILD_FEATURE(xpu_malloc), + BUILD_FEATURE(xpu_free), + BUILD_FEATURE(xpu_current_device), + BUILD_FEATURE(xpu_set_device), + BUILD_FEATURE(xpu_wait), + BUILD_FEATURE(xpu_memcpy), + BUILD_FEATURE(xpu_launch_async) + .setGetNewCallback([](const hook::OriginalInfo& info) { + hook::createSymbolTable(info.libName, info.baseHeadPtr); + }), + BUILD_FEATURE(xpu_stream_create), BUILD_FEATURE(xpu_stream_destroy), - BUILD_FEATURE(cudaMalloc), BUILD_FEATURE(cudaFree), - BUILD_FEATURE(cudaMemcpy), BUILD_FEATURE(cudaSetDevice), + BUILD_FEATURE(cudaMalloc), + BUILD_FEATURE(cudaFree), + BUILD_FEATURE(cudaMemcpy), + BUILD_FEATURE(cudaSetDevice), BUILD_FEATURE(cudaGetDevice), }; - void onSuccess() { LOG(WARN) << "install " << curSymName() << " success"; } + void onSuccess() { LOG(INFO) << "install " << curSymName() << " success"; } }; struct PatchRuntimeHook : public hook::HookInstallerWrap { diff --git a/test/cpp_test/mock_api_test.cpp b/test/cpp_test/mock_api_test.cpp index df17386..8537f14 100644 --- a/test/cpp_test/mock_api_test.cpp +++ b/test/cpp_test/mock_api_test.cpp @@ -7,6 +7,7 @@ #include "GlobalVarMgr.h" #include "cuda_mock.h" +#include "elf_parser.h" #include "gtest/gtest.h" #include "hook.h" #include "logger/logger_stl.h" @@ -17,4 +18,4 @@ TEST(MockAnyHook, base) { dh_any_hook_install(); int ret = mock::foo(nullptr); EXPECT_EQ(ret, 0); -} \ No newline at end of file +} diff --git a/test/cpp_test/test_elf_parser.cpp b/test/cpp_test/test_elf_parser.cpp new file mode 100644 index 0000000..cb82911 --- /dev/null +++ b/test/cpp_test/test_elf_parser.cpp @@ -0,0 +1,46 @@ +#include + +#include +#include + +#include "GlobalVarMgr.h" +#include "elf_parser.h" +#include "gtest/gtest.h" +#include "hook.h" +#include "logger/logger_stl.h" + +using namespace hook; + +namespace { +int foo(void*) { return 0; } + +void* org_foo = nullptr; + +class SymbolParserHook : public hook::HookInstallerWrap { + public: + bool targetLib(const char* name) { return adt::StringRef(name) == ""; } + hook::HookFeature symbols[1] = { + HookFeature("_ZN4mock3fooEPv", &foo, &org_foo) + .setGetNewCallback([](const hook::OriginalInfo& info) { + hook::createSymbolTable(info.libName, info.baseHeadPtr); + })}; + + void onSuccess() {} +}; + +} // namespace + +TEST(TestHookWrap, symbol_parser) { + auto sh = std::make_shared(); + sh->install(); +} + +// TEST(MockAnyHook, symbol) { +// hook::CachedSymbolTable ctb("./test/cpp_test/mock_api/libmock_api.so", +// nullptr); auto strtab = ctb.load_section_data(".strtab"); LOG(WARN) << +// "strtab size:" << strtab.size(); + +// for (auto it : ctb.getSymbolTable()) { +// LOG(WARN) << it.second; +// } +// } diff --git a/test/cpp_test/test_hook.cpp b/test/cpp_test/test_hook.cpp index f4ab0a1..6a001f3 100644 --- a/test/cpp_test/test_hook.cpp +++ b/test/cpp_test/test_hook.cpp @@ -4,6 +4,7 @@ #include #include "GlobalVarMgr.h" +#include "elf_parser.h" #include "gtest/gtest.h" #include "hook.h" #include "logger/logger_stl.h"