Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sending data cannot be read #1586

Closed
jphz opened this issue Jan 16, 2025 · 0 comments
Closed

sending data cannot be read #1586

jphz opened this issue Jan 16, 2025 · 0 comments

Comments

@jphz
Copy link

jphz commented Jan 16, 2025

I feel like the problem lies with the handshake.

create 10000 connection
handshake_complete 10000 count
send_data 10000 count
recv_data approx 9000+ count

asio::async_read(m_stream, asio::buffer(m_headerBuffer.data(), 2),
[this, self = shared_from_this()](auto ec, auto size) {
if (ec) {
onFail(ec, "onRead");
return;
}
LOG_DEBUG("recv_data:{}", size); // this not execute
});

#pragma once

#include <cstdint>
#include <string>
#include <vector>

#include "asio_third_party.h"
#include "connection.h"

namespace River {
class WebSocketConnection final
    : public std::enable_shared_from_this<WebSocketConnection>
    , public Connection {
public:
    WebSocketConnection(NetServicePtr service, uint8_t type = 0);

    WebSocketConnection(NetServicePtr service, uint8_t type, const std::string& ip, uint16_t port);

    ~WebSocketConnection() = default;

    void start();

    void stop() override {
    }

    void sendMsg(PacketPtr packet) override;

    asio::ip::tcp::socket& stream() {
        return m_stream;
    }

private:
    asio::awaitable<void> handshake();

    asio::awaitable<void> clientHandshake();

    asio::awaitable<void> serverHandshake();

    bool parseHandshake(const std::string& request);

    std::string generateWebSocketKey();

    bool validateServerResponse(const std::string& response, const std::string& key);

    void onFail(const asio::error_code& ec, std::string_view prefix);

    void onShutdown();

private:
    std::vector<char> m_headerBuffer{};
    asio::ip::tcp::socket m_stream;
    std::string m_response{};
    bool m_client = false;
};
}  // namespace River

using WebSocketConnectionPtr = std::shared_ptr<River::WebSocketConnection>;

#include "../include/websocket_connection.h"

#include <iostream>
#include <random>
#include <string>

#include "../include/comm_func.h"
#include "../include/crypto.h"
#include "../include/thread_pool.h"

using namespace River;

WebSocketConnection::WebSocketConnection(NetServicePtr service, uint8_t type)
    : Connection(service, type)
    , m_stream(asio::make_strand(THREAD_POOL.getExecutor())) {
    m_headerBuffer.resize(WEBSOCKET_FRAME_HEAD + PACKET_HEAD);
}

WebSocketConnection::WebSocketConnection(NetServicePtr service, uint8_t type, const std::string& ip, uint16_t port)
    : Connection(service, type)
    , m_client(true)
    , m_stream(asio::make_strand(THREAD_POOL.getExecutor())) {
    m_ip = ip;
    m_port = port;
    m_headerBuffer.resize(WEBSOCKET_FRAME_HEAD + PACKET_HEAD);
}

void WebSocketConnection::start() {
    asio::co_spawn(
        m_stream.get_executor(),
        [this, self = shared_from_this()]() -> asio::awaitable<void> {
            RIVER_START_TRY
            co_await handshake();

            if (m_state == CS_CLOSE) {
                co_return;
            }

            if (!m_client) {
                auto [ip, port] = CommFunc::getRemoteIPAndPort(stream());
                m_ip = ip;
                m_port = port;
                sendMsg(nullptr);
            }

            LOG_DEBUG("handshake_complete");
            asio::async_read(m_stream, asio::buffer(m_headerBuffer.data(), 2),
                                     [this, self = shared_from_this()](auto ec, auto size) {
                                         if (ec) {
                                             onFail(ec, "onRead");
                                             return;
                                         }
                                         LOG_DEBUG("recv_data:{}", size);
                                     });

            RIVER_END_TRY
        },
        asio::detached);
}

void WebSocketConnection::sendMsg(PacketPtr packet) {
    static std::string frame = "hello world";
    asio::async_write(m_stream, asio::buffer(frame), [this, self = shared_from_this()](auto ec, auto size) {
        if (ec) {
            onFail(ec, "onWrire");
            return;
        }
        LOG_DEBUG("send_data:{}", size);
    });
}

void WebSocketConnection::onShutdown() {
    if (m_state == CS_CLOSE) {
        return;
    }

    m_state = CS_CLOSE;
    RIVER_START_TRY
    if (auto service = m_service.lock()) {
        service->onDisconnect(shared_from_this());
    }
    if (m_stream.is_open()) {
        asio::error_code ignore{};
        m_stream.shutdown(asio::socket_base::shutdown_both, ignore);
        m_stream.close(ignore);
    }
    RIVER_END_TRY
}

asio::awaitable<void> WebSocketConnection::handshake() {
    if (m_client) {
        co_await clientHandshake();
    } else {
        co_await serverHandshake();
    }
}

