Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
add elf parser
Browse files Browse the repository at this point in the history
  • Loading branch information
lipracer committed Jul 16, 2024
1 parent 7ae2400 commit f29fc15
Show file tree
Hide file tree
Showing 12 changed files with 352 additions and 18 deletions.
76 changes: 76 additions & 0 deletions include/elf_parser.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include <elf.h>
#include <link.h>

#include <fstream>
#include <string>
#include <unordered_map>
#include <vector>

#include "logger/logger.h"
#include "macro.h"

namespace hook {

class HOOK_API CachedSymbolTable {
public:
struct StringRefIterator {
StringRefIterator(const char *str);

StringRefIterator &operator++() &;
StringRefIterator operator++(int) &;

adt::StringRef operator*();

bool operator==(const StringRefIterator &other) const;
bool operator!=(const StringRefIterator &other) const;
const void *data() const;

private:
const char *str_;
};

StringRefIterator strtab_begin(const char *str) const;

StringRefIterator strtab_end(const char *str) const;

CachedSymbolTable(const std::string &name, const void *base_address);

void move_to_section_header();

void move_to_section_header(size_t index);

adt::StringRef getSectionName(size_t index) const;
size_t find_section(adt::StringRef name) const;

void load_symbol_table();
void parse_section();
void parse_named_section();

const std::string &lookUpSymbol(const void *func) const;

const std::unordered_map<size_t, std::string> &getSymbolTable() const {
return symbol_table;
}

size_t min_addrtess() const { return min_address_; }
size_t max_addrtess() const { return max_address_; }

private:
std::string libName;
std::ifstream ifs;
ElfW(Ehdr) elf_header;
std::vector<char> section_header_str;
std::vector<ElfW(Shdr)> sections;
std::vector<std::string> strtab;
std::unordered_map<size_t, std::string> symbol_table;
const void *base_address;
size_t min_address_ = -1;
size_t max_address_ = 0;
};

CachedSymbolTable *createSymbolTable(const std::string &lib,
const void *address);

CachedSymbolTable *getSymbolTable(const std::string &lib);

} // namespace hook
4 changes: 3 additions & 1 deletion include/env_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ namespace hook {
// if we use the function template we need implement all of the str2value_impl
// functions to support more type, but sometimes we need implement the
// str2value_impl near the type define

// TODO: replace C style const char* and length with StringRef
template <typename T>
struct str2value_impl {
void operator()(T& value, const char* str, size_t len = std::string::npos,
Expand Down Expand Up @@ -57,7 +59,7 @@ struct str2value_impl<std::pair<K, V>> {
for (; i < len && str[i] != '\0'; ++i) {
if (str[i] == '=') {
pair.first = str2value<K>()(str, i);
pair.second = str2value<V>()(str + i + 1);
pair.second = str2value<V>()(str + i + 1, len - i - 1);
break;
}
}
Expand Down
12 changes: 12 additions & 0 deletions include/hook.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void install_hook();
struct OriginalInfo {
const char* libName = nullptr;
const void* basePtr = nullptr;
const void* baseHeadPtr = nullptr;
void* relaPtr = nullptr;
void* oldFuncPtr = nullptr;
void** pltTablePtr = nullptr;
Expand All @@ -40,6 +41,7 @@ struct OriginalInfo {
relaPtr = info.relaPtr;
oldFuncPtr = info.oldFuncPtr;
pltTablePtr = info.pltTablePtr;
baseHeadPtr = info.baseHeadPtr;
return *this;
}
};
Expand Down Expand Up @@ -276,6 +278,7 @@ struct HookFeatureBase {
void* newFunc;
void** oldFunc;
std::function<bool(void)> filter_;
std::function<void(const OriginalInfo&)> getNewCallback_;
};

using WrapFuncGenerator = std::function<void*(const char*, const char*, void*)>;
Expand Down Expand Up @@ -340,6 +343,12 @@ struct __HookFeature : public HookFeatureBase {
return newFuncGenerator(libName, symName.c_str(), newFunc);
}

__HookFeature& setGetNewCallback(
const std::function<void(const OriginalInfo&)>& getNewCallback) {
getNewCallback_ = getNewCallback;
return *this;
}

std::function<void*(size_t)> findUniqueFunc;
WrapFuncGenerator newFuncGenerator;
};
Expand Down Expand Up @@ -372,6 +381,9 @@ struct MemberDetector<DerivedT,
// TODO: if std::get<2>(*iter) is a pointer and it's point to
// std::get<1>(*iter) then there will return nullptr
*iter->oldFunc = info.oldFuncPtr;
if (iter->getNewCallback_) {
iter->getNewCallback_(info);
}
return iter->getNewFunc(info.libName);
}
};
Expand Down
7 changes: 5 additions & 2 deletions include/logger/StringRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ class StringRef {

StringRef(const char* str) : size_(strlen(str)), str_(str) {}

StringRef(const char* str, size_t size) : size_(size), str_(str) {}
StringRef(const char* str, size_t size)
: size_(size != std::string::npos ? size : strlen(str_)), str_(str) {}

StringRef(const_iterator begin, const_iterator end)
: size_(std::distance(begin, end)), str_(begin) {}
Expand All @@ -46,7 +47,9 @@ class StringRef {

size_t size() const { return size_; }

std::string str() const { return std::string(str_, str_ + size_); }
std::string str() const {
return size_ != 0 ? std::string(str_, str_ + size_) : std::string();
}

const char* data() const { return str_; }

Expand Down
7 changes: 5 additions & 2 deletions include/logger/logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@ enum class LogModule { profile, trace, hook, python, memory, debug, last };
class LogModuleHelper {
public:
static auto& enum_strs() {
static std::array<const char*, 6> strs = {"PROFILE", "TRACE", "HOOK",
"PYTHON", "MEMORY", "LAST"};
static std::array<const char*, 7> strs = {
"PROFILE", "TRACE", "HOOK", "PYTHON", "MEMORY", "DEBUG", "LAST"};
static_assert(sizeof(strs) / sizeof(const char*) ==
static_cast<size_t>(LogModule::last) + 1);
return strs;
}

static auto begin() { return enum_strs().begin(); }
static auto end() { return enum_strs().end(); }

Expand Down
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_library(cuda_mock STATIC
env_mgr.cpp
backtrace.cpp
statistic.cpp
elf_parser.cpp
hook_context.cpp
hook.cpp
cuda_op_tracer.cpp
Expand Down
182 changes: 182 additions & 0 deletions lib/elf_parser.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#include "elf_parser.h"
namespace hook {

CachedSymbolTable::StringRefIterator::StringRefIterator(const char *str)
: str_(str) {}

CachedSymbolTable::StringRefIterator &
CachedSymbolTable::StringRefIterator::operator++() & {
size_t len = strlen(str_);
++len;
str_ += len;
return *this;
}

CachedSymbolTable::StringRefIterator
CachedSymbolTable::StringRefIterator::operator++(int) & {
auto ret = StringRefIterator(str_);
++*this;
return ret;
}

adt::StringRef CachedSymbolTable::StringRefIterator::operator*() {
return adt::StringRef(str_);
}

bool CachedSymbolTable::StringRefIterator::operator==(
const StringRefIterator &other) const {
return str_ == other.str_;
}
bool CachedSymbolTable::StringRefIterator::operator!=(
const StringRefIterator &other) const {
return !(*this == other);
}

const void *CachedSymbolTable::StringRefIterator::data() const { return str_; }

CachedSymbolTable::CachedSymbolTable(const std::string &name,
const void *base_address)
: libName(name), ifs(name), base_address(base_address) {
MLOG(DEBUG, INFO) << name << " base address:" << base_address;
ifs.read(reinterpret_cast<char *>(&elf_header), sizeof(elf_header));
parse_named_section();
parse_section();
load_symbol_table();
for (size_t i = 0; i < sections.size(); ++i) {
MLOG(DEBUG, INFO) << "found section:" << i
<< " name:" << getSectionName(i);
}
}

CachedSymbolTable::StringRefIterator CachedSymbolTable::strtab_begin(
const char *str) const {
return CachedSymbolTable::StringRefIterator(str);
}

CachedSymbolTable::StringRefIterator CachedSymbolTable::strtab_end(
const char *str) const {
return CachedSymbolTable::StringRefIterator(str);
}

void CachedSymbolTable::move_to_section_header() {
ifs.seekg(elf_header.e_shoff, std::ios::beg);
}

void CachedSymbolTable::move_to_section_header(size_t index) {
MLOG(DEBUG, INFO) << "move to:" << index;
move_to_section_header();
size_t shstr_h_oft = sizeof(ElfW(Shdr)) * index;
ifs.seekg(shstr_h_oft, std::ios::cur);
}

adt::StringRef CachedSymbolTable::getSectionName(size_t index) const {
return adt::StringRef(&section_header_str.at(sections.at(index).sh_name));
}

size_t CachedSymbolTable::find_section(adt::StringRef name) const {
size_t index = 0;
for (; index < sections.size(); index++) {
if (getSectionName(index) == name) {
break;
}
}
return index;
}

void CachedSymbolTable::load_symbol_table() {
size_t symtab_h_index = find_section(".symtab");
if (symtab_h_index >= sections.size()) {
LOG(WARN) << "can't found secton: .symtab";
return;
}
ifs.seekg(sections[symtab_h_index].sh_offset, std::ios::beg);
std::vector<ElfW(Sym)> symbol_tb(sections[symtab_h_index].sh_size /
sizeof(symbol_tb[0]));
ifs.read(reinterpret_cast<char *>(symbol_tb.data()),
sections[symtab_h_index].sh_size);

size_t strtab_h_index = find_section(".strtab");
if (strtab_h_index >= sections.size()) {
LOG(WARN) << "can't found secton: .strtab";
return;
}
ifs.seekg(sections[strtab_h_index].sh_offset, std::ios::beg);
std::vector<char> buf(sections[strtab_h_index].sh_size);
ifs.read(buf.data(), buf.size());

auto strtab_begin_iter = strtab_begin(buf.data());
auto strtab_end_iter = strtab_end(buf.data() + buf.size());
for (StringRefIterator iter = strtab_begin_iter; iter != strtab_end_iter;
++iter) {
strtab.push_back((*iter).str());
}
for (size_t i = 0; i < symbol_tb.size() / sizeof(ElfW(Sym)); ++i) {
if (strtab.size() <= symbol_tb[i].st_name) {
LOG(WARN) << "symbol_tb[" << i << "].st_name("
<< symbol_tb[i].st_name
<< ") over strtab size:" << strtab.size();
}
symbol_table.emplace(symbol_tb[i].st_value,
strtab[symbol_tb[i].st_name]);
if (symbol_tb[i].st_value > max_address_) {
max_address_ = symbol_tb[i].st_value;
}
if (symbol_tb[i].st_value < min_address_) {
min_address_ = symbol_tb[i].st_value;
}
}
MLOG(DEBUG, INFO) << libName << "\naddress range:" << min_address_ << "~"
<< max_address_;
std::vector<std::string>().swap(strtab);
}
void CachedSymbolTable::parse_section() {
CHECK_EQ(sizeof(ElfW(Shdr)), elf_header.e_shentsize);
move_to_section_header();
sections.resize(elf_header.e_shnum);
MLOG(DEBUG, INFO) << "elf_header.e_shnum:" << elf_header.e_shnum;
ifs.read(reinterpret_cast<char *>(sections.data()),
sections.size() * sizeof(sections[0]));
}
void CachedSymbolTable::parse_named_section() {
move_to_section_header(elf_header.e_shstrndx);

ElfW(Shdr) shstr_h;
ifs.read(reinterpret_cast<char *>(&shstr_h), sizeof(shstr_h));
ifs.seekg(shstr_h.sh_offset, std::ios::beg);

section_header_str.resize(shstr_h.sh_size);
ifs.read(section_header_str.data(), section_header_str.size());
}

const std::string &CachedSymbolTable::lookUpSymbol(const void *func) const {
auto offset = reinterpret_cast<const char *>(func) -
reinterpret_cast<const char *>(base_address);
MLOG(DEBUG, INFO) << "lookup address:" << offset;
auto iter = symbol_table.find(offset);
if (iter == symbol_table.end()) {
LOG(WARN) << libName
<< "\nnot find launch_async symbol offset:" << offset
<< " base address:" << base_address
<< " func address:" << func << " range(" << min_address_
<< "~" << max_address_;
} else {
LOG(WARN) << "found launch_async symbol:" << iter->second;
}
static std::string empty("");
return empty;
}

static std::unordered_map<std::string, std::unique_ptr<CachedSymbolTable>>
table;

CachedSymbolTable *createSymbolTable(const std::string &lib,
const void *address) {
auto iter = table.emplace(lib, new CachedSymbolTable(lib, address));
return iter.first->second.get();
}

CachedSymbolTable *getSymbolTable(const std::string &lib) {
return table[lib].get();
}

} // namespace hook
1 change: 1 addition & 0 deletions lib/hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ int install_hooker(PltTable* pltTable, const hook::HookInstaller& installer) {
originalInfo.libName = pltTable->lib_name.c_str();
originalInfo.basePtr = pltTable->base_addr;
originalInfo.relaPtr = pltTable->rela_plt;
originalInfo.baseHeadPtr = pltTable->base_header_addr;
originalInfo.pltTablePtr = reinterpret_cast<void**>(addr);
originalInfo.oldFuncPtr =
reinterpret_cast<void*>(*reinterpret_cast<size_t*>(addr));
Expand Down
Loading

0 comments on commit f29fc15

Please sign in to comment.