Skip to content

Commit

Permalink
feat: support ssl
Browse files Browse the repository at this point in the history
  • Loading branch information
shuai132 committed Sep 13, 2023
1 parent 716a1b2 commit 1c097e3
Show file tree
Hide file tree
Showing 23 changed files with 615 additions and 70 deletions.
17 changes: 17 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.5)
project(asio_net CXX)

option(ASIO_NET_ENABLE_RPC "" ON)
option(ASIO_NET_ENABLE_SSL "" OFF)

option(ASIO_NET_BUILD_TEST "" OFF)

Expand All @@ -17,6 +18,9 @@ add_compile_options(-Wall)

add_library(${PROJECT_NAME} INTERFACE)
target_include_directories(${PROJECT_NAME} INTERFACE .)
if (ASIO_NET_ENABLE_SSL)
target_compile_definitions(${PROJECT_NAME} INTERFACE -DASIO_NET_ENABLE_SSL)
endif ()

if (ASIO_NET_ENABLE_RPC)
target_include_directories(${PROJECT_NAME} INTERFACE
Expand All @@ -25,11 +29,17 @@ if (ASIO_NET_ENABLE_RPC)
)
endif ()

if (ASIO_NET_ENABLE_SSL)
find_package(OpenSSL 1.1.0 REQUIRED)
target_link_libraries(${PROJECT_NAME} INTERFACE OpenSSL::SSL)
endif ()