asio::awaitable<void> WebSocketConnection::clientHandshake() {
    asio::error_code ec{};
    asio::ip::tcp::endpoint ep(asio::ip::address(asio::ip::make_address(m_ip)), m_port);
    co_await m_stream.async_connect(ep, asio::redirect_error(asio::use_awaitable, ec));
    if (ec) {
        onFail(ec, "clientHandshake async_connect");
        co_return;
    }

    std::string key = generateWebSocketKey();
    std::string request =
        "GET / HTTP/1.1\r\n"
        "Host: " +
        m_ip + ":" + std::to_string(m_port) +
        "\r\n"
        "Upgrade: websocket\r\n"
        "Connection: Upgrade\r\n"
        "Sec-WebSocket-Key: " +
        key +
        "\r\n"
        "Sec-WebSocket-Version: 13\r\n\r\n";

    co_await asio::async_write(m_stream, asio::buffer(request), asio::redirect_error(asio::use_awaitable, ec));
    if (ec) {
        onFail(ec, "clientHandshake request");
        co_return;
    }

    std::vector<char> response(4096);
    std::size_t n =
        co_await m_stream.async_read_some(asio::buffer(response), asio::redirect_error(asio::use_awaitable, ec));
    if (ec) {
        onFail(ec, "clientHandshake response");
        co_return;
    }

    std::string responseStr(response.begin(), response.begin() + n);

    if (!validateServerResponse(responseStr, key)) {
        onFail(ec, "clientHandshake invalid Sec-WebSocket-Accept");
        co_return;
    }
}

asio::awaitable<void> WebSocketConnection::serverHandshake() {
    std::vector<char> data(4096);
    std::size_t totalRead = 0;
    bool handshakeComplete = false;

    asio::error_code ec{};
    while (!handshakeComplete) {
        auto n = co_await m_stream.async_read_some(asio::buffer(data.data() + totalRead, data.size() - totalRead),
                                                   asio::redirect_error(asio::use_awaitable, ec));
        if (ec) {
            onFail(ec, "serverHandshake read");
            co_return;
        }

        if (n == 0) {
            onFail(ec, "serverHandshake n is 0");
            co_return;
        }
        totalRead += n;

        if (totalRead >= 4) {
            std::string request(data.begin(), data.begin() + totalRead);
            std::size_t headerEnd = request.find("\r\n\r\n");
            if (headerEnd != std::string::npos) {
                handshakeComplete = true;
                request.resize(headerEnd + 4);
                if (!parseHandshake(request)) {
                    onFail(ec, "serverHandshake parseHandshake");
                    co_return;
                }
                break;
            }
        }

        if (totalRead >= data.size()) {
            onFail(ec, "serverHandshake handshake request too large or malformed");
            co_return;
        }
    }

    co_await asio::async_write(m_stream, asio::buffer(m_response), asio::redirect_error(asio::use_awaitable, ec));
    if (ec) {
        onFail(ec, "serverHandshake response");
        co_return;
    }
}

bool WebSocketConnection::parseHandshake(const std::string& request) {
    std::string key;
    bool valid = false;
    std::size_t pos = request.find("\r\n\r\n");
    if (pos != std::string::npos) {
        std::string headers = request.substr(0, pos);
        std::size_t upgradePos = headers.find("Upgrade: websocket");
        std::size_t connectionPos = headers.find("Connection: Upgrade");
        std::size_t secWebsocketKeyPos = headers.find("Sec-WebSocket-Key:");
        if (upgradePos != std::string::npos && connectionPos != std::string::npos &&
            secWebsocketKeyPos != std::string::npos) {
            std::size_t keyStart = secWebsocketKeyPos + 18;
            std::size_t keyEnd = headers.find("\r\n", keyStart);
            key = headers.substr(keyStart, keyEnd - keyStart);
            CommFunc::trim(key);
            valid = true;
        }
    }
    if (valid) {
        std::string acceptKey = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
        uint8_t sha1[20];
        SHA1((uint8_t*)acceptKey.c_str(), acceptKey.size(), sha1);
        std::string secWebsocketAccept = Crypto::base64Encode(std::string_view{(char*)sha1, 20});

        m_response =
            "HTTP/1.1 101 Switching Protocols\r\n"
            "Upgrade: websocket\r\n"
            "Connection: Upgrade\r\n"
            "Sec-WebSocket-Accept: " +
            secWebsocketAccept + "\r\n\r\n";
    }
    return valid;
}

std::string WebSocketConnection::generateWebSocketKey() {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<> dis(0, 255);
    std::vector<uint8_t> key(16);
    for (auto& byte : key) {
        byte = static_cast<uint8_t>(dis(gen));
    }
    return Crypto::base64Encode(std::string_view{(char*)key.data(), key.size()});
}

bool WebSocketConnection::validateServerResponse(const std::string& response, const std::string& key) {
    std::string acceptKey = key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
    uint8_t sha1[20];
    SHA1((uint8_t*)acceptKey.c_str(), acceptKey.size(), sha1);
    std::string expected_accept = Crypto::base64Encode(std::string_view{(char*)sha1, 20});

    std::size_t pos = response.find("Sec-WebSocket-Accept: ");
    if (pos == std::string::npos)
        return false;
    pos += 22;
    std::size_t end = response.find("\r\n", pos);
    std::string actualAccept = response.substr(pos, end - pos);
    return actualAccept == expected_accept;
}

void WebSocketConnection::onFail(const asio::error_code& ec, std::string_view prefix) {
    if (prefix != "clientHandshake async_connect") {
        LOG_ERROR("WebSocketConnection::{} code:{} what:{}", prefix, ec.value(), ec.message());
    }
    onShutdown();
}
@jphz jphz closed this as completed Jan 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant