From 2c1c8e4461ee3536a2d1608d90a92f97f4236e86 Mon Sep 17 00:00:00 2001 From: barshaul Date: Tue, 8 Aug 2023 14:34:21 +0000 Subject: [PATCH] Changed the cluster initialization to retrieve all IP entries from the initial nodes and use all resolved IPs. --- redis/src/aio/connection.rs | 2 +- redis/src/cluster_async/mod.rs | 84 +++++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 8 deletions(-) diff --git a/redis/src/aio/connection.rs b/redis/src/aio/connection.rs index 2bcd7cb9c..cdca0801c 100644 --- a/redis/src/aio/connection.rs +++ b/redis/src/aio/connection.rs @@ -361,7 +361,7 @@ where } } -async fn get_socket_addrs( +pub(crate) async fn get_socket_addrs( host: &str, port: u16, ) -> RedisResult + Send + '_> { diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 02912bc47..19755fc81 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -36,7 +36,7 @@ use std::{ }; use crate::{ - aio::{ConnectionLike, MultiplexedConnection}, + aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection}, cluster::{get_connection_info, slot_cmd}, cluster_client::{ClusterParams, RetryParams}, cluster_routing::{ @@ -487,15 +487,54 @@ where Ok(connection) } + /// Go through each of the initial nodes and attempt to retrieve all IP entries from them. + /// If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list. + pub(crate) async fn try_to_expand_initial_nodes( + initial_nodes: &[ConnectionInfo], + ) -> Vec { + stream::iter(initial_nodes) + .fold( + Vec::with_capacity(initial_nodes.len()), + |mut acc, info| async { + let (host, port) = match &info.addr { + crate::ConnectionAddr::Tcp(host, port) => (host, port), + crate::ConnectionAddr::TcpTls { + host, + port, + insecure: _, + } => (host, port), + crate::ConnectionAddr::Unix(_) => { + // We don't support multiple addresses for a Unix address. Store the initial node address and continue + acc.push(info.addr.to_string()); + return acc; + } + }; + match get_socket_addrs(host, *port).await { + Ok(socket_addrs) => { + for addr in socket_addrs { + acc.push(addr.to_string()); + } + } + Err(_) => { + // Couldn't find socket addresses, store the initial node address and continue + acc.push(info.addr.to_string()); + } + }; + acc + }, + ) + .await + } + async fn create_initial_connections( initial_nodes: &[ConnectionInfo], params: &ClusterParams, ) -> RedisResult> { + let initial_nodes: Vec = Self::try_to_expand_initial_nodes(initial_nodes).await; let connections = stream::iter(initial_nodes.iter().cloned()) - .map(|info| { + .map(|addr| { let params = params.clone(); async move { - let addr = info.addr.to_string(); let result = connect_and_check(&addr, params).await; match result { Ok(conn) => Some((addr, async { conn }.boxed().shared())), @@ -684,15 +723,33 @@ where inner.cluster_params.tls, num_of_nodes_to_query, )?; - + // Create a new connection vector of the found nodes let connections: &ConnectionMap = &read_guard.0; let mut nodes = new_slots.values().flatten().collect::>(); nodes.sort_unstable(); nodes.dedup(); let nodes_len = nodes.len(); - let addresses_and_connections_iter = nodes - .into_iter() - .map(|addr| (addr, connections.get(addr).cloned())); + let addresses_and_connections_iter = nodes.into_iter().map(|addr| async move { + if let Some(conn) = connections.get(addr).cloned() { + return (addr, Some(conn)); + } + // If it's a DNS endpoint, it could have been stored in the existing connections vector using the resolved IP address instead of the DNS endpoint's name. + // We shall check if a connection is already exists under the resolved IP name. + let (host, port) = match get_host_and_port_from_addr(addr) { + Some((host, port)) => (host, port), + None => return (addr, None), + }; + let conn = get_socket_addrs(host, port) + .await + .ok() + .map(|mut socket_addresses| { + socket_addresses.find_map(|addr| connections.get(&addr.to_string()).cloned()) + }) + .unwrap_or(None); + (addr, conn) + }); + let addresses_and_connections_iter = + futures::future::join_all(addresses_and_connections_iter).await; let new_connections: HashMap> = stream::iter(addresses_and_connections_iter) .fold( @@ -709,6 +766,7 @@ where .await; drop(read_guard); + // Replace the current slot map and connection vector with the new ones let mut write_guard = inner.conn_lock.write().await; write_guard.1 = new_slots; write_guard.0 = new_connections; @@ -1306,6 +1364,18 @@ where (addr, conn) } +/// Splits a string address into host and port. If the passed address cannot be parsed, None is returned. +/// [addr] should be in the following format: ":". +fn get_host_and_port_from_addr(addr: &str) -> Option<(&str, u16)> { + let parts: Vec<&str> = addr.split(':').collect(); + if parts.len() != 2 { + return None; + } + let host = parts.first().unwrap(); + let port = parts.get(1).unwrap(); + port.parse::().ok().map(|port| (*host, port)) +} + #[cfg(test)] mod pipeline_routing_tests { use super::route_pipeline;