Skip to content

Commit

Permalink
feat: dds subscribe more
Browse files Browse the repository at this point in the history
  • Loading branch information
shuai132 committed Nov 11, 2024
1 parent 3128a0e commit ed79324
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 32 deletions.
95 changes: 70 additions & 25 deletions include/asio_net/dds.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ namespace asio_net {
namespace dds {

using rpc_s = std::shared_ptr<rpc_core::rpc>;
using rpc_w = std::weak_ptr<rpc_core::rpc>;
using handle_t = std::function<void(std::string msg)>;
using handle_s = std::shared_ptr<handle_t>;

struct Msg {
std::string topic;
Expand Down Expand Up @@ -41,17 +44,19 @@ class dds_server {
rpc->subscribe("update_topic_list", [this, rpc](const std::vector<std::string>& topic_list) {
update_topic_list(rpc, topic_list);
});
rpc->subscribe("publish", [this](const dds::Msg& msg) {
publish(msg);
rpc->subscribe("publish", [this, rpc_wp = dds::rpc_w(rpc)](const dds::Msg& msg) {
publish(msg, rpc_wp);
});
};
}

void publish(const dds::Msg& msg) {
void publish(const dds::Msg& msg, const dds::rpc_w& from_rpc) {
auto it = topic_rpc_map.find(msg.topic);
if (it != topic_rpc_map.cend()) {
auto from_rpc_sp = from_rpc.lock();
for (const auto& rpc : it->second) {
rpc->cmd("publish")->msg(msg)->call();
if (rpc == from_rpc_sp) continue;
rpc->cmd("publish")->msg(msg)->retry(-1)->call();
}
}
}
Expand Down Expand Up @@ -89,35 +94,65 @@ class dds_client {
public:
explicit dds_client(asio::io_context& io_context) : client(io_context, rpc_config{.rpc = rpc}) {
client.on_open = [&](const std::shared_ptr<rpc_core::rpc>&) {
rpc->cmd("update_topic_list")->msg(topic_list)->retry(-1)->call();
rpc->subscribe("publish", [this](const dds::Msg& msg) {
dispatch_publish(msg);
});
update_topic_list();
};
}

void publish(std::string topic, std::string data) {
void publish(std::string topic, std::string data = "") {
auto msg = dds::Msg{.topic = std::move(topic), .data = std::move(data)};
dispatch_publish(msg);
rpc->cmd("publish")->msg(std::move(msg))->call();
}

void subscribe(std::string topic, std::function<void(std::string msg)> handle) {
auto it = topic_handle_map.find(topic);
if (it == topic_handle_map.cend()) {
topic_handle_map[topic] = std::move(handle);
topic_list.push_back(std::move(topic));
rpc->cmd("update_topic_list")->msg(topic_list)->retry(-1)->call();
uintptr_t subscribe(const std::string& topic, dds::handle_t handle) {
auto it = topic_handles_map.find(topic);
auto handle_sp = std::make_shared<dds::handle_t>(std::move(handle));
auto handle_id = (uintptr_t)handle_sp.get();
if (it == topic_handles_map.cend()) {
topic_handles_map[topic].push_back(std::move(handle_sp));
update_topic_list();
} else {
it->second = std::move(handle);
it->second.push_back(std::move(handle_sp));
}
return handle_id;
}

void unsubscribe(const std::string& topic) {
auto it = topic_handle_map.find(topic);
if (it != topic_handle_map.cend()) {
topic_handle_map.erase(it);
topic_list.erase(std::remove(topic_list.begin(), topic_list.end(), topic), topic_list.end());
rpc->cmd("update_topic_list")->msg(topic_list)->retry(-1)->call();
bool unsubscribe(const std::string& topic) {
auto it = topic_handles_map.find(topic);
if (it != topic_handles_map.cend()) {
topic_handles_map.erase(it);
update_topic_list();
return true;
} else {
return false;
}
}

bool unsubscribe(uintptr_t handle_id) {
auto it = std::find_if(topic_handles_map.begin(), topic_handles_map.end(), [id = handle_id](auto& p) {
auto& vec = p.second;
auto len_before = vec.size();
vec.erase(std::remove_if(vec.begin(), vec.end(),
[id](auto& sp) {
return (uintptr_t)sp.get() == id;
}),
vec.end());
auto len_after = vec.size();
return len_before != len_after;
});
if (it != topic_handles_map.end()) {
ASIO_NET_LOGD("unsubscribe: id: %zu", handle_id);
if (it->second.empty()) {
topic_handles_map.erase(it);
update_topic_list();
}
return true;
} else {
ASIO_NET_LOGD("unsubscribe: no such id: %zu", handle_id);
return false;
}
}

Expand All @@ -132,18 +167,28 @@ class dds_client {

private:
void dispatch_publish(const dds::Msg& msg) {
auto it = topic_handle_map.find(msg.topic);
if (it != topic_handle_map.cend()) {
auto& handle = it->second;
handle(msg.data);
auto it = topic_handles_map.find(msg.topic);
if (it != topic_handles_map.cend()) {
auto& handles = it->second;
for (const auto& handle : handles) {
(*handle)(msg.data);
}
}
}

void update_topic_list() {
std::vector<std::string> topic_list;
topic_list.reserve(topic_handles_map.size());
for (const auto& kv : topic_handles_map) {
topic_list.push_back(kv.first);
}
rpc->cmd("update_topic_list")->msg(topic_list)->retry(-1)->call();
}

private:
std::shared_ptr<rpc_core::rpc> rpc = rpc_core::rpc::create();
rpc_client client;
std::vector<std::string> topic_list;
std::unordered_map<std::string, std::function<void(std::string msg)>> topic_handle_map;
std::unordered_map<std::string, std::vector<std::shared_ptr<dds::handle_t>>> topic_handles_map;
};

} // namespace asio_net
68 changes: 61 additions & 7 deletions test/dds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
#include <cstdlib>
#include <thread>

#include "assert_def.h"
#include "log.h"

using namespace asio_net;

const uint16_t PORT = 6666;

static std::atomic_bool received_flag[3]{};
static std::atomic_int received_all_cnt{0};

static void init_server() {
std::thread([] {
asio::io_context context;
Expand All @@ -23,19 +27,72 @@ static void init_client() {
std::thread([i] {
asio::io_context context;
dds_client client(context);
client.subscribe("topic_all", [=](const std::string& data) {
client.subscribe("topic_all", [i](const std::string& data) {
LOG("client_%d: topic:%s, data:%s", i, "topic_all", data.c_str());
++received_all_cnt;
});
std::string topic_tmp = "topic_" + std::to_string(i);
client.subscribe(topic_tmp, [=](const std::string& data) {
client.subscribe(topic_tmp, [i, topic_tmp](const std::string& data) {
LOG("client_%d: topic:%s, data:%s", i, topic_tmp.c_str(), data.c_str());
ASSERT(data == "to client_" + std::to_string(i));
ASSERT(!received_flag[i]);
received_flag[i] = true;
});
client.open("localhost", PORT);
client.run();
}).detach();
}
}

static const int interval_ms = 1000;
static void interval_check(dds_client& client) {
/// 1. test basic
static bool first_run = true;
static std::atomic_bool received_flag_self{false};
if (first_run) {
first_run = false;
client.subscribe("topic_self", [](const std::string& msg) {
LOG("received: topic_self: %s", msg.c_str());
ASSERT(!received_flag_self);
received_flag_self = true;
});
} else {
// check and reset flag
for (auto& flag : received_flag) {
ASSERT(flag);
flag = false;
}

ASSERT(received_all_cnt == 3);
received_all_cnt = 0;

ASSERT(received_flag_self);
received_flag_self = false;
}

client.publish("topic_self", "to client_self");
client.publish("topic_0", "to client_0");
client.publish("topic_1", "to client_1");
client.publish("topic_2", "to client_2");
client.publish("topic_all", "hello");

/// 2.1 test unsubscribe
client.subscribe("topic_test_0", [](const std::string& msg) {
(void)(msg);
ASSERT(false);
});
client.unsubscribe("topic_test_0");
client.publish("topic_test_0");

/// 2.2
auto id = client.subscribe("topic_test_1", [](const std::string& msg) {
(void)(msg);
ASSERT(false);
});
client.unsubscribe(id);
client.publish("topic_test_1");
}

int main() {
init_server();
init_client();
Expand All @@ -47,13 +104,10 @@ int main() {
std::function<void()> time_task;
asio::steady_timer timer(context);
time_task = [&] {
timer.expires_after(std::chrono::seconds(2));
timer.expires_after(std::chrono::milliseconds(interval_ms));
timer.async_wait([&](std::error_code ec) {
(void)ec;
client.publish("topic_all", "hello");
client.publish("topic_0", "to client_0");
client.publish("topic_1", "to client_1");
client.publish("topic_2", "to client_2");
interval_check(client);
time_task();
});
};
Expand Down

0 comments on commit ed79324

Please sign in to comment.