Skip to content

Commit

Permalink
add python callback to ensure correct shutdown order
Browse files Browse the repository at this point in the history
Signed-off-by: Melody Ren <[email protected]>
  • Loading branch information
melody-ren committed Jan 8, 2025
1 parent 7c8a9ab commit 8d529e8
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 24 deletions.
8 changes: 4 additions & 4 deletions libs/qec/include/cudaq/qec/plugin_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <dlfcn.h>
#include <map>
#include <memory>
#include <string>

/// @brief Enum to define different types of plugins
Expand All @@ -22,10 +23,9 @@ enum class PluginType {

/// @brief A struct to store plugin handle with its type
struct PluginHandle {
void *handle; // Pointer to the shared library handle. This is the result of
// dlopen() function.
PluginType type; // Type of the plugin (e.g., decoder, code, etc)
bool is_closed; // Flag indicating if the handle is closed
std::shared_ptr<void> handle; // Pointer to the shared library handle. This is
// the result of dlopen() function.
PluginType type; // Type of the plugin (e.g., decoder, code, etc)
};

/// @brief Function to load plugins from a directory based on type
Expand Down
39 changes: 19 additions & 20 deletions libs/qec/lib/plugin_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,36 @@ static std::map<std::string, PluginHandle> &get_plugin_handles() {

// Function to load plugins from a directory based on their type
void load_plugins(const std::string &plugin_dir, PluginType type) {
if (!fs::exists(plugin_dir)) {
std::cerr << "WARNING: Plugin directory does not exist: " << plugin_dir
<< std::endl;
return;
}
for (const auto &entry : fs::directory_iterator(plugin_dir)) {
if (entry.path().extension() == ".so") {
void *raw_handle = dlopen(entry.path().c_str(), RTLD_NOW);
if (raw_handle) {
// Custom deleter ensures dlclose is called
auto deleter = [](void *h) {
if (h)
dlclose(h);
};

void *handle = dlopen(entry.path().c_str(), RTLD_NOW);

if (!handle) {
get_plugin_handles().emplace(
entry.path().filename().string(),
PluginHandle{std::shared_ptr<void>(raw_handle, deleter), type});
} else {
std::cerr << "ERROR: Failed to load plugin: " << entry.path()
<< " Error: " << dlerror() << std::endl;
} else {
get_plugin_handles().emplace(entry.path().filename().string(),
PluginHandle{handle, type, false});
}
}
}
}

// Function to clean up the plugin handles
void cleanup_plugins(PluginType type) {
for (auto &[key, plugin] : get_plugin_handles()) {
if (plugin.type == type) {
if (plugin.handle && !plugin.is_closed) {
dlclose(plugin.handle);
plugin.is_closed = true;
} else {
std::cerr << "WARNING: Invalid or null handle for plugin: " << key
<< "\n";
}
auto &handles = get_plugin_handles();
auto it = handles.begin();
while (it != handles.end()) {
if (it->second.type == type) {
it = handles.erase(it); // dlclose is handled by the custom deleter
} else {
++it;
}
}
}
8 changes: 8 additions & 0 deletions libs/qec/python/bindings/py_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "common/Logger.h"

#include "cudaq/qec/decoder.h"
#include "cudaq/qec/plugin_loader.h"

#include "type_casters.h"
#include "utils.h"
Expand Down Expand Up @@ -70,6 +71,13 @@ std::unordered_map<std::string, std::function<py::object(
PyDecoderRegistry::registry;

void bindDecoder(py::module &mod) {
// Required by all plugin classes
auto cleanup_callback = []() {
// Change the type to the correct plugin type
cleanup_plugins(PluginType::DECODER);
};
// This ensures the correct shutdown sequence
mod.add_object("_cleanup", py::capsule(cleanup_callback));

auto qecmod = py::hasattr(mod, "qecrt")
? mod.attr("qecrt").cast<py::module_>()
Expand Down

0 comments on commit 8d529e8

Please sign in to comment.