Skip to content

Commit

Permalink
Fix RPC queries being doubly wrapped (#987)
Browse files Browse the repository at this point in the history
In some cases RPC queries that are wrapped with RpcDestFlags, RpcDestActor,
or RpcDestActorFlags are wrapped one more time with RpcDestActorFlags.
This PR fixes this behaviour.
  • Loading branch information
apolyakov authored May 3, 2024
1 parent 3d1f228 commit 31a0030
Show file tree
Hide file tree
Showing 14 changed files with 974 additions and 114 deletions.
103 changes: 83 additions & 20 deletions common/rpc-headers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,91 @@
// Copyright (c) 2020 LLC «V Kontakte»
// Distributed under the GPL v3 License, see LICENSE.notice.txt

#include "common/algorithms/find.h"
#include "common/rpc-headers.h"
#include "common/algorithms/find.h"
#include "common/tl/constants/common.h"

size_t fill_extra_headers_if_needed(RpcExtraHeaders &extra_headers, uint32_t function_magic, int actor_id, bool ignore_answer) {
size_t extra_headers_size = 0;
bool need_actor = actor_id != 0 && vk::none_of_equal(function_magic, TL_RPC_DEST_ACTOR, TL_RPC_DEST_ACTOR_FLAGS);
bool need_flags = ignore_answer && vk::none_of_equal(function_magic, TL_RPC_DEST_FLAGS, TL_RPC_DEST_ACTOR_FLAGS);

if (need_actor && need_flags) {
extra_headers.rpc_dest_actor_flags.op = TL_RPC_DEST_ACTOR_FLAGS;
extra_headers.rpc_dest_actor_flags.actor_id = actor_id;
extra_headers.rpc_dest_actor_flags.flags = vk::tl::common::rpc_invoke_req_extra_flags::no_result;
extra_headers_size = sizeof(extra_headers.rpc_dest_actor_flags);
} else if (need_actor) {
extra_headers.rpc_dest_actor.op = TL_RPC_DEST_ACTOR;
extra_headers.rpc_dest_actor.actor_id = actor_id;
extra_headers_size = sizeof(extra_headers.rpc_dest_actor);
} else if (need_flags) {
extra_headers.rpc_dest_flags.op = TL_RPC_DEST_FLAGS;
extra_headers.rpc_dest_flags.flags = vk::tl::common::rpc_invoke_req_extra_flags::no_result;
extra_headers_size = sizeof(extra_headers.rpc_dest_flags);

RegularizeWrappersReturnT regularize_wrappers(const char *rpc_payload, std::int32_t actor_id, bool ignore_result) {
static_assert(sizeof(RpcDestActorFlagsHeaders) >= sizeof(RpcDestActorHeaders));
static_assert(sizeof(RpcDestActorFlagsHeaders) >= sizeof(RpcDestFlagsHeaders));

const auto cur_wrapper{*reinterpret_cast<const RpcExtraHeaders *>(rpc_payload)};
const auto function_magic{*reinterpret_cast<const std::uint32_t *>(rpc_payload)};

if (actor_id == 0 && !ignore_result && vk::none_of_equal(function_magic, TL_RPC_DEST_ACTOR, TL_RPC_DEST_FLAGS, TL_RPC_DEST_ACTOR_FLAGS)) {
return {std::nullopt, 0, std::nullopt, nullptr};
}

RpcExtraHeaders extra_headers{};
const std::size_t new_wrapper_size{sizeof(RpcDestActorFlagsHeaders)};
std::size_t cur_wrapper_size{0};
std::int32_t cur_wrapper_actor_id{0};
bool cur_wrapper_ignore_result{false};

switch (function_magic) {
case TL_RPC_DEST_ACTOR_FLAGS:
cur_wrapper_size = sizeof(RpcDestActorFlagsHeaders);
cur_wrapper_actor_id = cur_wrapper.rpc_dest_actor_flags.actor_id;
cur_wrapper_ignore_result = static_cast<bool>(cur_wrapper.rpc_dest_actor_flags.flags & vk::tl::common::rpc_invoke_req_extra_flags::no_result);

extra_headers.rpc_dest_actor_flags.op = TL_RPC_DEST_ACTOR_FLAGS;
extra_headers.rpc_dest_actor_flags.actor_id = actor_id != 0 ? actor_id : cur_wrapper.rpc_dest_actor_flags.actor_id;
if (ignore_result) {
extra_headers.rpc_dest_actor_flags.flags = cur_wrapper.rpc_dest_actor_flags.flags | vk::tl::common::rpc_invoke_req_extra_flags::no_result;
} else {
extra_headers.rpc_dest_actor_flags.flags = cur_wrapper.rpc_dest_actor_flags.flags & ~vk::tl::common::rpc_invoke_req_extra_flags::no_result;
}

break;
case TL_RPC_DEST_ACTOR:
cur_wrapper_size = sizeof(RpcDestActorHeaders);
cur_wrapper_actor_id = cur_wrapper.rpc_dest_actor.actor_id;

extra_headers.rpc_dest_actor_flags.op = TL_RPC_DEST_ACTOR_FLAGS;
extra_headers.rpc_dest_actor_flags.actor_id = actor_id != 0 ? actor_id : cur_wrapper.rpc_dest_actor.actor_id;
extra_headers.rpc_dest_actor_flags.flags = ignore_result ? vk::tl::common::rpc_invoke_req_extra_flags::no_result : 0x0;

break;
case TL_RPC_DEST_FLAGS:
cur_wrapper_size = sizeof(RpcDestFlagsHeaders);
cur_wrapper_ignore_result = static_cast<bool>(cur_wrapper.rpc_dest_flags.flags & vk::tl::common::rpc_invoke_req_extra_flags::no_result);

extra_headers.rpc_dest_actor_flags.op = TL_RPC_DEST_ACTOR_FLAGS;
extra_headers.rpc_dest_actor_flags.actor_id = actor_id;
if (ignore_result) {
extra_headers.rpc_dest_actor_flags.flags = cur_wrapper.rpc_dest_flags.flags | vk::tl::common::rpc_invoke_req_extra_flags::no_result;
} else {
extra_headers.rpc_dest_actor_flags.flags = cur_wrapper.rpc_dest_flags.flags & ~vk::tl::common::rpc_invoke_req_extra_flags::no_result;
}

break;
default:
// we don't have a cur_wrapper, but we do have 'actor_id' or 'ignore_result' set
extra_headers.rpc_dest_actor_flags.op = TL_RPC_DEST_ACTOR_FLAGS;
extra_headers.rpc_dest_actor_flags.actor_id = actor_id;
extra_headers.rpc_dest_actor_flags.flags = ignore_result ? vk::tl::common::rpc_invoke_req_extra_flags::no_result : 0x0;

break;
}
return extra_headers_size;

decltype(RegularizeWrappersReturnT{}.opt_actor_id_warning_info) opt_actor_id_warning{};
if (actor_id != 0 && cur_wrapper_actor_id != 0) {
opt_actor_id_warning.emplace("inaccurate use of 'actor_id': '%d' was passed into RPC connection constructor, "
"but '%d' was already set in RpcDestActor or RpcDestActorFlags\n",
actor_id, cur_wrapper_actor_id);
}

const char *opt_ignore_result_warning_msg{nullptr};
if (!ignore_result && cur_wrapper_ignore_result) {
opt_ignore_result_warning_msg = "inaccurate use of 'ignore_answer': 'false' was passed into TL query function (e.g., rpc_tl_query), "
"but 'true' was already set in RpcDestFlags or RpcDestActorFlags\n";
}

return {
std::pair<RpcExtraHeaders, std::size_t>{extra_headers, new_wrapper_size},
cur_wrapper_size,
std::move(opt_actor_id_warning),
opt_ignore_result_warning_msg,
};
}
19 changes: 17 additions & 2 deletions common/rpc-headers.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

#pragma once

#include <cstddef>
#include <cstdint>
#include <optional>
#include <tuple>
#include <utility>

#pragma pack(push, 1)

Expand Down Expand Up @@ -44,4 +46,17 @@ struct RpcHeaders {

#pragma pack(pop)

size_t fill_extra_headers_if_needed(RpcExtraHeaders &extra_headers, uint32_t function_magic, int actor_id, bool ignore_answer);
struct RegularizeWrappersReturnT {
/// Optionally contains a new wrapper and its size
std::optional<std::pair<RpcExtraHeaders, std::size_t>> opt_new_wrapper;
/// The size of a wrapper found in rpc payload (0 if there is no one)
std::size_t cur_wrapper_size;
/// Optionally contains a tuple of <format string, current wrapper's actor_id, new actor_id>.
/// If not std::nullopt, can be used to warn about actor_id redefinition, for example,
/// 'php_warning(format_str, current_wrapper_actor_id, new_actor_id)'
std::optional<std::tuple<const char *, std::int32_t, std::int32_t>> opt_actor_id_warning_info;
/// Optionally contains a string. If not nullptr, can be used to warn about inaccurate usage of 'ignore_result'.
const char *opt_ignore_result_warning_msg;
};

RegularizeWrappersReturnT regularize_wrappers(const char *rpc_payload, std::int32_t actor_id, bool ignore_result);
45 changes: 32 additions & 13 deletions runtime/rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "runtime/rpc.h"

#include <cstdarg>
#include <cstring>
#include <chrono>

#include "common/rpc-error-codes.h"
Expand Down Expand Up @@ -679,22 +680,40 @@ int64_t rpc_send_impl(const class_instance<C$RpcConnection> &conn, double timeou
store_int(-1); // reserve for crc32
php_assert (data_buf.size() % sizeof(int) == 0);

const char *rpc_payload_start = data_buf.c_str() + sizeof(RpcHeaders);
size_t rpc_payload_size = data_buf.size() - sizeof(RpcHeaders);
uint32_t function_magic = CurrentProcessingQuery::get().get_last_stored_tl_function_magic();
RpcExtraHeaders extra_headers{};
size_t extra_headers_size = fill_extra_headers_if_needed(extra_headers, function_magic, conn.get()->actor_id, ignore_answer);
const auto [opt_new_wrapper, cur_wrapper_size, opt_actor_id_warning_info, opt_ignore_result_warning_msg]{
regularize_wrappers(data_buf.c_str() + sizeof(RpcHeaders), conn.get()->actor_id, ignore_answer)};

const auto request_size = static_cast<size_t>(data_buf.size() + extra_headers_size);
char *p = static_cast<char *>(dl::allocate(request_size));
if (opt_actor_id_warning_info.has_value()) {
const auto [msg, cur_wrapper_actor_id, new_wrapper_actor_id]{opt_actor_id_warning_info.value()};
php_warning(msg, cur_wrapper_actor_id, new_wrapper_actor_id);
}
if (opt_ignore_result_warning_msg != nullptr) {
php_warning("%s", opt_ignore_result_warning_msg);
}

char *request_buf{nullptr};
std::size_t request_size{0};

// Memory will look like this:
// 'request_buf' will look like this:
// [ RpcHeaders (reserved in f$rpc_clean) ] [ RpcExtraHeaders (optional) ] [ payload ]
memcpy(p, data_buf.c_str(), sizeof(RpcHeaders));
memcpy(p + sizeof(RpcHeaders), &extra_headers, extra_headers_size);
memcpy(p + sizeof(RpcHeaders) + extra_headers_size, rpc_payload_start, rpc_payload_size);
if (opt_new_wrapper.has_value()) {
const auto [new_wrapper, new_wrapper_size]{opt_new_wrapper.value()};
request_size = data_buf.size() - cur_wrapper_size + new_wrapper_size;
request_buf = static_cast<char *>(dl::allocate(request_size));

std::memcpy(request_buf, data_buf.c_str(), sizeof(RpcHeaders));
std::memcpy(request_buf + sizeof(RpcHeaders), &new_wrapper, new_wrapper_size);
std::memcpy(request_buf + sizeof(RpcHeaders) + new_wrapper_size,
data_buf.c_str() + sizeof(RpcHeaders) + cur_wrapper_size,
data_buf.size() - sizeof(RpcHeaders) - cur_wrapper_size);
} else {
request_size = data_buf.size();
request_buf = static_cast<char *>(dl::allocate(request_size));

std::memcpy(request_buf, data_buf.c_str(), request_size);
}

slot_id_t q_id = rpc_send_query(conn.get()->host_num, p, static_cast<int>(request_size), timeout_convert_to_ms(timeout));
slot_id_t q_id = rpc_send_query(conn.get()->host_num, request_buf, static_cast<int>(request_size), timeout_convert_to_ms(timeout));

// request's statistics
req_extra_info = rpc_request_extra_info_t{request_size};
Expand Down Expand Up @@ -739,7 +758,7 @@ int64_t rpc_send_impl(const class_instance<C$RpcConnection> &conn, double timeou
double send_timestamp = std::chrono::duration<double>{std::chrono::system_clock::now().time_since_epoch()}.count();

cur->resumable_id = register_forked_resumable(new rpc_resumable(q_id));
cur->function_magic = function_magic;
cur->function_magic = CurrentProcessingQuery::get().get_last_stored_tl_function_magic();
cur->actor_or_port = conn.get()->actor_id > 0 ? conn.get()->actor_id : -conn.get()->port;
cur->timer = nullptr;

Expand Down
3 changes: 1 addition & 2 deletions runtime/typed_rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ array<int64_t> f$typed_rpc_tl_query(const class_instance<C$RpcConnection> &conne
static_assert(std::is_same_v<KphpRpcRequest, R>, "Unexpected type");

if (ignore_answer && need_responses_extra_info) {
php_warning(
"Both $ignore_answer and $need_responses_extra_info are 'true'. Can't collect metrics for ignored answers");
php_warning("Both $ignore_answer and $need_responses_extra_info are 'true'. Can't collect metrics for ignored answers");
}

size_t bytes_sent = 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
<?php

/**
* AUTOGENERATED, DO NOT EDIT! If you want to modify it, check tl schema.
*
* This autogenerated code represents tl class for typed RPC API.
*/

namespace VK\TL\_common\Functions;

use VK\TL;

/**
* @kphp-tl-class
*/
class rpcDestActor implements TL\RpcFunction {

/** @var int */
public $actor_id = 0;

/** @var TL\RpcFunction */
public $query = null;

/** Allows kphp implicitly load function result class */
private const RESULT = TL\_common\Functions\rpcDestActor_result::class;

/**
* @param int $actor_id
* @param TL\RpcFunction $query
*/
public function __construct($actor_id = 0, $query = null) {
$this->actor_id = $actor_id;
$this->query = $query;
}

/**
* @param TL\RpcFunctionReturnResult $function_return_result
* @return TL\RpcFunctionReturnResult
*/
public static function functionReturnValue($function_return_result) {
if ($function_return_result instanceof rpcDestActor_result) {
return $function_return_result->value;
}
warning('Unexpected result type in functionReturnValue: ' . ($function_return_result ? get_class($function_return_result) : 'null'));
return (new rpcDestActor_result())->value;
}

/**
* @kphp-inline
*
* @param TL\RpcResponse $response
* @return TL\RpcFunctionReturnResult
*/
public static function result(TL\RpcResponse $response) {
return self::functionReturnValue($response->getResult());
}

/**
* @kphp-inline
*
* @return string
*/
public function getTLFunctionName() {
return 'rpcDestActor';
}

}

/**
* @kphp-tl-class
*/
class rpcDestActor_result implements TL\RpcFunctionReturnResult {

/** @var TL\RpcFunctionReturnResult */
public $value = null;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
<?php

/**
* AUTOGENERATED, DO NOT EDIT! If you want to modify it, check tl schema.
*
* This autogenerated code represents tl class for typed RPC API.
*/

namespace VK\TL\_common\Functions;

use VK\TL;

/**
* @kphp-tl-class
*/
class rpcDestActorFlags implements TL\RpcFunction {

/** @var int */
public $actor_id = 0;

/** @var int */
public $flags = 0;

/** @var TL\_common\Types\rpcInvokeReqExtra */
public $extra = null;

/** @var TL\RpcFunction */
public $query = null;

/** Allows kphp implicitly load function result class */
private const RESULT = TL\_common\Functions\rpcDestActorFlags_result::class;

/**
* @param int $actor_id
* @param int $flags
* @param TL\_common\Types\rpcInvokeReqExtra $extra
* @param TL\RpcFunction $query
*/
public function __construct($actor_id = 0, $flags = 0, $extra = null, $query = null) {
$this->actor_id = $actor_id;
$this->flags = $flags;
$this->extra = $extra;
$this->query = $query;
}

/**
* @param TL\RpcFunctionReturnResult $function_return_result
* @return TL\RpcFunctionReturnResult
*/
public static function functionReturnValue($function_return_result) {
if ($function_return_result instanceof rpcDestActorFlags_result) {
return $function_return_result->value;
}
warning('Unexpected result type in functionReturnValue: ' . ($function_return_result ? get_class($function_return_result) : 'null'));
return (new rpcDestActorFlags_result())->value;
}

/**
* @kphp-inline
*
* @param TL\RpcResponse $response
* @return TL\RpcFunctionReturnResult
*/
public static function result(TL\RpcResponse $response) {
return self::functionReturnValue($response->getResult());
}

/**
* @kphp-inline
*
* @return string
*/
public function getTLFunctionName() {
return 'rpcDestActorFlags';
}

}

/**
* @kphp-tl-class
*/
class rpcDestActorFlags_result implements TL\RpcFunctionReturnResult {

/** @var TL\RpcFunctionReturnResult */
public $value = null;

}
Loading

0 comments on commit 31a0030

Please sign in to comment.