if (ASIO_NET_BUILD_TEST)
message(STATUS "ASIO_PATH: $ENV{ASIO_PATH}")
include_directories($ENV{ASIO_PATH})
link_libraries(${PROJECT_NAME})
link_libraries(pthread)
add_definitions(-DOPENSSL_PEM_PATH=\"${CMAKE_CURRENT_LIST_DIR}/test/ssl/\")

# for github actions/ci
if (ASIO_NET_DISABLE_ON_DATA_PRINT)
Expand All @@ -50,6 +60,10 @@ if (ASIO_NET_BUILD_TEST)
add_executable(${PROJECT_NAME}_test_server_discovery test/server_discovery.cpp)
add_executable(${PROJECT_NAME}_test_domain_tcp test/domain_tcp.cpp)
add_executable(${PROJECT_NAME}_test_domain_udp test/domain_udp.cpp)
if (ASIO_NET_ENABLE_SSL)
add_executable(${PROJECT_NAME}_test_tcp_ssl_c test/tcp_ssl_c.cpp)
add_executable(${PROJECT_NAME}_test_tcp_ssl_s test/tcp_ssl_s.cpp)
endif ()

if (ASIO_NET_ENABLE_RPC)
add_compile_definitions(RPC_CORE_LOG_SHOW_DEBUG)
Expand All @@ -60,5 +74,8 @@ if (ASIO_NET_BUILD_TEST)
add_executable(${PROJECT_NAME}_test_rpc_s test/rpc_s.cpp)
add_executable(${PROJECT_NAME}_test_rpc_c test/rpc_c.cpp)
add_executable(${PROJECT_NAME}_test_domain_rpc test/domain_rpc.cpp)
if (ASIO_NET_ENABLE_SSL)
add_executable(${PROJECT_NAME}_test_rpc_ssl test/rpc_ssl.cpp)
endif ()
endif ()
endif ()
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ and [rpc_core](https://github.com/shuai132/rpc_core)
* TCP/UDP support, rely on: [asio](http://think-async.com/Asio/)
* RPC support, rely on: [rpc_core](https://github.com/shuai132/rpc_core)
* Service discovery based on UDP multicast
* Domain socket and rpc support, compatible with IPv6
* Support IPv6 and SSL (with OpenSSL)
* Domain socket and rpc support
* Comprehensive unittests
* Automatic reconnection

Expand Down
49 changes: 49 additions & 0 deletions detail/socket_type.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include "asio.hpp"
#include "log.h"
#include "noncopyable.hpp"

#ifdef ASIO_NET_ENABLE_SSL
#include "asio/ssl.hpp"
#endif

namespace asio_net {
namespace detail {

enum class socket_type {
normal,
domain,
ssl,
};

template <socket_type T>
struct socket_impl;

template <>
struct socket_impl<socket_type::normal> {
using socket = asio::ip::tcp::socket;
using endpoint = asio::ip::tcp::endpoint;
using resolver = asio::ip::tcp::resolver;
using acceptor = asio::ip::tcp::acceptor;
};

template <>
struct socket_impl<socket_type::domain> {
using socket = asio::local::stream_protocol::socket;
using endpoint = asio::local::stream_protocol::endpoint;
using acceptor = asio::local::stream_protocol::acceptor;
};

#ifdef ASIO_NET_ENABLE_SSL
template <>
struct socket_impl<socket_type::ssl> {
using socket = asio::ssl::stream<asio::ip::tcp::socket>;
using endpoint = asio::ip::tcp::endpoint;
using resolver = asio::ip::tcp::resolver;
using acceptor = asio::ip::tcp::acceptor;
};
#endif

} // namespace detail
} // namespace asio_net
26 changes: 14 additions & 12 deletions detail/tcp_channel_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@
#include "log.h"
#include "message.hpp"
#include "noncopyable.hpp"
#include "socket_type.hpp"
#include "type.h"

namespace asio_net {
namespace detail {

template <typename T>
template <socket_type T>
class tcp_channel_t : private noncopyable {
using socket = typename T::socket;
using endpoint = typename T::endpoint;

public:
tcp_channel_t(socket& socket, const config& config) : socket_(socket), config_(config) {
tcp_channel_t(typename socket_impl<T>::socket& socket, const config& config) : socket_(socket), config_(config) {
ASIO_NET_LOGD("tcp_channel: %p", this);
}

Expand All @@ -30,14 +28,18 @@ class tcp_channel_t : private noncopyable {
void init_socket() {
if (config_.socket_send_buffer_size != UINT32_MAX) {
asio::socket_base::send_buffer_size option(config_.socket_send_buffer_size);
socket_.set_option(option);
get_socket().set_option(option);
}
if (config_.socket_recv_buffer_size != UINT32_MAX) {
asio::socket_base::receive_buffer_size option(config_.socket_recv_buffer_size);
socket_.set_option(option);
get_socket().set_option(option);
}
}

inline auto& get_socket() const {
return socket_.lowest_layer();
}

public:
/**
* async send message
Expand All @@ -60,14 +62,14 @@ class tcp_channel_t : private noncopyable {
}

bool is_open() const {
return socket_.is_open();
return get_socket().is_open();
}

endpoint local_endpoint() {
typename socket_impl<T>::endpoint local_endpoint() {
return socket_.local_endpoint();
}

endpoint remote_endpoint() {
typename socket_impl<T>::endpoint remote_endpoint() {
return socket_.remote_endpoint();
}

Expand Down Expand Up @@ -187,7 +189,7 @@ class tcp_channel_t : private noncopyable {

if (!is_open()) return;
asio::error_code ec;
socket_.close(ec);
get_socket().close(ec);
if (ec) {
ASIO_NET_LOGW("do_close: %s", ec.message().c_str());
}
Expand All @@ -202,7 +204,7 @@ class tcp_channel_t : private noncopyable {
}

private:
socket& socket_;
typename socket_impl<T>::socket& socket_;
const config& config_;
detail::message read_msg_;
uint32_t send_buffer_now_ = 0;
Expand Down
89 changes: 66 additions & 23 deletions detail/tcp_client_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
namespace asio_net {
namespace detail {

template <typename T>
template <socket_type T>
class tcp_client_t : public tcp_channel_t<T> {
using socket = typename T::socket;
using endpoint = typename T::endpoint;

public:
explicit tcp_client_t(asio::io_context& io_context, config config = {})
: tcp_channel_t<T>(socket_, config_), io_context_(io_context), socket_(io_context), config_(config) {
config_.init();
}

#ifdef ASIO_NET_ENABLE_SSL
explicit tcp_client_t(asio::io_context& io_context, asio::ssl::context& ssl_context, config config = {})
: tcp_channel_t<T>(socket_, config_), io_context_(io_context), socket_(io_context, ssl_context), config_(config) {
config_.init();
}
#endif

/**
* connect to server
*
Expand Down Expand Up @@ -49,7 +53,7 @@ class tcp_client_t : public tcp_channel_t<T> {

void set_reconnect(uint32_t ms) {
reconnect_ms_ = ms;
reconnect_timer_ = std::make_unique<asio::steady_timer>(socket_.get_executor());
reconnect_timer_ = std::make_unique<asio::steady_timer>(io_context_);
}

void cancel_reconnect() {
Expand Down Expand Up @@ -83,30 +87,36 @@ class tcp_client_t : public tcp_channel_t<T> {

private:
void do_open(const std::string& host, uint16_t port) {
static_assert(std::is_same<asio::ip::tcp, T>::value, "");
auto resolver = std::make_unique<typename T::resolver>(socket_.get_executor());
static_assert(T == socket_type::normal || T == socket_type::ssl, "");
auto resolver = std::make_unique<typename socket_impl<T>::resolver>(io_context_);
auto rp = resolver.get();
rp->async_resolve(typename T::resolver::query(host, std::to_string(port)),
[this, resolver = std::move(resolver)](const std::error_code& ec, const typename T::resolver::results_type& endpoints) {
if (!ec) {
asio::async_connect(socket_, endpoints, [this](const std::error_code& ec, const endpoint&) {
async_connect_handler(ec);
});
} else {
if (on_open_failed) on_open_failed(ec);
check_reconnect();
}
});
rp->async_resolve(
typename socket_impl<T>::resolver::query(host, std::to_string(port)),
[this, resolver = std::move(resolver)](const std::error_code& ec, const typename socket_impl<T>::resolver::results_type& endpoints) {
if (!ec) {
asio::async_connect(tcp_channel_t<T>::get_socket(), endpoints,
[this](const std::error_code& ec, const typename socket_impl<T>::endpoint&) {
async_connect_handler<T>(ec);
});
} else {
if (on_open_failed) on_open_failed(ec);
check_reconnect();
}
});
}

void do_open(const std::string& endpoint) {
static_assert(std::is_same<T, asio::local::stream_protocol>::value, "");
socket_.async_connect(typename T::endpoint(endpoint), [this](const std::error_code& ec) {
async_connect_handler(ec);
static_assert(T == socket_type::domain, "");
socket_.async_connect(typename socket_impl<T>::endpoint(endpoint), [this](const std::error_code& ec) {
async_connect_handler<socket_type::domain>(ec);
});
}

void async_connect_handler(const std::error_code& ec) {
template <socket_type>
void async_connect_handler(const std::error_code& ec);

template <>
void async_connect_handler<socket_type::normal>(const std::error_code& ec) {
if (!ec) {
this->init_socket();
tcp_channel_t<T>::on_close = [this] {
Expand All @@ -124,6 +134,39 @@ class tcp_client_t : public tcp_channel_t<T> {
}
}

template <>
inline void async_connect_handler<socket_type::domain>(const std::error_code& ec) {
async_connect_handler<socket_type::normal>(ec);
}

#ifdef ASIO_NET_ENABLE_SSL
template <>
void async_connect_handler<socket_type::ssl>(const std::error_code& ec) {
if (!ec) {
this->init_socket();
socket_.async_handshake(asio::ssl::stream_base::client, [this](const std::error_code& error) {
if (!error) {
tcp_channel_t<T>::on_close = [this] {
tcp_client_t::on_close();
check_reconnect();
};
if (on_open) on_open();
if (reconnect_timer_) {
reconnect_timer_->cancel();
}
this->do_read_start();
} else {
if (on_open_failed) on_open_failed(error);
check_reconnect();
}
});
} else {
if (on_open_failed) on_open_failed(ec);
check_reconnect();
}
}
#endif

public:
std::function<void()> on_open;
std::function<void(std::error_code)> on_open_failed;
Expand All @@ -134,7 +177,7 @@ class tcp_client_t : public tcp_channel_t<T> {

private:
asio::io_context& io_context_;
socket socket_;
typename socket_impl<T>::socket socket_;
config config_;
std::unique_ptr<asio::steady_timer> reconnect_timer_;
uint32_t reconnect_ms_ = 0;
Expand Down
Loading

0 comments on commit 1c097e3

Please sign in to comment.