We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I feel like the problem lies with the handshake.
create 10000 connection handshake_complete 10000 count send_data 10000 count recv_data approx 9000+ count
asio::async_read(m_stream, asio::buffer(m_headerBuffer.data(), 2), [this, self = shared_from_this()](auto ec, auto size) { if (ec) { onFail(ec, "onRead"); return; } LOG_DEBUG("recv_data:{}", size); // this not execute });
#pragma once #include <cstdint> #include <string> #include <vector> #include "asio_third_party.h" #include "connection.h" namespace River { class WebSocketConnection final : public std::enable_shared_from_this<WebSocketConnection> , public Connection { public: WebSocketConnection(NetServicePtr service, uint8_t type = 0); WebSocketConnection(NetServicePtr service, uint8_t type, const std::string& ip, uint16_t port); ~WebSocketConnection() = default; void start(); void stop() override { } void sendMsg(PacketPtr packet) override; asio::ip::tcp::socket& stream() { return m_stream; } private: asio::awaitable<void> handshake(); asio::awaitable<void> clientHandshake(); asio::awaitable<void> serverHandshake(); bool parseHandshake(const std::string& request); std::string generateWebSocketKey(); bool validateServerResponse(const std::string& response, const std::string& key); void onFail(const asio::error_code& ec, std::string_view prefix); void onShutdown(); private: std::vector<char> m_headerBuffer{}; asio::ip::tcp::socket m_stream; std::string m_response{}; bool m_client = false; }; } // namespace River using WebSocketConnectionPtr = std::shared_ptr<River::WebSocketConnection>; #include "../include/websocket_connection.h" #include <iostream> #include <random> #include <string> #include "../include/comm_func.h" #include "../include/crypto.h" #include "../include/thread_pool.h" using namespace River; WebSocketConnection::WebSocketConnection(NetServicePtr service, uint8_t type) : Connection(service, type) , m_stream(asio::make_strand(THREAD_POOL.getExecutor())) { m_headerBuffer.resize(WEBSOCKET_FRAME_HEAD + PACKET_HEAD); } WebSocketConnection::WebSocketConnection(NetServicePtr service, uint8_t type, const std::string& ip, uint16_t port) : Connection(service, type) , m_client(true) , m_stream(asio::make_strand(THREAD_POOL.getExecutor())) { m_ip = ip; m_port = port; m_headerBuffer.resize(WEBSOCKET_FRAME_HEAD + PACKET_HEAD); } void WebSocketConnection::start() { asio::co_spawn( m_stream.get_executor(), [this, self = shared_from_this()]() -> asio::awaitable<void> { RIVER_START_TRY co_await handshake(); if (m_state == CS_CLOSE) { co_return; } if (!m_client) { auto [ip, port] = CommFunc::getRemoteIPAndPort(stream()); m_ip = ip; m_port = port; sendMsg(nullptr); } LOG_DEBUG("handshake_complete"); asio::async_read(m_stream, asio::buffer(m_headerBuffer.data(), 2), [this, self = shared_from_this()](auto ec, auto size) { if (ec) { onFail(ec, "onRead"); return; } LOG_DEBUG("recv_data:{}", size); }); RIVER_END_TRY }, asio::detached); } void WebSocketConnection::sendMsg(PacketPtr packet) { static std::string frame = "hello world"; asio::async_write(m_stream, asio::buffer(frame), [this, self = shared_from_this()](auto ec, auto size) { if (ec) { onFail(ec, "onWrire"); return; } LOG_DEBUG("send_data:{}", size); }); } void WebSocketConnection::onShutdown() { if (m_state == CS_CLOSE) { return; } m_state = CS_CLOSE; RIVER_START_TRY if (auto service = m_service.lock()) { service->onDisconnect(shared_from_this()); } if (m_stream.is_open()) { asio::error_code ignore{}; m_stream.shutdown(asio::socket_base::shutdown_both, ignore); m_stream.close(ignore); } RIVER_END_TRY } asio::awaitable<void> WebSocketConnection::handshake() { if (m_client) { co_await clientHandshake(); } else { co_await serverHandshake(); } } asio::awaitable<void> WebSocketConnection::clientHandshake() { asio::error_code ec{}; asio::ip::tcp::endpoint ep(asio::ip::address(asio::ip::make_address(m_ip)), m_port); co_await m_stream.async_connect(ep, asio::redirect_error(asio::use_awaitable, ec)); if (ec) { onFail(ec, "clientHandshake async_connect"); co_return; } std::string key = generateWebSocketKey(); std::string request = "GET / HTTP/1.1\r\n" "Host: " + m_ip + ":" + std::to_string(m_port) + "\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Key: " + key + "\r\n" "Sec-WebSocket-Version: 13\r\n\r\n"; co_await asio::async_write(m_stream, asio::buffer(request), asio::redirect_error(asio::use_awaitable, ec)); if (ec) { onFail(ec, "clientHandshake request"); co_return; } std::vector<char> response(4096); std::size_t n = co_await m_stream.async_read_some(asio::buffer(response), asio::redirect_error(asio::use_awaitable, ec)); if (ec) { onFail(ec, "clientHandshake response"); co_return; } std::string responseStr(response.begin(), response.begin() + n); if (!validateServerResponse(responseStr, key)) { onFail(ec, "clientHandshake invalid Sec-WebSocket-Accept"); co_return; } } asio::awaitable<void> WebSocketConnection::serverHandshake() { std::vector<char> data(4096); std::size_t totalRead = 0; bool handshakeComplete = false; asio::error_code ec{}; while (!handshakeComplete) { auto n = co_await m_stream.async_read_some(asio::buffer(data.data() + totalRead, data.size() - totalRead), asio::redirect_error(asio::use_awaitable, ec)); if (ec) { onFail(ec, "serverHandshake read"); co_return; } if (n == 0) { onFail(ec, "serverHandshake n is 0"); co_return; } totalRead += n; if (totalRead >= 4) { std::string request(data.begin(), data.begin() + totalRead); std::size_t headerEnd = request.find("\r\n\r\n"); if (headerEnd != std::string::npos) { handshakeComplete = true; request.resize(headerEnd + 4); if (!parseHandshake(request)) { onFail(ec, "serverHandshake parseHandshake"); co_return; } break; } } if (totalRead >= data.size()) { onFail(ec, "serverHandshake handshake request too large or malformed"); co_return; } } co_await asio::async_write(m_stream, asio::buffer(m_response), asio::redirect_error(asio::use_awaitable, ec)); if (ec) { onFail(ec, "serverHandshake response"); co_return; } } bool WebSocketConnection::parseHandshake(const std::string& request) { std::string key; bool valid = false; std::size_t pos = request.find("\r\n\r\n"); if (pos != std::string::npos) { std::string headers = request.substr(0, pos); std::size_t upgradePos = headers.find("Upgrade: websocket"); std::size_t connectionPos = headers.find("Connection: Upgrade"); std::size_t secWebsocketKeyPos = headers.find("Sec-WebSocket-Key:"); if (upgradePos != std::string::npos && connectionPos != std::string::npos && secWebsocketKeyPos != std::string::npos) { std::size_t keyStart = secWebsocketKeyPos + 18; std::size_t keyEnd = headers.find("\r\n", keyStart); key = headers.substr(keyStart, keyEnd - keyStart); CommFunc::trim(key); valid = true; } } if (valid) { std::string acceptKey = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; uint8_t sha1[20]; SHA1((uint8_t*)acceptKey.c_str(), acceptKey.size(), sha1); std::string secWebsocketAccept = Crypto::base64Encode(std::string_view{(char*)sha1, 20}); m_response = "HTTP/1.1 101 Switching Protocols\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Accept: " + secWebsocketAccept + "\r\n\r\n"; } return valid; } std::string WebSocketConnection::generateWebSocketKey() { std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution<> dis(0, 255); std::vector<uint8_t> key(16); for (auto& byte : key) { byte = static_cast<uint8_t>(dis(gen)); } return Crypto::base64Encode(std::string_view{(char*)key.data(), key.size()}); } bool WebSocketConnection::validateServerResponse(const std::string& response, const std::string& key) { std::string acceptKey = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; uint8_t sha1[20]; SHA1((uint8_t*)acceptKey.c_str(), acceptKey.size(), sha1); std::string expected_accept = Crypto::base64Encode(std::string_view{(char*)sha1, 20}); std::size_t pos = response.find("Sec-WebSocket-Accept: "); if (pos == std::string::npos) return false; pos += 22; std::size_t end = response.find("\r\n", pos); std::string actualAccept = response.substr(pos, end - pos); return actualAccept == expected_accept; } void WebSocketConnection::onFail(const asio::error_code& ec, std::string_view prefix) { if (prefix != "clientHandshake async_connect") { LOG_ERROR("WebSocketConnection::{} code:{} what:{}", prefix, ec.value(), ec.message()); } onShutdown(); }
The text was updated successfully, but these errors were encountered:
No branches or pull requests
I feel like the problem lies with the handshake.
create 10000 connection
handshake_complete 10000 count
send_data 10000 count
recv_data approx 9000+ count
asio::async_read(m_stream, asio::buffer(m_headerBuffer.data(), 2),
[this, self = shared_from_this()](auto ec, auto size) {
if (ec) {
onFail(ec, "onRead");
return;
}
LOG_DEBUG("recv_data:{}", size); // this not execute
});
The text was updated successfully, but these errors were encountered: