Skip to content
This repository has been archived by the owner on Sep 12, 2024. It is now read-only.

Commit

Permalink
add some comment (#196)
Browse files Browse the repository at this point in the history
* refine logger

* add some comment
  • Loading branch information
lipracer authored Jul 19, 2024
1 parent 0aa3e8a commit c09893c
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 48 deletions.
44 changes: 33 additions & 11 deletions include/elf_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,36 @@

namespace hook {

/// @brief elf file parser for find symbol
class HOOK_API CachedSymbolTable {
public:
struct StringRefIterator {
StringRefIterator(const char *str);
// CStrIterator: A class to iterate over C-style strings
struct CStrIterator {
CStrIterator(const char *str);

StringRefIterator &operator++() &;
StringRefIterator operator++(int) &;
// Pre-increment operator
CStrIterator &operator++() &;

// Post-increment operator
CStrIterator operator++(int) &;

// Dereference operator to get a StringRef
adt::StringRef operator*();

bool operator==(const StringRefIterator &other) const;
bool operator!=(const StringRefIterator &other) const;
// Equality comparison operator
bool operator==(const CStrIterator &other) const;

// Inequality comparison operator
bool operator!=(const CStrIterator &other) const;

// Return the raw pointer to the string
const void *data() const;

private:
const char *str_;
};

/// @brief simple buffer manager
/// @tparam T
template <typename T>
struct OwnerBuf {
OwnerBuf() = default;
Expand Down Expand Up @@ -93,20 +105,23 @@ class HOOK_API CachedSymbolTable {
size_t size_ = 0;
};

StringRefIterator strtab_begin(const char *str) const;
CStrIterator strtab_begin(const char *str) const;

StringRefIterator strtab_end(const char *str) const;
CStrIterator strtab_end(const char *str) const;

std::tuple<StringRefIterator, StringRefIterator> strtab_range(
const char *str, size_t size) const {
std::tuple<CStrIterator, CStrIterator> strtab_range(const char *str,
size_t size) const {
return std::make_tuple(strtab_begin(str), strtab_begin(str + size));
}

CachedSymbolTable(const std::string &name, const void *base_address,
const std::vector<std::string> &section_names = {});

/// @brief move to sections header location
void move_to_section_header();

/// @brief move target section location
/// @param index
void move_to_section_header(size_t index);

adt::StringRef getSectionName(size_t index) const;
Expand Down Expand Up @@ -148,9 +163,16 @@ class HOOK_API CachedSymbolTable {
std::vector<std::string> section_names;
};

/// @brief create symbol table
/// @param lib a elf file
/// @param address The address where the elf file is loaded at runtime
/// @return
CachedSymbolTable *createSymbolTable(const std::string &lib,
const void *address);

/// @brief get symbol table
/// @param lib a elf file
/// @return
CachedSymbolTable *getSymbolTable(const std::string &lib);

} // namespace hook
2 changes: 2 additions & 0 deletions include/hook.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ void install_hook();
struct OriginalInfo {
const char* libName = nullptr;
const void* basePtr = nullptr;
/// @brief elf header load pointer at runtime, symbol address = baseHeadPtr
/// + offset(in elf)
const void* baseHeadPtr = nullptr;
void* relaPtr = nullptr;
void* oldFuncPtr = nullptr;
Expand Down
13 changes: 13 additions & 0 deletions include/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@

namespace hook {

/**
* @brief Converts a size in bytes to a human-readable string with appropriate
* units.
*
* This function takes a size in bytes and converts it to a more readable format
* with units such as KB, MB, GB, etc. The result is a string that is easier to
* understand for humans, providing the size in the largest possible unit
* without losing precision.
*
* @param bytes The size in bytes to be converted.
* @return A string representing the size in a human-readable format with
* appropriate units.
*/
std::string prettyFormatSize(size_t bytes);

} // namespace hook
14 changes: 11 additions & 3 deletions include/xpu_mock.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
#pragma once
#include <string>
#include <vector>
#include "logger/StringRef.h"

#include "logger/StringRef.h"

/// @brief replace xpu_set_device or cuda_set_device
void dh_patch_runtime();


/// @brief replace more runtime api to profile api performance and arguments
void __runtimeapi_hook_initialize();

void __print_hook_initialize(std::vector<adt::StringRef> &target_libs, std::vector<adt::StringRef> &target_symbols);
/// @brief config which library and which printf like symbols will be replace
/// @param target_libs
/// @param target_symbols
void __print_hook_initialize(const std::vector<adt::StringRef> &target_libs,
const std::vector<adt::StringRef> &target_symbols);

/// @brief start capture printf output
void __print_hook_start_capture();

/// @brief end capture printf output and return all of printf's output
/// @return
std::string __print_hook_end_capture();
52 changes: 25 additions & 27 deletions lib/elf_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,37 @@

namespace hook {

CachedSymbolTable::StringRefIterator::StringRefIterator(const char *str)
: str_(str) {}
CachedSymbolTable::CStrIterator::CStrIterator(const char *str) : str_(str) {}

CachedSymbolTable::StringRefIterator &
CachedSymbolTable::StringRefIterator::operator++() & {
CachedSymbolTable::CStrIterator &CachedSymbolTable::CStrIterator::operator++()
& {
size_t len = strlen(str_);
++len;
str_ += len;
return *this;
}

CachedSymbolTable::StringRefIterator
CachedSymbolTable::StringRefIterator::operator++(int) & {
auto ret = StringRefIterator(str_);
CachedSymbolTable::CStrIterator CachedSymbolTable::CStrIterator::operator++(
int) & {
auto ret = CStrIterator(str_);
++*this;
return ret;
}

adt::StringRef CachedSymbolTable::StringRefIterator::operator*() {
adt::StringRef CachedSymbolTable::CStrIterator::operator*() {
return adt::StringRef(str_);
}

bool CachedSymbolTable::StringRefIterator::operator==(
const StringRefIterator &other) const {
bool CachedSymbolTable::CStrIterator::operator==(
const CStrIterator &other) const {
return str_ == other.str_;
}
bool CachedSymbolTable::StringRefIterator::operator!=(
const StringRefIterator &other) const {
bool CachedSymbolTable::CStrIterator::operator!=(
const CStrIterator &other) const {
return !(*this == other);
}

const void *CachedSymbolTable::StringRefIterator::data() const { return str_; }
const void *CachedSymbolTable::CStrIterator::data() const { return str_; }

CachedSymbolTable::CachedSymbolTable(
const std::string &name, const void *base_address,
Expand All @@ -45,34 +44,33 @@ CachedSymbolTable::CachedSymbolTable(
base_address(base_address),
section_names(section_names) {
CHECK(ifs.is_open(), "can't open file:{}", name);
MLOG(DEBUG, INFO) << name << " base address:" << base_address;
MLOG(TRACE, INFO) << name << " base address:" << base_address;
ifs.read(reinterpret_cast<char *>(&elf_header), sizeof(elf_header));
parse_named_section();
parse_section_header();
load_symbol_table();
for (size_t i = 0; i < sections_.size(); ++i) {
MLOG(DEBUG, INFO) << "found section:" << i
MLOG(TRACE, INFO) << "found section:" << i
<< " name:" << getSectionName(i)
<< " size:" << prettyFormatSize(sections_[i].sh_size);
}
}

CachedSymbolTable::StringRefIterator CachedSymbolTable::strtab_begin(
CachedSymbolTable::CStrIterator CachedSymbolTable::strtab_begin(
const char *str) const {
return CachedSymbolTable::StringRefIterator(str);
return CachedSymbolTable::CStrIterator(str);
}

CachedSymbolTable::StringRefIterator CachedSymbolTable::strtab_end(
CachedSymbolTable::CStrIterator CachedSymbolTable::strtab_end(
const char *str) const {
return CachedSymbolTable::StringRefIterator(str);
return CachedSymbolTable::CStrIterator(str);
}

void CachedSymbolTable::move_to_section_header() {
ifs.seekg(elf_header.e_shoff, std::ios::beg);
}

void CachedSymbolTable::move_to_section_header(size_t index) {
MLOG(DEBUG, INFO) << "move to:" << index;
move_to_section_header();
size_t shstr_h_oft = sizeof(ElfW(Shdr)) * index;
ifs.seekg(shstr_h_oft, std::ios::cur);
Expand Down Expand Up @@ -101,7 +99,7 @@ void CachedSymbolTable::load_symbol_table() {
#endif
for (size_t i = 0; i < symbol_tb.size(); ++i) {
if (strtab_buf.size() <= symbol_tb[i].st_name) {
MLOG(DEBUG, INFO)
MLOG(TRACE, INFO)
<< "symbol_tb[" << i << "].st_name(" << symbol_tb[i].st_name
<< ") over buf size:" << strtab_buf.size();
continue;
Expand All @@ -119,7 +117,7 @@ void CachedSymbolTable::load_symbol_table() {

for (size_t i = 0; i < xpu_symbol_tb.size(); ++i) {
if (strtab_buf.size() <= xpu_symbol_tb[i].st_name) {
MLOG(DEBUG, INFO) << "xpu_symbol_tb[" << i << "].st_name("
MLOG(TRACE, INFO) << "xpu_symbol_tb[" << i << "].st_name("
<< xpu_symbol_tb[i].st_name
<< ") over buf size:" << strtab_buf.size();
continue;
Expand All @@ -129,15 +127,15 @@ void CachedSymbolTable::load_symbol_table() {
adt::StringRef(&strtab_buf[xpu_symbol_tb[i].st_name]).str());
}

MLOG(DEBUG, INFO) << libName << "\naddress range:" << min_address_ << "~"
MLOG(TRACE, INFO) << libName << "\naddress range:" << min_address_ << "~"
<< max_address_;
}

void CachedSymbolTable::parse_section_header() {
CHECK_EQ(sizeof(ElfW(Shdr)), elf_header.e_shentsize);
move_to_section_header();
sections_.resize(elf_header.e_shnum);
MLOG(DEBUG, INFO) << "elf_header.e_shnum:" << elf_header.e_shnum;
MLOG(TRACE, INFO) << "elf_header.e_shnum:" << elf_header.e_shnum;
ifs.read(reinterpret_cast<char *>(sections_.data()),
sections_.size() * sizeof(sections_[0]));
}
Expand Down Expand Up @@ -171,10 +169,10 @@ const std::string &CachedSymbolTable::lookUpSymbol(const void *func) const {
static std::string empty("");
auto offset = reinterpret_cast<const char *>(func) -
reinterpret_cast<const char *>(base_address);
MLOG(DEBUG, INFO) << "lookup address:" << offset;
MLOG(TRACE, INFO) << "lookup address:" << offset;
auto iter = symbol_table.find(offset);
if (iter == symbol_table.end()) {
MLOG(DEBUG, INFO) << libName
MLOG(TRACE, INFO) << libName
<< "\nnot find launch_async symbol offset:" << offset
<< " base address:" << base_address
<< " func address:" << func << " range("
Expand All @@ -197,4 +195,4 @@ CachedSymbolTable *getSymbolTable(const std::string &lib) {
return table[lib].get();
}

} // namespace hook
} // namespace hook
5 changes: 3 additions & 2 deletions lib/hooks/print_hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ std::string XpuRuntimePrintfHook::end_capture() {
return str;
}

void __print_hook_initialize(std::vector<adt::StringRef>& target_libs,
std::vector<adt::StringRef>& target_symbols) {
void __print_hook_initialize(
const std::vector<adt::StringRef>& target_libs,
const std::vector<adt::StringRef>& target_symbols) {
XpuRuntimePrintfHook::instance()->setTargetLibs(target_libs);
XpuRuntimePrintfHook::instance()->setTargetSymbols(target_symbols);
XpuRuntimePrintfHook::instance()->install(); // replace plt table
Expand Down
2 changes: 1 addition & 1 deletion src/cuda_mock/cuda_mock_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def dump_to_cache(self):

gProfileDataCollection = ProfileDataCollection("gpu" if is_nvidia_gpu else "xpu")
gDefaultTargetLib = ["libxpucuda.so", "libcuda.so"]
gDefaultTargetSymbols = ["__printf_chk", "printf","fprintf","__fprintf","vfprintf",]
gDefaultTargetSymbols = ["__printf_chk", "printf", "fprintf", "__fprintf", "vfprintf",]
class __XpuRuntimeProfiler:
def __init__(self, target_libs = gDefaultTargetLib, target_symbols = gDefaultTargetSymbols):
print_hook_initialize(target_libs=target_libs, target_symbols=target_symbols)
Expand Down
9 changes: 5 additions & 4 deletions src/cuda_mock_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@
// m.def("xpu_initialize", []() { xpu_dh_initialize(); });
// }

std::vector<adt::StringRef> convert_arg_list_of_str(const char** list_of_str) {
std::vector<adt::StringRef> convert_arg_list_of_str(
const HookString_t* list_of_str) {
std::vector<adt::StringRef> result;
if (list_of_str == nullptr) {
MLOG(PYTHON, ERROR) << "impossible convert_arg_list_of_str";
return result;
}
for (const char** str = list_of_str; *str != nullptr; ++str) {
for (const HookString_t* str = list_of_str; *str != nullptr; ++str) {
result.emplace_back(adt::StringRef(*str));
MLOG(PYTHON, INFO) << "convert_arg_list_of_str convert "
<< result.back() << "to cpp object";
Expand Down Expand Up @@ -91,8 +92,8 @@ HOOK_API void xpu_initialize() { // hooker = "profile"
for print_hook
*/

HOOK_API void print_hook_initialize(const char** target_libs,
const char** target_symbols) {
HOOK_API void print_hook_initialize(const HookString_t* target_libs,
const HookString_t* target_symbols) {
std::vector<adt::StringRef> cpp_target_libs =
convert_arg_list_of_str(target_libs);
std::vector<adt::StringRef> cpp_target_symbols =
Expand Down

0 comments on commit c09893c

Please sign in to comment.