diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b78b6aed..31e98ebcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Next release +- refactor(rpc): re-worked rpc tower server and added proper websocket support - fix(network): added the FGW and gateway url to the chain config - fix(block_hash): block hash mismatch on transaction with an empty signature - feat: declare v0, l1 handler support added diff --git a/Cargo.lock b/Cargo.lock index 473090db1..0db7ba339 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5702,6 +5702,7 @@ dependencies = [ "anyhow", "blockifier", "lazy_static", + "log", "mp-utils", "primitive-types", "rstest 0.18.2", diff --git a/crates/client/rpc/src/RPC.md b/crates/client/rpc/src/RPC.md new file mode 100644 index 000000000..f10d9504f --- /dev/null +++ b/crates/client/rpc/src/RPC.md @@ -0,0 +1,94 @@ +# RPC + +_This section consists of a brief overview of RPC handling architecture inside +of Madara, as its structure can be quite confusing at first._ + +## Properties + +Madara RPC has the folliwing properties: + +**Each RPC category is independent** and decoupled from the rest, so `trace` +methods exist in isolation from `read` methods for example. + +**RPC methods are versioned**. It is therefore possible for a user to call +_different versions_ of the same RPC method. This is mostly present for ease of +development of new RPC versions, but also serves to assure a level of backwards +compatibility. To select a specific version of an rpc method, you will need to +append `/rcp/v{version}` to the rpc url you are connecting to. + +**RPC versions are grouped under the `Starknet` struct**. This serves as a +common point of implementation for all RPC methods across all versions, and is +also the point of interaction between RPC methods and the node backend. + +> [!NOTE] +> All of this is regrouped as an RPC _service_. + +## Implementation details + +There are **two** main parts to the implementation of RPC methods in Madara. + +### Jsonrpsee implementation + +> [!NOTE] > `jsonrpsee` is a library developed by Parity which is used to implement JSON +> RPC APIs through simple macro logic. + +Each RPC version is defined under the `version` folder using the +`versioned_starknet_rpc` macro. This just serves to rename the trait it is +defined on and all jsonrpsee `#[method]` definitions to include the version +name. The latter is especially important as it avoids name clashes when merging +multiple `RpcModule`s from different versions together. + +#### Renaming + +```rust +#[versioned_starknet_rpc("V0_7_1)")] +trait yourTrait { + #[method(name = "foo")] + async fn foo(); +} +``` + +Will become + +```rust +#[jsonrpsee::proc_macros::rpc(server, namespace = "starknet")] +trait yourTraitV0_7_1 { + #[method(name = "V0_7_1_foo")] + async fn foo(); +} +``` + +### Implementation as a service + +> [!IMPORTANT] +> This is where the RPC server is set up and where RPC calls are actually +> parsed, validated, routed and handled. + +`RpcService` is responsible for starting the rpc service, and hence the rpc +server. This is done with tower in the following steps: + +- RPC apis are built and combined into a single `RpcModule` using + `versioned_rpc_api`, and all other configurations are loaded. + +- Request filtering middleware is set up. This includes host origin filtering + and CORS filtering. + +> [!NOTE] +> Rpc middleware will apply to both websocket and http rpc requests, which is +> why we do not apply versioning in the http middleware. + +- Request constraints are set, such as the maximum number of connections and + request / response size constraints. + +- Additional service layers are added on each rpc call inside `service_fn`. + These are composed into versioning, rate limiting (which is optional) and + metrics layers. Importantly, version filtering with `RpcMiddlewareServiceVersion` + will transforms rpc methods request with header `/rpc/v{version}` and a json rpc + body with a `{method}` field into the correct `starknet_{version}_{method}` rpc + method call, as this is how we version them internally with jsonrpsee. + +> [!NOTE] +> The `starknet` prefix comes from the secondary macro expansion of +> `#[rpc(server, namespace = "starknet)]` + +- Finally, the RPC service is added to tower as `RpcServiceBuilder`. Note that diff --git a/crates/client/rpc/src/lib.rs b/crates/client/rpc/src/lib.rs index f63ff9e8e..fbe64f795 100644 --- a/crates/client/rpc/src/lib.rs +++ b/crates/client/rpc/src/lib.rs @@ -4,7 +4,6 @@ mod constants; mod errors; -mod macros; pub mod providers; #[cfg(test)] pub mod test_utils; @@ -99,14 +98,25 @@ pub fn versioned_rpc_api( write: bool, trace: bool, internal: bool, + ws: bool, ) -> anyhow::Result> { let mut rpc_api = RpcModule::new(()); - merge_rpc_versions!( - rpc_api, starknet, read, write, trace, internal, - v0_7_1, // We can add new versions by adding the version module below - // , v0_8_0 (for example) - ); + if read { + rpc_api.merge(versions::v0_7_1::StarknetReadRpcApiV0_7_1Server::into_rpc(starknet.clone()))?; + } + if write { + rpc_api.merge(versions::v0_7_1::StarknetWriteRpcApiV0_7_1Server::into_rpc(starknet.clone()))?; + } + if trace { + rpc_api.merge(versions::v0_7_1::StarknetTraceRpcApiV0_7_1Server::into_rpc(starknet.clone()))?; + } + if internal { + rpc_api.merge(versions::v0_7_1::MadaraWriteRpcApiV0_7_1Server::into_rpc(starknet.clone()))?; + } + if ws { + // V0.8.0 ... + } Ok(rpc_api) } diff --git a/crates/client/rpc/src/macros.rs b/crates/client/rpc/src/macros.rs deleted file mode 100644 index cb46b06dc..000000000 --- a/crates/client/rpc/src/macros.rs +++ /dev/null @@ -1,21 +0,0 @@ -#[macro_export] -macro_rules! merge_rpc_versions { - ($rpc_api:expr, $starknet:expr, $read:expr, $write:expr, $trace:expr, $internal:expr, $($version:ident),+ $(,)?) => { - $( - paste::paste! { - if $read { - $rpc_api.merge(versions::[<$version>]::[]::into_rpc($starknet.clone()))?; - } - if $write { - $rpc_api.merge(versions::[<$version>]::[]::into_rpc($starknet.clone()))?; - } - if $trace { - $rpc_api.merge(versions::[<$version>]::[]::into_rpc($starknet.clone()))?; - } - if $internal { - $rpc_api.merge(versions::[<$version>]::[]::into_rpc($starknet.clone()))?; - } - } - )+ - }; -} diff --git a/crates/client/rpc/src/versions/v0_7_1/api.rs b/crates/client/rpc/src/versions/v0_7_1/api.rs index 40dcb202c..2c5c2aa45 100644 --- a/crates/client/rpc/src/versions/v0_7_1/api.rs +++ b/crates/client/rpc/src/versions/v0_7_1/api.rs @@ -1,5 +1,4 @@ use jsonrpsee::core::RpcResult; -use jsonrpsee::proc_macros::rpc; use starknet_core::types::{ BlockHashAndNumber, BlockId, BroadcastedDeclareTransaction, BroadcastedDeployAccountTransaction, BroadcastedInvokeTransaction, BroadcastedTransaction, ContractClass, DeclareTransactionResult, diff --git a/crates/node/src/cli/rpc.rs b/crates/node/src/cli/rpc.rs index 48d18d8a5..63ded0ea5 100644 --- a/crates/node/src/cli/rpc.rs +++ b/crates/node/src/cli/rpc.rs @@ -4,7 +4,6 @@ use std::num::NonZeroU32; use std::str::FromStr; use clap::ValueEnum; -use ip_network::IpNetwork; use jsonrpsee::server::BatchRequestConfig; /// Available RPC methods. @@ -99,21 +98,6 @@ pub struct RpcParams { #[arg(env = "MADARA_RPC_RATE_LIMIT", long)] pub rpc_rate_limit: Option, - /// Disable RPC rate limiting for certain ip addresses or ranges. - /// - /// Each IP address must be in the following notation: `1.2.3.4/24`. - #[arg(env = "MADARA_RPC_RATE_LIMIT_WHITELISTED_IPS", long, num_args = 1..)] - pub rpc_rate_limit_whitelisted_ips: Vec, - - /// Trust proxy headers for disable rate limiting. - /// - /// When using a reverse proxy setup, the real requester IP is usually added to the headers as `X-Real-IP` or `X-Forwarded-For`. - /// By default, the RPC server will not trust these headers. - /// - /// This is currently only useful for rate-limiting reasons. - #[arg(env = "MADARA_RPC_RATE_LIMIT_TRUST_PROXY_HEADERS", long)] - pub rpc_rate_limit_trust_proxy_headers: bool, - /// Set the maximum RPC request payload size for both HTTP and WebSockets in megabytes. #[arg(env = "MADARA_RPC_MAX_REQUEST_SIZE", long, default_value_t = RPC_DEFAULT_MAX_REQUEST_SIZE_MB)] pub rpc_max_request_size: u32, @@ -147,14 +131,25 @@ pub struct RpcParams { #[arg(env = "MADARA_RPC_MAX_BATCH_REQUEST_LEN", long, conflicts_with_all = &["rpc_disable_batch_requests"], value_name = "LEN")] pub rpc_max_batch_request_len: Option, - /// Specify browser *origins* allowed to access the HTTP & WebSocket RPC servers. + /// Specify browser *origins* allowed to access the HTTP & WebSocket RPC + /// servers. /// /// For most purposes, an origin can be thought of as just `protocol://domain`. - /// By default, only browser requests from localhost will work. + /// Default behavior depends on `rpc_external`: + /// + /// - If rpc_external is set, CORS will default to allow all incoming + /// addresses. + /// - If rpc_external is not set, CORS will default to allow only + /// connections from `localhost`. /// - /// This argument is a comma separated list of origins, or the special `all` value. + /// > If the rpcs are permissive, the same will be true for core, and + /// > vise-versa. /// - /// Learn more about CORS and web security at . + /// This argument is a comma separated list of origins, or the special `all` + /// value. + /// + /// Learn more about CORS and web security at + /// . #[arg(env = "MADARA_RPC_CORS", long, value_name = "ORIGINS")] pub rpc_cors: Option, } @@ -162,12 +157,16 @@ pub struct RpcParams { impl RpcParams { pub fn cors(&self) -> Option> { let cors = self.rpc_cors.clone().unwrap_or_else(|| { - Cors::List(vec![ - "http://localhost:*".into(), - "http://127.0.0.1:*".into(), - "https://localhost:*".into(), - "https://127.0.0.1:*".into(), - ]) + if self.rpc_external { + Cors::All + } else { + Cors::List(vec![ + "http://localhost:*".into(), + "http://127.0.0.1:*".into(), + "https://localhost:*".into(), + "https://127.0.0.1:*".into(), + ]) + } }); match cors { diff --git a/crates/node/src/service/rpc.rs b/crates/node/src/service/rpc.rs index 7254a437f..3f464dde9 100644 --- a/crates/node/src/service/rpc.rs +++ b/crates/node/src/service/rpc.rs @@ -46,8 +46,8 @@ impl RpcService { (true, false) } }; - let (read, write, trace, internal) = (rpcs, rpcs, rpcs, node_operator); - let starknet = Starknet::new(Arc::clone(db.backend()), chain_config.clone(), add_txs_method_provider.clone()); + let (read, write, trace, internal, ws) = (rpcs, rpcs, rpcs, node_operator, rpcs); + let starknet = Starknet::new(Arc::clone(db.backend()), chain_config.clone(), add_txs_method_provider); let metrics = RpcMetrics::register(metrics_handle)?; Ok(Self { @@ -59,12 +59,10 @@ impl RpcService { max_payload_out_mb: config.rpc_max_response_size, max_subs_per_conn: config.rpc_max_subscriptions_per_connection, message_buffer_capacity: config.rpc_message_buffer_capacity_per_connection, - rpc_api: versioned_rpc_api(&starknet, read, write, trace, internal)?, + rpc_api: versioned_rpc_api(&starknet, read, write, trace, internal, ws)?, metrics, cors: config.cors(), rate_limit: config.rpc_rate_limit, - rate_limit_whitelisted_ips: config.rpc_rate_limit_whitelisted_ips.clone(), - rate_limit_trust_proxy_headers: config.rpc_rate_limit_trust_proxy_headers, }), server_handle: None, }) diff --git a/crates/node/src/service/rpc/middleware.rs b/crates/node/src/service/rpc/middleware.rs index 75a6a4d7e..f92f9e2f6 100644 --- a/crates/node/src/service/rpc/middleware.rs +++ b/crates/node/src/service/rpc/middleware.rs @@ -1,10 +1,7 @@ //! JSON-RPC specific middleware. -use std::future::Future; use std::num::NonZeroU32; -use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; use std::time::{Duration, Instant}; use futures::future::{BoxFuture, FutureExt}; @@ -12,128 +9,158 @@ use governor::clock::{Clock, DefaultClock, QuantaClock}; use governor::middleware::NoOpMiddleware; use governor::state::{InMemoryState, NotKeyed}; use governor::{Jitter, Quota, RateLimiter}; -use hyper::{Body, Response}; use jsonrpsee::server::middleware::rpc::RpcServiceT; -use jsonrpsee::types::{ErrorObject, Request}; -use jsonrpsee::MethodResponse; -use serde_json::{json, Value}; -use tower::{Layer, Service}; -use mp_chain_config::{RpcVersion, RpcVersionError}; +use mp_chain_config::RpcVersion; -pub use super::metrics::{Metrics, RpcMetrics}; +pub use super::metrics::Metrics; /// Rate limit middleware #[derive(Debug, Clone)] pub struct RateLimit { - pub(crate) inner: Arc>, + pub(crate) limiter: Arc>, pub(crate) clock: QuantaClock, } impl RateLimit { pub fn new(max_burst: NonZeroU32) -> Self { let clock = QuantaClock::default(); - Self { inner: Arc::new(RateLimiter::direct_with_clock(Quota::per_minute(max_burst), &clock)), clock } + Self { limiter: Arc::new(RateLimiter::direct_with_clock(Quota::per_minute(max_burst), &clock)), clock } } } const MAX_JITTER: Duration = Duration::from_millis(50); const MAX_RETRIES: usize = 10; -#[derive(Debug, Clone, Default)] -pub struct MiddlewareLayer { - rate_limit: Option, - metrics: Option, +#[derive(Debug, Clone)] +pub struct RpcMiddlewareLayerRateLimit { + rate_limit: RateLimit, } -impl MiddlewareLayer { - pub fn new() -> Self { - Self::default() +impl RpcMiddlewareLayerRateLimit { + pub fn new(n: NonZeroU32) -> Self { + Self { rate_limit: RateLimit::new(n) } } +} + +impl tower::Layer for RpcMiddlewareLayerRateLimit { + type Service = RpcMiddlewareServiceRateLimit; + + fn layer(&self, inner: S) -> Self::Service { + RpcMiddlewareServiceRateLimit { inner, rate_limit: self.rate_limit.clone() } + } +} + +#[derive(Debug, Clone)] +pub struct RpcMiddlewareServiceRateLimit { + inner: S, + rate_limit: RateLimit, +} + +impl<'a, S> RpcServiceT<'a> for RpcMiddlewareServiceRateLimit +where + S: Send + Sync + Clone + RpcServiceT<'a> + 'static, +{ + type Future = BoxFuture<'a, jsonrpsee::MethodResponse>; + + fn call(&self, mut req: jsonrpsee::types::Request<'a>) -> Self::Future { + let inner = self.inner.clone(); + let rate_limit = self.rate_limit.clone(); + + async move { + let mut attempts = 0; + let jitter = Jitter::up_to(MAX_JITTER); + let mut rate_limited = false; + + loop { + if attempts >= MAX_RETRIES { + return jsonrpsee::MethodResponse::error( + req.id, + jsonrpsee::types::ErrorObject::owned(-32099, "RPC rate limit exceeded", None::<()>), + ); + } + + if let Err(rejected) = rate_limit.limiter.check() { + tokio::time::sleep(jitter + rejected.wait_time_from(rate_limit.clock.now())).await; + rate_limited = true; + } else { + break; + } + + attempts += 1; + } - /// Enable new rate limit middleware enforced per minute. - pub fn with_rate_limit_per_minute(self, n: NonZeroU32) -> Self { - Self { rate_limit: Some(RateLimit::new(n)), metrics: self.metrics } + // This should be ok as a way to flag rate limited requests as the + // JSON RPC spec discourages the use of NULL as an id in a _request_ + // since it is used for _responses_ with an unknown id. + if rate_limited { + req.id = jsonrpsee::types::Id::Null; + } + + inner.call(req).await + } + .boxed() } +} +#[derive(Debug, Clone)] +pub struct RpcMiddlewareLayerMetrics { + metrics: Metrics, +} + +impl RpcMiddlewareLayerMetrics { /// Enable metrics middleware. - pub fn with_metrics(self, metrics: Metrics) -> Self { - Self { rate_limit: self.rate_limit, metrics: Some(metrics) } + pub fn new(metrics: Metrics) -> Self { + Self { metrics } } /// Register a new websocket connection. pub fn ws_connect(&self) { - if let Some(m) = self.metrics.as_ref() { - m.ws_connect() - } + self.metrics.ws_connect() } /// Register that a websocket connection was closed. pub fn ws_disconnect(&self, now: Instant) { - if let Some(m) = self.metrics.as_ref() { - m.ws_disconnect(now) - } + self.metrics.ws_disconnect(now) } } -impl tower::Layer for MiddlewareLayer { - type Service = Middleware; +impl tower::Layer for RpcMiddlewareLayerMetrics { + type Service = RpcMiddlewareServiceMetrics; - fn layer(&self, service: S) -> Self::Service { - Middleware { service, rate_limit: self.rate_limit.clone(), metrics: self.metrics.clone() } + fn layer(&self, inner: S) -> Self::Service { + RpcMiddlewareServiceMetrics { inner, metrics: self.metrics.clone() } } } -pub struct Middleware { - service: S, - rate_limit: Option, - metrics: Option, +#[derive(Debug, Clone)] +pub struct RpcMiddlewareServiceMetrics { + inner: S, + metrics: Metrics, } -impl<'a, S> RpcServiceT<'a> for Middleware +impl<'a, S> RpcServiceT<'a> for RpcMiddlewareServiceMetrics where - S: Send + Sync + RpcServiceT<'a> + Clone + 'static, + S: Send + Sync + Clone + RpcServiceT<'a> + 'static, { - type Future = BoxFuture<'a, MethodResponse>; + type Future = BoxFuture<'a, jsonrpsee::MethodResponse>; - fn call(&self, req: Request<'a>) -> Self::Future { - let now = Instant::now(); - - if let Some(m) = self.metrics.as_ref() { - m.on_call(&req) - } - - let service = self.service.clone(); - let rate_limit = self.rate_limit.clone(); + fn call(&self, mut req: jsonrpsee::types::Request<'a>) -> Self::Future { + let inner = self.inner.clone(); let metrics = self.metrics.clone(); async move { - let mut is_rate_limited = false; - - if let Some(limit) = rate_limit.as_ref() { - let mut attempts = 0; - let jitter = Jitter::up_to(MAX_JITTER); - - loop { - if attempts >= MAX_RETRIES { - return MethodResponse::error( - req.id, - ErrorObject::owned(-32999, "RPC rate limit exceeded", None::<()>), - ); - } - - if let Err(rejected) = limit.inner.check() { - tokio::time::sleep(jitter + rejected.wait_time_from(limit.clock.now())).await; - } else { - break; - } - - is_rate_limited = true; - attempts += 1; - } - } + let is_rate_limited = if matches!(req.id, jsonrpsee::types::params::Id::Null) { + req.id = jsonrpsee::types::params::Id::Number(1); + true + } else { + false + }; + + let now = std::time::Instant::now(); - let rp = service.call(req.clone()).await; + metrics.on_call(&req); + let rp = inner.call(req.clone()).await; let method = req.method_name(); let status = rp.as_error_code().unwrap_or(200); @@ -149,9 +176,7 @@ where "{method} {status} {res_len} - {response_time:?}", ); - if let Some(m) = metrics.as_ref() { - m.on_response(&req, &rp, is_rate_limited, now) - } + metrics.on_response(&req, &rp, is_rate_limited, now); rp } @@ -159,139 +184,60 @@ where } } -#[derive(Clone)] -pub struct VersionMiddleware { +#[derive(Debug, Clone)] +pub struct RpcMiddlewareServiceVersion { inner: S, + path: String, } -#[derive(thiserror::Error, Debug)] -enum VersionMiddlewareError { - #[error("Failed to read request body: {0}")] - BodyReadError(#[from] hyper::Error), - #[error("Failed to parse JSON: {0}")] - JsonParseError(#[from] serde_json::Error), - #[error("Invalid URL format")] - InvalidUrlFormat, - #[error("Invalid version specified")] - InvalidVersion, - #[error("Unsupported version specified")] - UnsupportedVersion, - #[error("Invalid method format. Namespace required: {0}")] - InvalidMethodFormat(String), - #[error("Missing method in RPC request")] - MissingMethod, -} - -impl From for VersionMiddlewareError { - fn from(e: RpcVersionError) -> Self { - match e { - RpcVersionError::InvalidNumber(_) => Self::InvalidVersion, - RpcVersionError::InvalidPathSupplied => Self::InvalidUrlFormat, - RpcVersionError::InvalidVersion => Self::InvalidVersion, - RpcVersionError::TooManyComponents(_) => Self::InvalidVersion, - RpcVersionError::UnsupportedVersion => Self::UnsupportedVersion, - } - } -} - -impl VersionMiddleware { - pub fn new(inner: S) -> Self { - Self { inner } - } -} - -#[derive(Clone)] -pub struct VersionMiddlewareLayer; - -impl Layer for VersionMiddlewareLayer { - type Service = VersionMiddleware; - - fn layer(&self, inner: S) -> Self::Service { - VersionMiddleware::new(inner) +impl RpcMiddlewareServiceVersion { + pub fn new(inner: S, path: String) -> Self { + Self { inner, path } } } -impl Service> for VersionMiddleware +impl<'a, S> RpcServiceT<'a> for RpcMiddlewareServiceVersion where - S: Service, Response = Response> + Clone + Send + 'static, - S::Future: Send + 'static, + S: Send + Sync + Clone + RpcServiceT<'a> + 'static, { - type Response = S::Response; - type Error = S::Error; - type Future = Pin> + Send + 'static>>; + type Future = BoxFuture<'a, jsonrpsee::MethodResponse>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } + fn call(&self, mut req: jsonrpsee::types::Request<'a>) -> Self::Future { + let inner = self.inner.clone(); + let path = self.path.clone(); - fn call(&mut self, mut req: hyper::Request) -> Self::Future { - let mut inner = self.inner.clone(); - - Box::pin(async move { - match add_rpc_version_to_method(&mut req).await { - Ok(()) => inner.call(req).await, - Err(e) => { - let error = match e { - VersionMiddlewareError::InvalidUrlFormat => { - ErrorObject::owned(-32600, "Invalid URL format. Use /rpc/v{version}", None::<()>) - } - VersionMiddlewareError::InvalidVersion => { - ErrorObject::owned(-32600, "Invalid RPC version specified", None::<()>) - } - VersionMiddlewareError::UnsupportedVersion => { - ErrorObject::owned(-32601, "Unsupported RPC version specified", None::<()>) - } - _ => ErrorObject::owned(-32603, "Internal error", None::<()>), - }; - - let body = json!({ - "jsonrpc": "2.0", - "error": error, - "id": null - }) - .to_string(); - - Ok(Response::builder() - .header("Content-Type", "application/json") - .body(Body::from(body)) - .unwrap_or_else(|_| Response::new(Body::from("Internal server error")))) - } + async move { + if req.method == "rpc_methods" { + return inner.call(req).await; } - }) - } -} -async fn add_rpc_version_to_method(req: &mut hyper::Request) -> Result<(), VersionMiddlewareError> { - let path = req.uri().path().to_string(); - let version = RpcVersion::from_request_path(&path)?; - - let whole_body = hyper::body::to_bytes(req.body_mut()).await?; - let json: Value = serde_json::from_slice(&whole_body)?; - - // in case of batched requests, the request is an array of JSON-RPC requests - let mut batched_request = false; - let mut items = if let Value::Array(items) = json { - batched_request = true; - items - } else { - vec![json] - }; - - for item in items.iter_mut() { - if let Some(method) = item.get_mut("method").as_deref().and_then(Value::as_str) { - let new_method = if let Some((prefix, suffix)) = method.split_once('_') { - format!("{}_{}_{}", prefix, version.name(), suffix) - } else { - return Err(VersionMiddlewareError::InvalidMethodFormat(method.to_string())); + let Ok(version) = RpcVersion::from_request_path(&path) else { + return jsonrpsee::MethodResponse::error( + req.id, + jsonrpsee::types::ErrorObject::owned( + jsonrpsee::types::error::PARSE_ERROR_CODE, + jsonrpsee::types::error::PARSE_ERROR_MSG, + None::<()>, + ), + ); }; - item["method"] = Value::String(new_method); - } else { - return Err(VersionMiddlewareError::MissingMethod); - } - } - let response = if batched_request { serde_json::to_vec(&items)? } else { serde_json::to_vec(&items[0])? }; - *req.body_mut() = Body::from(response); + let Some(method_without_namespace) = req.method.strip_prefix("starknet_") else { + return jsonrpsee::MethodResponse::error( + req.id(), + jsonrpsee::types::ErrorObject::owned( + jsonrpsee::types::error::METHOD_NOT_FOUND_CODE, + jsonrpsee::types::error::METHOD_NOT_FOUND_MSG, + Some(req.method_name()), + ), + ); + }; - Ok(()) + let method_new = format!("starknet_{}_{}", version.name(), method_without_namespace); + req.method = jsonrpsee::core::Cow::from(method_new); + + inner.call(req).await + } + .boxed() + } } diff --git a/crates/node/src/service/rpc/server.rs b/crates/node/src/service/rpc/server.rs index 379244203..4ab03657a 100644 --- a/crates/node/src/service/rpc/server.rs +++ b/crates/node/src/service/rpc/server.rs @@ -2,31 +2,20 @@ #![allow(clippy::borrow_interior_mutable_const)] use std::convert::Infallible; -use std::net::{IpAddr, SocketAddr}; +use std::net::SocketAddr; use std::num::NonZeroU32; -use std::str::FromStr; use std::time::Duration; use anyhow::Context; -use forwarded_header_value::ForwardedHeaderValue; -use hyper::header::{HeaderName, HeaderValue}; -use hyper::server::conn::AddrStream; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request, Response, StatusCode}; -use ip_network::IpNetwork; -use jsonrpsee::core::id_providers::RandomStringIdProvider; -use jsonrpsee::server::middleware::http::HostFilterLayer; -use jsonrpsee::server::middleware::rpc::RpcServiceBuilder; -use jsonrpsee::server::{stop_channel, ws, BatchRequestConfig, PingConfig, StopHandle, TowerServiceBuilder}; -use jsonrpsee::{Methods, RpcModule}; -use tokio::net::TcpListener; use tokio::task::JoinSet; use tower::Service; -use tower_http::cors::{AllowOrigin, CorsLayer}; use mp_utils::wait_or_graceful_shutdown; -use super::middleware::{Metrics, MiddlewareLayer, RpcMetrics, VersionMiddlewareLayer}; +use crate::service::rpc::middleware::{RpcMiddlewareLayerRateLimit, RpcMiddlewareServiceVersion}; + +use super::metrics::RpcMetrics; +use super::middleware::{Metrics, RpcMiddlewareLayerMetrics}; const MEGABYTE: u32 = 1024 * 1024; @@ -41,23 +30,19 @@ pub struct ServerConfig { pub max_payload_out_mb: u32, pub metrics: RpcMetrics, pub message_buffer_capacity: u32, - pub rpc_api: RpcModule<()>, + pub rpc_api: jsonrpsee::RpcModule<()>, /// Batch request config. - pub batch_config: BatchRequestConfig, + pub batch_config: jsonrpsee::server::BatchRequestConfig, /// Rate limit calls per minute. pub rate_limit: Option, - /// Disable rate limit for certain ips. - pub rate_limit_whitelisted_ips: Vec, - /// Trust proxy headers for rate limiting. - pub rate_limit_trust_proxy_headers: bool, } #[derive(Debug, Clone)] struct PerConnection { - methods: Methods, - stop_handle: StopHandle, + methods: jsonrpsee::Methods, + stop_handle: jsonrpsee::server::StopHandle, metrics: RpcMetrics, - service_builder: TowerServiceBuilder, + service_builder: jsonrpsee::server::TowerServiceBuilder, } /// Start RPC server listening on given address. @@ -77,102 +62,77 @@ pub async fn start_server( message_buffer_capacity, rpc_api, rate_limit, - rate_limit_whitelisted_ips, - rate_limit_trust_proxy_headers, } = config; - let std_listener = TcpListener::bind(addr) + let listener = tokio::net::TcpListener::bind(addr) .await - .and_then(|a| a.into_std()) - .with_context(|| format!("binding to address: {addr}"))?; - let local_addr = std_listener.local_addr().ok(); - let host_filter = host_filtering(cors.is_some(), local_addr); + .with_context(|| format!("Binding TCP listener to address: {addr}"))?; + let local_addr = listener.local_addr().context("Failed to retrieve local address after binding TCP listener")?; + + let ping_config = jsonrpsee::server::PingConfig::new() + .ping_interval(Duration::from_secs(30)) + .inactive_limit(Duration::from_secs(60)) + .max_failures(3); let http_middleware = tower::ServiceBuilder::new() - .option_layer(host_filter) - // Proxy `GET /health` requests to internal `system_health` method. - // .layer(ProxyGetRequestLayer::new("/health", "system_health")?) - .layer(VersionMiddlewareLayer) - .layer(try_into_cors(cors.as_ref())?); + .option_layer(host_filtering(cors.is_some(), local_addr)) + .layer(try_into_cors(cors.as_ref())?); let builder = jsonrpsee::server::Server::builder() .max_request_body_size(max_payload_in_mb.saturating_mul(MEGABYTE)) .max_response_body_size(max_payload_out_mb.saturating_mul(MEGABYTE)) .max_connections(max_connections) .max_subscriptions_per_connection(max_subs_per_conn) - .enable_ws_ping( - PingConfig::new() - .ping_interval(Duration::from_secs(30)) - .inactive_limit(Duration::from_secs(60)) - .max_failures(3), - ) - .set_http_middleware(http_middleware) + .enable_ws_ping(ping_config) .set_message_buffer_capacity(message_buffer_capacity) .set_batch_request_config(batch_config) - .set_id_provider(RandomStringIdProvider::new(16)); + .set_http_middleware(http_middleware) + .set_id_provider(jsonrpsee::server::RandomStringIdProvider::new(16)); - let (stop_handle, server_handle) = stop_channel(); + let (stop_handle, server_handle) = jsonrpsee::server::stop_channel(); let cfg = PerConnection { methods: build_rpc_api(rpc_api).into(), - service_builder: builder.to_service_builder(), - metrics, stop_handle: stop_handle.clone(), + metrics, + service_builder: builder.to_service_builder(), }; - let make_service = make_service_fn(move |addr: &AddrStream| { + let make_service = hyper::service::make_service_fn(move |_| { let cfg = cfg.clone(); - let rate_limit_whitelisted_ips = rate_limit_whitelisted_ips.clone(); - let ip = addr.remote_addr().ip(); async move { let cfg = cfg.clone(); - let rate_limit_whitelisted_ips = rate_limit_whitelisted_ips.clone(); - - Ok::<_, Infallible>(service_fn(move |req| { - let proxy_ip = if rate_limit_trust_proxy_headers { get_proxy_ip(&req) } else { None }; - - let rate_limit_cfg = if rate_limit_whitelisted_ips - .iter() - .any(|ips| ips.contains(proxy_ip.unwrap_or(ip))) - { - log::debug!(target: "rpc", "ip={ip}, proxy_ip={:?} is trusted, disabling rate-limit", proxy_ip); - None - } else { - if !rate_limit_whitelisted_ips.is_empty() { - log::debug!(target: "rpc", "ip={ip}, proxy_ip={:?} is not trusted, rate-limit enabled", proxy_ip); - } - rate_limit - }; + Ok::<_, Infallible>(hyper::service::service_fn(move |req| { let PerConnection { service_builder, metrics, stop_handle, methods } = cfg.clone(); - let is_websocket = ws::is_upgrade_request(&req); + let is_websocket = jsonrpsee::server::ws::is_upgrade_request(&req); let transport_label = if is_websocket { "ws" } else { "http" }; + let path = req.uri().path().to_string(); + let metrics_layer = RpcMiddlewareLayerMetrics::new(Metrics::new(metrics, transport_label)); - let middleware_layer = match rate_limit_cfg { - None => MiddlewareLayer::new().with_metrics(Metrics::new(metrics, transport_label)), - Some(rate_limit) => MiddlewareLayer::new() - .with_metrics(Metrics::new(metrics, transport_label)) - .with_rate_limit_per_minute(rate_limit), - }; - - let rpc_middleware = RpcServiceBuilder::new().layer(middleware_layer.clone()); + let rpc_middleware = jsonrpsee::server::RpcServiceBuilder::new() + .layer_fn(move |service| RpcMiddlewareServiceVersion::new(service, path.clone())) + .option_layer(rate_limit.map(RpcMiddlewareLayerRateLimit::new)) + .layer(metrics_layer.clone()); let mut svc = service_builder.set_rpc_middleware(rpc_middleware).build(methods, stop_handle); async move { if req.uri().path() == "/health" { - Ok(Response::builder().status(StatusCode::OK).body(Body::from("OK"))?) + Ok(hyper::Response::builder().status(hyper::StatusCode::OK).body(hyper::Body::from("OK"))?) } else { if is_websocket { + // Utilize the session close future to know when the actual WebSocket + // session was closed. let on_disconnect = svc.on_session_closed(); // Spawn a task to handle when the connection is closed. tokio::spawn(async move { let now = std::time::Instant::now(); - middleware_layer.ws_connect(); + metrics_layer.ws_connect(); on_disconnect.await; - middleware_layer.ws_disconnect(now); + metrics_layer.ws_disconnect(now); }); } @@ -183,16 +143,12 @@ pub async fn start_server( } }); - let server = hyper::Server::from_tcp(std_listener) + let server = hyper::Server::from_tcp(listener.into_std()?) .with_context(|| format!("Creating hyper server at: {addr}"))? .serve(make_service); join_set.spawn(async move { - log::info!( - "📱 Running JSON-RPC server at {} (allowed origins={})", - local_addr.map_or_else(|| "unknown".to_string(), |a| a.to_string()), - format_cors(cors.as_ref()) - ); + log::info!("📱 Running JSON-RPC server at {} (allowed origins={})", local_addr, format_cors(cors.as_ref())); server .with_graceful_shutdown(async { wait_or_graceful_shutdown(stop_handle.shutdown()).await; @@ -204,31 +160,49 @@ pub async fn start_server( Ok(server_handle) } -const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); -const X_REAL_IP: HeaderName = HeaderName::from_static("x-real-ip"); -const FORWARDED: HeaderName = HeaderName::from_static("forwarded"); - -pub(crate) fn host_filtering(enabled: bool, addr: Option) -> Option { - // If the local_addr failed, fallback to wildcard. - let port = addr.map_or("*".to_string(), |p| p.port().to_string()); - +// Copied from https://github.com/paritytech/polkadot-sdk/blob/a0aefc6b233ace0a82a8631d67b6854e6aeb014b/substrate/client/rpc-servers/src/utils.rs#L192 +pub(crate) fn host_filtering( + enabled: bool, + addr: SocketAddr, +) -> Option { if enabled { // NOTE: The listening addresses are whitelisted by default. - let hosts = [format!("localhost:{port}"), format!("127.0.0.1:{port}"), format!("[::1]:{port}")]; - Some(HostFilterLayer::new(hosts).expect("Invalid host filter")) + + let mut hosts = Vec::new(); + + if addr.is_ipv4() { + hosts.push(format!("localhost:{}", addr.port())); + hosts.push(format!("127.0.0.1:{}", addr.port())); + } else { + hosts.push(format!("[::1]:{}", addr.port())); + } + + Some(jsonrpsee::server::middleware::http::HostFilterLayer::new(hosts).expect("Valid hosts; qed")) } else { None } } -pub(crate) fn build_rpc_api(mut rpc_api: RpcModule) -> RpcModule { - let mut available_methods = rpc_api.method_names().collect::>(); +pub(crate) fn build_rpc_api(mut rpc_api: jsonrpsee::RpcModule) -> jsonrpsee::RpcModule { + let mut available_methods = rpc_api + .method_names() + .map(|name| { + let mut split = name.split("_"); + let namespace = split.next().expect("Should not be empty"); + let major = split.next().expect("Should not be empty"); + let minor = split.next().expect("Should not be empty"); + let patch = split.next().expect("Should not be empty"); + let method = split.next().expect("Should not be empty"); + + format!("rpc/{major}_{minor}_{patch}/{namespace}_{method}") + }) + .collect::>(); // The "rpc_methods" is defined below and we want it to be part of the reported methods. // The available methods will be prefixed by their version, example: - // * starknet_V0_7_1_blockNumber, - // * starknet_V0_8_0_blockNumber (...) - available_methods.push("rpc_methods"); + // * rpc/v0_7_1/starknet_blockNumber, + // * rpc/v0_8_0/starknet_blockNumber (...) + available_methods.push("rpc/rpc_methods".to_string()); available_methods.sort(); rpc_api @@ -242,16 +216,15 @@ pub(crate) fn build_rpc_api(mut rpc_api: RpcModule) rpc_api } -pub(crate) fn try_into_cors(maybe_cors: Option<&Vec>) -> anyhow::Result { +pub(crate) fn try_into_cors(maybe_cors: Option<&Vec>) -> anyhow::Result { if let Some(cors) = maybe_cors { let mut list = Vec::new(); for origin in cors { - list.push(HeaderValue::from_str(origin)?); + list.push(hyper::header::HeaderValue::from_str(origin)?); } - Ok(CorsLayer::new().allow_origin(AllowOrigin::list(list))) + Ok(tower_http::cors::CorsLayer::new().allow_origin(tower_http::cors::AllowOrigin::list(list))) } else { - // allow all cors - Ok(CorsLayer::permissive()) + Ok(tower_http::cors::CorsLayer::permissive()) } } @@ -262,38 +235,3 @@ pub(crate) fn format_cors(maybe_cors: Option<&Vec>) -> String { format!("{:?}", ["*"]) } } - -/// Extracts the IP addr from the HTTP request. -/// -/// It is extracted in the following order: -/// 1. `Forwarded` header. -/// 2. `X-Forwarded-For` header. -/// 3. `X-Real-Ip`. -pub(crate) fn get_proxy_ip(req: &Request) -> Option { - if let Some(ip) = req - .headers() - .get(&FORWARDED) - .and_then(|v| v.to_str().ok()) - .and_then(|v| ForwardedHeaderValue::from_forwarded(v).ok()) - .and_then(|v| v.remotest_forwarded_for_ip()) - { - return Some(ip); - } - - if let Some(ip) = req - .headers() - .get(&X_FORWARDED_FOR) - .and_then(|v| v.to_str().ok()) - .and_then(|v| ForwardedHeaderValue::from_x_forwarded_for(v).ok()) - .and_then(|v| v.remotest_forwarded_for_ip()) - { - return Some(ip); - } - - if let Some(ip) = req.headers().get(&X_REAL_IP).and_then(|v| v.to_str().ok()).and_then(|v| IpAddr::from_str(v).ok()) - { - return Some(ip); - } - - None -} diff --git a/crates/primitives/chain_config/Cargo.toml b/crates/primitives/chain_config/Cargo.toml index e42898818..0df76791b 100644 --- a/crates/primitives/chain_config/Cargo.toml +++ b/crates/primitives/chain_config/Cargo.toml @@ -26,6 +26,7 @@ mp-utils.workspace = true # Other anyhow.workspace = true lazy_static.workspace = true +log.workspace = true primitive-types.workspace = true serde = { workspace = true, features = ["derive"] } serde_json.workspace = true diff --git a/crates/primitives/chain_config/src/rpc_version.rs b/crates/primitives/chain_config/src/rpc_version.rs index f60bc51dd..f86c21d65 100644 --- a/crates/primitives/chain_config/src/rpc_version.rs +++ b/crates/primitives/chain_config/src/rpc_version.rs @@ -4,6 +4,7 @@ use std::str::FromStr; lazy_static::lazy_static! { pub static ref SUPPORTED_RPC_VERSIONS: Vec = vec![ RpcVersion::RPC_VERSION_0_7_1, + RpcVersion::RPC_VERSION_0_8_0, ]; } @@ -30,28 +31,38 @@ impl RpcVersion { } pub fn from_request_path(path: &str) -> Result { + log::debug!(target: "rpc_version", "extracting rpc version from request: {path}"); + let path = path.to_ascii_lowercase(); let parts: Vec<&str> = path.split('/').collect(); + log::debug!(target: "rpc_version", "version parts are: {parts:?}"); + // If we have an empty path or just "/", fallback to latest rpc version if parts.len() == 1 || (parts.len() == 2 && parts[1].is_empty()) { + log::debug!(target: "rpc_version", "no version, defaulting to latest"); return Ok(Self::RPC_VERSION_LATEST); } // Check if the path follows the correct format, i.e. /rpc/v[version]. // If not, fallback to the latest version if parts.len() != 3 || parts[1] != "rpc" || !parts[2].starts_with('v') { + log::debug!(target: "rpc_version", "invalid version format, defaulting to latest"); return Ok(Self::RPC_VERSION_LATEST); } + log::debug!(target: "rpc_version", "looking for matching version..."); let version_str = &parts[2][1..]; // without the 'v' prefix if let Ok(version) = RpcVersion::from_str(version_str) { if SUPPORTED_RPC_VERSIONS.contains(&version) { + log::debug!(target: "rpc_version", "found matching version: {version}"); Ok(version) } else { + log::debug!(target: "rpc_version", "no matching version"); Err(RpcVersionError::UnsupportedVersion) } } else { + log::debug!(target: "rpc_version", "invalid version format: {version_str}"); Err(RpcVersionError::InvalidVersion) } } @@ -69,6 +80,7 @@ impl RpcVersion { } pub const RPC_VERSION_0_7_1: RpcVersion = RpcVersion([0, 7, 1]); + pub const RPC_VERSION_0_8_0: RpcVersion = RpcVersion([0, 8, 0]); pub const RPC_VERSION_LATEST: RpcVersion = Self::RPC_VERSION_0_7_1; } diff --git a/crates/proc-macros/src/lib.rs b/crates/proc-macros/src/lib.rs index 83abd1df3..432861eb9 100644 --- a/crates/proc-macros/src/lib.rs +++ b/crates/proc-macros/src/lib.rs @@ -3,7 +3,7 @@ //! This macro is a wrapper around the "rpc" macro supplied by the jsonrpsee library that generates //! a server and client traits from a given trait definition. The wrapper gets a version id and -//! prepend the version id to the trait name and to every method name (note method name refers to +//! prepends the version id to the trait name and to every method name (note method name refers to //! the name the API has for the function not the actual function name). We need this in order to be //! able to merge multiple versions of jsonrpc APIs into one server and not have a clash in method //! resolution. @@ -12,9 +12,9 @@ //! //! Given this code: //! ```rust,ignore -//! #[versioned_starknet_rpc("V0_7_1")] +//! #[versioned_rpc("V0_7_1", "starknet")] //! pub trait JsonRpc { -//! #[method(name = "blockNumber")] +//! #[method(name = "blockNumber", aliases = ["block_number"])] //! fn block_number(&self) -> anyhow::Result; //! } //! ``` @@ -23,15 +23,19 @@ //! ```rust,ignore //! #[rpc(server, namespace = "starknet")] //! pub trait JsonRpcV0_7_1 { -//! #[method(name = "V0_7_1_blockNumber")] +//! #[method(name = "V0_7_1_blockNumber", aliases = ["block_number"])] //! fn block_number(&self) -> anyhow::Result; //! } //! ``` +//! +//! > [!NOTE] +//! > This macro _will not_ override any other jsonrpsee attribute, meaning +//! > it does not currently support renaming `aliases` or `unsubscribe_aliases` use proc_macro::TokenStream; use proc_macro2::Span; use quote::quote; -use syn::{parse::Parse, parse_macro_input, Attribute, Ident, ItemTrait, LitStr, TraitItem}; +use syn::spanned::Spanned; #[derive(Debug)] struct VersionedRpcAttr { @@ -39,11 +43,11 @@ struct VersionedRpcAttr { namespace: String, } -impl Parse for VersionedRpcAttr { +impl syn::parse::Parse for VersionedRpcAttr { fn parse(input: syn::parse::ParseStream) -> syn::Result { - let version = input.parse::()?.value(); + let version = input.parse::()?.value(); input.parse::()?; - let namespace = input.parse::()?.value(); + let namespace = input.parse::()?.value(); if !version.starts_with('V') { return Err(syn::Error::new(Span::call_site(), "Version must start with 'V'")); @@ -68,11 +72,12 @@ impl Parse for VersionedRpcAttr { return Err(syn::Error::new( Span::call_site(), indoc::indoc!( - " + r#" Namespace cannot be empty. Please provide a non-empty namespace string. - Example: #[versioned_rpc(\"V0_7_1\", \"starknet\")] - " + + ex: #[versioned_rpc("V0_7_1", "starknet")] + "# ), )); } @@ -81,52 +86,129 @@ impl Parse for VersionedRpcAttr { } } -fn version_method_name(attr: &Attribute, version: &str) -> syn::Result { - let mut new_attr = attr.clone(); - attr.parse_nested_meta(|meta| { - if meta.path.is_ident("name") { - let value = meta.value()?; - let method_name: LitStr = value.parse()?; - let new_name = format!("{version}_{}", method_name.value()); - new_attr.meta = syn::parse_quote!(method(name = #new_name)); - } - Ok(()) - })?; - Ok(new_attr) +enum CallType { + Method, + Subscribe, } #[proc_macro_attribute] pub fn versioned_rpc(attr: TokenStream, input: TokenStream) -> TokenStream { - let VersionedRpcAttr { version, namespace } = parse_macro_input!(attr as VersionedRpcAttr); - let mut item_trait = parse_macro_input!(input as ItemTrait); + let VersionedRpcAttr { version, namespace } = syn::parse_macro_input!(attr as VersionedRpcAttr); + let mut item_trait = syn::parse_macro_input!(input as syn::ItemTrait); let trait_name = &item_trait.ident; - let versioned_trait_name = Ident::new(&format!("{trait_name}{version}"), trait_name.span()); - - for item in &mut item_trait.items { - if let TraitItem::Fn(method) = item { - method.attrs = method - .attrs - .iter() - .filter_map(|attr| { - if attr.path().is_ident("method") { - version_method_name(attr, &version).ok() - } else { - Some(attr.clone()) + let train_name_with_version = syn::Ident::new(&format!("{trait_name}{version}"), trait_name.span()); + + // This next section is reponsible for versioning the method name declared + // with jsonrpsee + let err = item_trait.items.iter_mut().try_fold((), |_, item| { + let syn::TraitItem::Fn(method) = item else { + return Err(syn::Error::new( + item.span(), + indoc::indoc! {r#" + Traits marked with `versioned_rpc` can only contain methods + + ex: + + #[versioned_rpc("V0_7_0", "starknet")] + trait MyTrait { + #[method(name = "foo", blocking)] + fn foo(); + } + "#}, + )); + }; + + method.attrs.iter_mut().try_fold((), |_, attr| { + // We leave these errors to be handled by jsonrpsee + let path = attr.path(); + let ident = if path.is_ident("method") { + CallType::Method + } else if path.is_ident("subscription") { + CallType::Subscribe + } else { + return Ok(()); + }; + + let syn::Meta::List(meta_list) = &attr.meta else { + return Ok(()); + }; + + // This convoluted section is just the way by which we traverse + // the macro attribute list. We are looking for: + // + // - An assignment + // - With lvalue a Path expression with literal value `name` or + // 'unsubscribe' + // - With rvalue a literal + // + // Any other attribute is skipped over and is not overwritten + let attr_args = meta_list + .parse_args_with(syn::punctuated::Punctuated::::parse_terminated) + .map_err(|_| { + syn::Error::new( + meta_list.span(), + indoc::indoc! {r#" + The `method` and `subscription` attributes expect comma-separated values. + + ex: `#[method(name = "foo", blocking)]` + "#}, + ) + })? + .into_iter() + .map(|expr| { + // There isn't really a more elegant way of doing this as + // `left` and `right` are boxed values and therefore cannot + // be pattern matched without being de-referenced + let syn::Expr::Assign(expr) = expr else { return expr }; + + let syn::Expr::Path(syn::ExprPath { path, .. }) = *expr.left.clone() else { + return syn::Expr::Assign(expr); + }; + let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(right), attrs }) = *expr.right.clone() else { + return syn::Expr::Assign(expr); + }; + + if !path.is_ident("name") && !path.is_ident("unsubscribe") { + return syn::Expr::Assign(expr); } + + let method_with_version = format!("{version}_{}", right.value()); + syn::Expr::Assign(syn::ExprAssign { + right: Box::new(syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(syn::LitStr::new(&method_with_version, right.span())), + attrs, + })), + ..expr + }) }) - .collect(); - } + .collect::>(); + + // This is the part where we actually replace the attribute with + // its versioned alternative. Note that the syntax #(#foo),* + // indicates a pattern repetition here, where all the elements in + // attr_args are expanded into rust code + attr.meta = match ident { + CallType::Method => syn::parse_quote!(method(#(#attr_args),*)), + CallType::Subscribe => syn::parse_quote!(subscription(#(#attr_args),*)), + }; + + Ok(()) + }) + }); + + if let Err(e) = err { + return e.into_compile_error().into(); } - let versioned_trait = ItemTrait { - attrs: vec![syn::parse_quote!(#[rpc(server, namespace = #namespace)])], - ident: versioned_trait_name, + let trait_with_version = syn::ItemTrait { + attrs: vec![syn::parse_quote!(#[jsonrpsee::proc_macros::rpc(server, namespace = #namespace)])], + ident: train_name_with_version, ..item_trait }; quote! { - #versioned_trait + #trait_with_version } .into() } @@ -134,7 +216,7 @@ pub fn versioned_rpc(attr: TokenStream, input: TokenStream) -> TokenStream { #[cfg(test)] mod tests { use super::*; - use quote::{quote, ToTokens}; + use quote::quote; use syn::parse_quote; #[test] @@ -165,19 +247,13 @@ mod tests { assert_eq!( result.unwrap_err().to_string(), indoc::indoc!( - " + r#" Namespace cannot be empty. Please provide a non-empty namespace string. - Example: #[versioned_rpc(\"V0_7_1\", \"starknet\")] - " + + ex: #[versioned_rpc("V0_7_1", "starknet")] + "# ) ); } - - #[test] - fn test_version_method_name() { - let attr: Attribute = parse_quote!(#[method(name = "blockNumber")]); - let result = version_method_name(&attr, "V0_7_1").unwrap(); - assert_eq!(result.to_token_stream().to_string(), "# [method (name = \"V0_7_1_blockNumber\")]"); - } }