Skip to content

Commit

Permalink
chore: comments update
Browse files Browse the repository at this point in the history
  • Loading branch information
ICavlek committed Aug 26, 2024
1 parent a2c6795 commit 7611957
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 36 deletions.
19 changes: 13 additions & 6 deletions src/bin/beerus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{sync::Arc, time::Duration};
use beerus::config::Config;
use clap::Parser;
use tokio::sync::RwLock;
use validator::Validate;

const RPC_SPEC_VERSION: &str = "0.6.0";

Expand All @@ -11,8 +12,7 @@ async fn main() -> eyre::Result<()> {
tracing_subscriber::fmt::init();

let args = Args::parse();
let config = get_config(&args)?;
config.check(args.skip_chain_id_validation).await?;
let config = get_config(&args).await?;

let beerus = beerus::client::Client::new(&config).await?;
beerus.start().await?;
Expand Down Expand Up @@ -66,10 +66,17 @@ struct Args {
skip_chain_id_validation: bool,
}

fn get_config(args: &Args) -> eyre::Result<Config> {
Ok(if let Some(path) = args.config.as_ref() {
async fn get_config(args: &Args) -> eyre::Result<Config> {
let config = if let Some(path) = args.config.as_ref() {
Config::from_file(path)?
} else {
Config::from_env()?
})
Config::from_env()
};
config.validate()?;
if args.skip_chain_id_validation {
tracing::warn!("Skipping chain id validation");
return Ok(config);
}
config.check().await?;
Ok(config)
}
39 changes: 9 additions & 30 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ fn default_rpc_addr() -> SocketAddr {
}

impl Config {
pub fn from_env() -> Result<Self> {
Ok(Self {
pub fn from_env() -> Self {
Self {
network: Network::from_str(
&std::env::var("NETWORK").unwrap_or_default(),
)
Expand All @@ -68,7 +68,7 @@ impl Config {
.ok()
.and_then(|rpc_addr| rpc_addr.parse::<SocketAddr>().ok())
.unwrap_or_else(default_rpc_addr),
})
}
}

pub fn from_file(path: &str) -> Result<Self> {
Expand All @@ -86,9 +86,7 @@ impl Config {
}
}

pub async fn check(&self, skip_chain_id_validation: bool) -> Result<()> {
self.validate()?;

pub async fn check(&self) -> Result<()> {
let expected_chain_id = match self.network {
Network::MAINNET => MAINNET_ETHEREUM_CHAINID,
Network::SEPOLIA => SEPOLIA_ETHEREUM_CHAINID,
Expand All @@ -102,7 +100,6 @@ impl Config {
expected_chain_id,
&self.eth_execution_rpc,
"eth_chainId",
skip_chain_id_validation,
)
.await?;

Expand All @@ -119,7 +116,6 @@ impl Config {
expected_chain_id,
&self.starknet_rpc,
"starknet_chainId",
skip_chain_id_validation,
)
.await?;

Expand All @@ -146,17 +142,12 @@ async fn check_chain_id(
expected_chain_id: &str,
url: &str,
method: &str,
skip_chain_id_validation: bool,
) -> Result<()> {
let chain_id = call_method(url, method).await?;
let message = format!(
"Invalid chain id: expected {expected_chain_id} but got {chain_id}"
);
if chain_id != expected_chain_id {
if !skip_chain_id_validation {
eyre::bail!(message);
}
tracing::warn!(message);
eyre::bail!(format!(
"Invalid chain id: expected {expected_chain_id} but got {chain_id}"
));
}
Ok(())
}
Expand Down Expand Up @@ -207,8 +198,7 @@ mod tests {
poll_secs: 300,
rpc_addr: SocketAddr::from(([0, 0, 0, 0], 3030)),
};
let skip_chain_id_validation = false;
let response = config.check(skip_chain_id_validation).await;
let response = config.validate();

assert!(response.is_err());
assert!(response
Expand All @@ -221,13 +211,11 @@ mod tests {
async fn correct_eth_url() {
let response = serde_json::json!({"jsonrpc":"2.0","id":0,"result":MAINNET_ETHEREUM_CHAINID});
let server = setup_server_with_response(response).await;
let skip_chain_id_validation = false;

let result = check_chain_id(
MAINNET_ETHEREUM_CHAINID,
&server.uri(),
"eth_chainId",
skip_chain_id_validation,
)
.await;
assert!(result.is_ok());
Expand All @@ -238,13 +226,11 @@ mod tests {
let response =
serde_json::json!({"jsonrpc":"2.0","id":0,"error":"foo"});
let server = setup_server_with_response(response).await;
let skip_chain_id_validation = false;

let result = check_chain_id(
MAINNET_ETHEREUM_CHAINID,
&server.uri(),
"eth_chainId",
skip_chain_id_validation,
)
.await;

Expand All @@ -260,13 +246,11 @@ mod tests {
let response =
serde_json::json!({"jsonrpc":"2.0","id":0,"error":"foo"});
let server = setup_server_with_response(response).await;
let skip_chain_id_validation = false;

let result = check_chain_id(
MAINNET_STARKNET_CHAINID,
&server.uri(),
"eth_chainId",
skip_chain_id_validation,
)
.await;

Expand All @@ -282,13 +266,11 @@ mod tests {
let response =
serde_json::json!({"jsonrpc":"2.0","id":0,"error":"foo"});
let server = setup_server_with_response(response).await;
let skip_chain_id_validation = false;

let result = check_chain_id(
MAINNET_STARKNET_CHAINID,
&server.uri(),
"starknet_chainId",
skip_chain_id_validation,
)
.await;

Expand All @@ -304,13 +286,11 @@ mod tests {
let response =
serde_json::json!({"jsonrpc":"2.0","id":0,"error":"foo"});
let server = setup_server_with_response(response).await;
let skip_chain_id_validation = false;

let result = check_chain_id(
MAINNET_STARKNET_CHAINID,
&server.uri(),
"starknet_chainId",
skip_chain_id_validation,
)
.await;

Expand All @@ -331,8 +311,7 @@ mod tests {
poll_secs: 9999,
rpc_addr: SocketAddr::from(([127, 0, 0, 1], 3030)),
};
let skip_chain_id_validation = false;
let response = config.check(skip_chain_id_validation).await;
let response = config.validate();

assert!(response.is_err());
assert!(response.unwrap_err().to_string().contains("poll_secs"));
Expand Down

0 comments on commit 7611957

Please sign in to comment.