From 761195765c1a4cfbbcb5b56c6407302afe846b57 Mon Sep 17 00:00:00 2001 From: ICavlek Date: Mon, 26 Aug 2024 17:51:33 +0200 Subject: [PATCH] chore: comments update --- src/bin/beerus.rs | 19 +++++++++++++------ src/config.rs | 39 +++++++++------------------------------ 2 files changed, 22 insertions(+), 36 deletions(-) diff --git a/src/bin/beerus.rs b/src/bin/beerus.rs index 6c26e943e..c37e90993 100644 --- a/src/bin/beerus.rs +++ b/src/bin/beerus.rs @@ -3,6 +3,7 @@ use std::{sync::Arc, time::Duration}; use beerus::config::Config; use clap::Parser; use tokio::sync::RwLock; +use validator::Validate; const RPC_SPEC_VERSION: &str = "0.6.0"; @@ -11,8 +12,7 @@ async fn main() -> eyre::Result<()> { tracing_subscriber::fmt::init(); let args = Args::parse(); - let config = get_config(&args)?; - config.check(args.skip_chain_id_validation).await?; + let config = get_config(&args).await?; let beerus = beerus::client::Client::new(&config).await?; beerus.start().await?; @@ -66,10 +66,17 @@ struct Args { skip_chain_id_validation: bool, } -fn get_config(args: &Args) -> eyre::Result { - Ok(if let Some(path) = args.config.as_ref() { +async fn get_config(args: &Args) -> eyre::Result { + let config = if let Some(path) = args.config.as_ref() { Config::from_file(path)? } else { - Config::from_env()? - }) + Config::from_env() + }; + config.validate()?; + if args.skip_chain_id_validation { + tracing::warn!("Skipping chain id validation"); + return Ok(config); + } + config.check().await?; + Ok(config) } diff --git a/src/config.rs b/src/config.rs index 688a90c5f..4229305f3 100644 --- a/src/config.rs +++ b/src/config.rs @@ -48,8 +48,8 @@ fn default_rpc_addr() -> SocketAddr { } impl Config { - pub fn from_env() -> Result { - Ok(Self { + pub fn from_env() -> Self { + Self { network: Network::from_str( &std::env::var("NETWORK").unwrap_or_default(), ) @@ -68,7 +68,7 @@ impl Config { .ok() .and_then(|rpc_addr| rpc_addr.parse::().ok()) .unwrap_or_else(default_rpc_addr), - }) + } } pub fn from_file(path: &str) -> Result { @@ -86,9 +86,7 @@ impl Config { } } - pub async fn check(&self, skip_chain_id_validation: bool) -> Result<()> { - self.validate()?; - + pub async fn check(&self) -> Result<()> { let expected_chain_id = match self.network { Network::MAINNET => MAINNET_ETHEREUM_CHAINID, Network::SEPOLIA => SEPOLIA_ETHEREUM_CHAINID, @@ -102,7 +100,6 @@ impl Config { expected_chain_id, &self.eth_execution_rpc, "eth_chainId", - skip_chain_id_validation, ) .await?; @@ -119,7 +116,6 @@ impl Config { expected_chain_id, &self.starknet_rpc, "starknet_chainId", - skip_chain_id_validation, ) .await?; @@ -146,17 +142,12 @@ async fn check_chain_id( expected_chain_id: &str, url: &str, method: &str, - skip_chain_id_validation: bool, ) -> Result<()> { let chain_id = call_method(url, method).await?; - let message = format!( - "Invalid chain id: expected {expected_chain_id} but got {chain_id}" - ); if chain_id != expected_chain_id { - if !skip_chain_id_validation { - eyre::bail!(message); - } - tracing::warn!(message); + eyre::bail!(format!( + "Invalid chain id: expected {expected_chain_id} but got {chain_id}" + )); } Ok(()) } @@ -207,8 +198,7 @@ mod tests { poll_secs: 300, rpc_addr: SocketAddr::from(([0, 0, 0, 0], 3030)), }; - let skip_chain_id_validation = false; - let response = config.check(skip_chain_id_validation).await; + let response = config.validate(); assert!(response.is_err()); assert!(response @@ -221,13 +211,11 @@ mod tests { async fn correct_eth_url() { let response = serde_json::json!({"jsonrpc":"2.0","id":0,"result":MAINNET_ETHEREUM_CHAINID}); let server = setup_server_with_response(response).await; - let skip_chain_id_validation = false; let result = check_chain_id( MAINNET_ETHEREUM_CHAINID, &server.uri(), "eth_chainId", - skip_chain_id_validation, ) .await; assert!(result.is_ok()); @@ -238,13 +226,11 @@ mod tests { let response = serde_json::json!({"jsonrpc":"2.0","id":0,"error":"foo"}); let server = setup_server_with_response(response).await; - let skip_chain_id_validation = false; let result = check_chain_id( MAINNET_ETHEREUM_CHAINID, &server.uri(), "eth_chainId", - skip_chain_id_validation, ) .await; @@ -260,13 +246,11 @@ mod tests { let response = serde_json::json!({"jsonrpc":"2.0","id":0,"error":"foo"}); let server = setup_server_with_response(response).await; - let skip_chain_id_validation = false; let result = check_chain_id( MAINNET_STARKNET_CHAINID, &server.uri(), "eth_chainId", - skip_chain_id_validation, ) .await; @@ -282,13 +266,11 @@ mod tests { let response = serde_json::json!({"jsonrpc":"2.0","id":0,"error":"foo"}); let server = setup_server_with_response(response).await; - let skip_chain_id_validation = false; let result = check_chain_id( MAINNET_STARKNET_CHAINID, &server.uri(), "starknet_chainId", - skip_chain_id_validation, ) .await; @@ -304,13 +286,11 @@ mod tests { let response = serde_json::json!({"jsonrpc":"2.0","id":0,"error":"foo"}); let server = setup_server_with_response(response).await; - let skip_chain_id_validation = false; let result = check_chain_id( MAINNET_STARKNET_CHAINID, &server.uri(), "starknet_chainId", - skip_chain_id_validation, ) .await; @@ -331,8 +311,7 @@ mod tests { poll_secs: 9999, rpc_addr: SocketAddr::from(([127, 0, 0, 1], 3030)), }; - let skip_chain_id_validation = false; - let response = config.check(skip_chain_id_validation).await; + let response = config.validate(); assert!(response.is_err()); assert!(response.unwrap_err().to_string().contains("poll_secs"));