diff --git a/crates/sui-core/src/authority_server.rs b/crates/sui-core/src/authority_server.rs index abe16af083743..04e33529ce090 100644 --- a/crates/sui-core/src/authority_server.rs +++ b/crates/sui-core/src/authority_server.rs @@ -51,6 +51,7 @@ use crate::{ use crate::{ authority::AuthorityState, consensus_adapter::{ConsensusAdapter, ConsensusAdapterMetrics}, + traffic_controller::parse_ip, traffic_controller::policies::TrafficTally, traffic_controller::TrafficController, }; @@ -332,7 +333,7 @@ impl ValidatorService { consensus_adapter, metrics: validator_metrics, traffic_controller: policy_config.clone().map(|policy| { - Arc::new(TrafficController::spawn( + Arc::new(TrafficController::init( policy, traffic_controller_metrics, firewall_config, @@ -1009,17 +1010,9 @@ impl ValidatorService { ); return None; }; - client_ip.parse::().ok().or_else(|| { - client_ip.parse::().ok().map(|socket_addr| socket_addr.ip()).or_else(|| { - self.metrics.forwarded_header_parse_error.inc(); - error!( - "Failed to parse x-forwarded-for header value of {:?} to ip address or socket. \ - Please ensure that your proxy is configured to resolve client domains to an \ - IP address before writing header", - client_ip, - ); - None - }) + parse_ip(client_ip).or_else(|| { + self.metrics.forwarded_header_parse_error.inc(); + None }) } Err(e) => { diff --git a/crates/sui-core/src/traffic_controller/mod.rs b/crates/sui-core/src/traffic_controller/mod.rs index 1dd940c5d4e04..cacec53b5cc7c 100644 --- a/crates/sui-core/src/traffic_controller/mod.rs +++ b/crates/sui-core/src/traffic_controller/mod.rs @@ -10,7 +10,7 @@ use dashmap::DashMap; use fs::File; use prometheus::IntGauge; use std::fs; -use std::net::{IpAddr, Ipv4Addr}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::ops::Add; use std::sync::Arc; @@ -39,10 +39,20 @@ struct Blocklists { proxied_clients: Blocklist, } +#[derive(Clone)] +enum Acl { + Blocklists(Blocklists), + /// If this variant is set, then we do no tallying or running + /// of background tasks, and instead simply block all IPs not + /// in the allowlist on calls to `check`. The allowlist should + /// only be populated once at initialization. + Allowlist(Vec), +} + #[derive(Clone)] pub struct TrafficController { - tally_channel: mpsc::Sender, - blocklists: Blocklists, + tally_channel: Option>, + acl: Acl, metrics: Arc, dry_run_mode: bool, } @@ -68,7 +78,33 @@ impl Debug for TrafficController { } impl TrafficController { - pub fn spawn( + pub fn init( + policy_config: PolicyConfig, + metrics: TrafficControllerMetrics, + fw_config: Option, + ) -> Self { + match policy_config.allow_list { + Some(allow_list) => { + let allowlist = allow_list + .into_iter() + .map(|ip_str| { + parse_ip(&ip_str).unwrap_or_else(|| { + panic!("Failed to parse allowlist IP address: {:?}", ip_str) + }) + }) + .collect(); + Self { + tally_channel: None, + acl: Acl::Allowlist(allowlist), + metrics: Arc::new(metrics), + dry_run_mode: policy_config.dry_run, + } + } + None => Self::spawn(policy_config, metrics, fw_config), + } + } + + fn spawn( policy_config: PolicyConfig, metrics: TrafficControllerMetrics, fw_config: Option, @@ -86,20 +122,15 @@ impl TrafficController { metrics .deadmans_switch_enabled .set(mem_drainfile_present as i64); - - let ret = Self { - tally_channel: tx, - blocklists: Blocklists { - clients: Arc::new(DashMap::new()), - proxied_clients: Arc::new(DashMap::new()), - }, - metrics: metrics.clone(), - dry_run_mode: policy_config.dry_run, + let blocklists = Blocklists { + clients: Arc::new(DashMap::new()), + proxied_clients: Arc::new(DashMap::new()), }; - let tally_loop_blocklists = ret.blocklists.clone(); - let clear_loop_blocklists = ret.blocklists.clone(); + let tally_loop_blocklists = blocklists.clone(); + let clear_loop_blocklists = blocklists.clone(); let tally_loop_metrics = metrics.clone(); let clear_loop_metrics = metrics.clone(); + let dry_run_mode = policy_config.dry_run; spawn_monitored_task!(run_tally_loop( rx, policy_config, @@ -112,74 +143,91 @@ impl TrafficController { clear_loop_blocklists, clear_loop_metrics, )); - ret + Self { + tally_channel: Some(tx), + acl: Acl::Blocklists(blocklists), + metrics: metrics.clone(), + dry_run_mode, + } } - pub fn spawn_for_test( + pub fn init_for_test( policy_config: PolicyConfig, fw_config: Option, ) -> Self { let metrics = TrafficControllerMetrics::new(&prometheus::Registry::new()); - Self::spawn(policy_config, metrics, fw_config) + Self::init(policy_config, metrics, fw_config) } pub fn tally(&self, tally: TrafficTally) { - // Use try_send rather than send mainly to avoid creating backpressure - // on the caller if the channel is full, which may slow down the critical - // path. Dropping the tally on the floor should be ok, as in this case - // we are effectively sampling traffic, which we would need to do anyway - // if we are overloaded - match self.tally_channel.try_send(tally) { - Err(TrySendError::Full(_)) => { - warn!("TrafficController tally channel full, dropping tally"); - self.metrics.tally_channel_overflow.inc(); - // TODO: once we've verified this doesn't happen under normal - // conditions, we can consider dropping the request itself given - // that clearly the system is overloaded - } - Err(TrySendError::Closed(_)) => { - panic!("TrafficController tally channel closed unexpectedly"); + if let Some(channel) = self.tally_channel.as_ref() { + // Use try_send rather than send mainly to avoid creating backpressure + // on the caller if the channel is full, which may slow down the critical + // path. Dropping the tally on the floor should be ok, as in this case + // we are effectively sampling traffic, which we would need to do anyway + // if we are overloaded + match channel.try_send(tally) { + Err(TrySendError::Full(_)) => { + warn!("TrafficController tally channel full, dropping tally"); + self.metrics.tally_channel_overflow.inc(); + // TODO: once we've verified this doesn't happen under normal + // conditions, we can consider dropping the request itself given + // that clearly the system is overloaded + } + Err(TrySendError::Closed(_)) => { + panic!("TrafficController tally channel closed unexpectedly"); + } + Ok(_) => {} } - Ok(_) => {} } } /// Handle check with dry-run mode considered pub async fn check(&self, client: &Option, proxied_client: &Option) -> bool { - match ( - self.check_impl(client, proxied_client).await, - self.dry_run_mode(), - ) { - // check succeeded - (true, _) => true, - // check failed while in dry-run mode - (false, true) => { - debug!( - "Dry run mode: Blocked request from client {:?}, proxied client: {:?}", - client, proxied_client - ); - self.metrics.num_dry_run_blocked_requests.inc(); - true + let check_with_dry_run_maybe = |allowed| -> bool { + match (allowed, self.dry_run_mode()) { + // check succeeded + (true, _) => true, + // check failed while in dry-run mode + (false, true) => { + debug!("Dry run mode: Blocked request from client {:?}", client); + self.metrics.num_dry_run_blocked_requests.inc(); + true + } + // check failed + (false, false) => false, + } + }; + + match &self.acl { + Acl::Allowlist(allowlist) => { + let allowed = client.is_none() || allowlist.contains(&client.unwrap()); + check_with_dry_run_maybe(allowed) + } + Acl::Blocklists(blocklists) => { + let allowed = self + .check_blocklists(blocklists, client, proxied_client) + .await; + check_with_dry_run_maybe(allowed) } - // check failed - (false, false) => false, } } - /// Returns true if the connection is allowed, false if it is blocked - pub async fn check_impl( + /// Returns true if the connection is in blocklist, false otherwise + async fn check_blocklists( &self, + blocklists: &Blocklists, client: &Option, proxied_client: &Option, ) -> bool { let client_check = self.check_and_clear_blocklist( client, - self.blocklists.clients.clone(), + blocklists.clients.clone(), &self.metrics.connection_ip_blocklist_len, ); let proxied_client_check = self.check_and_clear_blocklist( proxied_client, - self.blocklists.proxied_clients.clone(), + blocklists.proxied_clients.clone(), &self.metrics.proxy_ip_blocklist_len, ); let (client_check, proxied_client_check) = @@ -604,7 +652,7 @@ impl TrafficSim { assert!(per_client_tps > 0); assert!(duration.as_secs() > 0); - let controller = TrafficController::spawn_for_test(policy.clone(), None); + let controller = TrafficController::init_for_test(policy.clone(), None); let tasks = (0..num_clients).map(|task_num| { tokio::spawn(Self::run_single_client( controller.clone(), @@ -772,3 +820,15 @@ impl TrafficSim { ); } } + +pub fn parse_ip(ip: &str) -> Option { + ip.parse::().ok().or_else(|| { + ip.parse::() + .ok() + .map(|socket_addr| socket_addr.ip()) + .or_else(|| { + error!("Failed to parse value of {:?} to ip address or socket.", ip,); + None + }) + }) +} diff --git a/crates/sui-e2e-tests/tests/traffic_control_tests.rs b/crates/sui-e2e-tests/tests/traffic_control_tests.rs index 4ae24ddcc7ea6..f96e59099c183 100644 --- a/crates/sui-e2e-tests/tests/traffic_control_tests.rs +++ b/crates/sui-e2e-tests/tests/traffic_control_tests.rs @@ -421,7 +421,7 @@ async fn test_validator_traffic_control_error_delegated() -> Result<(), anyhow:: let mut server = NodeFwTestServer::new(); server.start(port).await; // await for the server to start - tokio::time::sleep(tokio::time::Duration::from_secs(3)).await; + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; // it should take no more than 4 requests to be added to the blocklist for _ in 0..n { @@ -547,7 +547,7 @@ async fn test_traffic_control_dead_mans_switch() -> Result<(), anyhow::Error> { // NOTE: we need to hold onto this tc handle to ensure we don't inadvertently close // the receive channel (this would cause traffic controller to exit the loop and thus // we will never engage the dead mans switch) - let _tc = TrafficController::spawn_for_test(policy_config, Some(firewall_config)); + let _tc = TrafficController::init_for_test(policy_config, Some(firewall_config)); assert!( !drain_path.exists(), "Expected drain file to not exist after startup unless previously set", @@ -704,6 +704,32 @@ async fn test_traffic_sketch_with_sampled_spam() { assert!(metrics.num_blocked > (expected_requests / 5) - 1000); } +#[sim_test] +async fn test_traffic_sketch_allowlist_mode() { + let policy_config = PolicyConfig { + connection_blocklist_ttl_sec: 1, + proxy_blocklist_ttl_sec: 1, + // first two clients allowlisted, rest blocked + allow_list: Some(vec![String::from("127.0.0.0"), String::from("127.0.0.1")]), + dry_run: false, + ..Default::default() + }; + let metrics = TrafficSim::run( + policy_config, + 4, // num_clients + 10_000, // per_client_tps + Duration::from_secs(10), + true, // report + ) + .await; + + let expected_requests = 10_000 * 10 * 4; + // ~half of all requests blocked + assert!(metrics.num_blocked >= expected_requests / 2 - 1000); + assert!(metrics.num_requests > expected_requests - 1_000); + assert!(metrics.num_requests < expected_requests + 200); +} + async fn assert_traffic_control_ok(mut test_cluster: TestCluster) -> Result<(), anyhow::Error> { let context = &mut test_cluster.wallet; let jsonrpc_client = &test_cluster.fullnode_handle.rpc_client; diff --git a/crates/sui-json-rpc/src/axum_router.rs b/crates/sui-json-rpc/src/axum_router.rs index 0fb6be835c835..aa2e1e4edcadb 100644 --- a/crates/sui-json-rpc/src/axum_router.rs +++ b/crates/sui-json-rpc/src/axum_router.rs @@ -22,7 +22,7 @@ use jsonrpsee::types::{ErrorObject, Id, InvalidRequest, Params, Request}; use jsonrpsee::{core::server::rpc_module::Methods, server::logger::Logger}; use serde_json::value::RawValue; use sui_core::traffic_controller::{ - metrics::TrafficControllerMetrics, policies::TrafficTally, TrafficController, + metrics::TrafficControllerMetrics, parse_ip, policies::TrafficTally, TrafficController, }; use sui_json_rpc_api::TRANSACTION_EXECUTION_CLIENT_ERROR_CODE; use sui_types::traffic_control::ClientIdSource; @@ -63,7 +63,7 @@ impl JsonRpcService { logger, id_provider: Arc::new(RandomIntegerIdProvider), traffic_controller: policy_config.clone().map(|policy| { - Arc::new(TrafficController::spawn( + Arc::new(TrafficController::init( policy, traffic_controller_metrics, remote_fw_config, @@ -183,17 +183,7 @@ async fn process_raw_request( ); return None; }; - client_ip.parse::().ok().or_else(|| { - client_ip.parse::().ok().map(|socket_addr| socket_addr.ip()).or_else(|| { - error!( - "Failed to parse x-forwarded-for header value of {:?} to ip address or socket. \ - Please ensure that your proxy is configured to resolve client domains to an \ - IP address before writing header", - client_ip, - ); - None - }) - }) + parse_ip(client_ip) } Err(e) => { error!("Invalid UTF-8 in x-forwarded-for header: {:?}", e); diff --git a/crates/sui-types/src/traffic_control.rs b/crates/sui-types/src/traffic_control.rs index 95e5ee534887b..6a70ea3ad28ab 100644 --- a/crates/sui-types/src/traffic_control.rs +++ b/crates/sui-types/src/traffic_control.rs @@ -251,6 +251,11 @@ pub struct PolicyConfig { pub spam_sample_rate: Weight, #[serde(default = "default_dry_run")] pub dry_run: bool, + /// List of String which should all parse to type IPAddr. + /// If set, only requests from provided IPs will be allowed, + /// and any blocklist related configuration will be ignored. + #[serde(default)] + pub allow_list: Option>, } impl Default for PolicyConfig { @@ -264,6 +269,7 @@ impl Default for PolicyConfig { channel_capacity: 100, spam_sample_rate: default_spam_sample_rate(), dry_run: default_dry_run(), + allow_list: None, } } }