Skip to content

Commit

Permalink
Changed the cluster initialization to retrieve all IP entries from th…
Browse files Browse the repository at this point in the history
…e initial nodes and use all resolved IPs.
  • Loading branch information
barshaul committed Aug 10, 2023
1 parent 63adbe0 commit 67b92c3
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 8 deletions.
2 changes: 1 addition & 1 deletion redis/src/aio/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ where
}
}

async fn get_socket_addrs(
pub(crate) async fn get_socket_addrs(
host: &str,
port: u16,
) -> RedisResult<impl Iterator<Item = SocketAddr> + Send + '_> {
Expand Down
84 changes: 77 additions & 7 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use std::{
};

use crate::{
aio::{ConnectionLike, MultiplexedConnection},
aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection},
cluster::{get_connection_info, slot_cmd},
cluster_client::{ClusterParams, RetryParams},
cluster_routing::{
Expand Down Expand Up @@ -478,15 +478,54 @@ where
Ok(connection)
}

/// Go through each of the initial nodes and attempt to retrieve all IP entries from them.
/// If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list.
pub(crate) async fn try_to_expand_initial_nodes(
initial_nodes: &[ConnectionInfo],
) -> Vec<String> {
stream::iter(initial_nodes)
.fold(
Vec::with_capacity(initial_nodes.len()),
|mut acc, info| async {
let (host, port) = match &info.addr {
crate::ConnectionAddr::Tcp(host, port) => (host, port),
crate::ConnectionAddr::TcpTls {
host,
port,
insecure: _,
} => (host, port),
crate::ConnectionAddr::Unix(_) => {
// We don't support multiple addresses for a Unix address. Store the initial node address and continue
acc.push(info.addr.to_string());
return acc;
}
};
match get_socket_addrs(host, *port).await {
Ok(socket_addrs) => {
for addr in socket_addrs {
acc.push(addr.to_string());
}
}
Err(_) => {
// Couldn't find socket addresses, store the initial node address and continue
acc.push(info.addr.to_string());
}
};
acc
},
)
.await
}

async fn create_initial_connections(
initial_nodes: &[ConnectionInfo],
params: &ClusterParams,
) -> RedisResult<ConnectionMap<C>> {
let initial_nodes: Vec<String> = Self::try_to_expand_initial_nodes(initial_nodes).await;
let connections = stream::iter(initial_nodes.iter().cloned())
.map(|info| {
.map(|addr| {
let params = params.clone();
async move {
let addr = info.addr.to_string();
let result = connect_and_check(&addr, params).await;
match result {
Ok(conn) => Some((addr, async { conn }.boxed().shared())),
Expand Down Expand Up @@ -675,15 +714,33 @@ where
inner.cluster_params.tls,
num_of_nodes_to_query,
)?;

// Create a new connection vector of the found nodes
let connections: &ConnectionMap<C> = &read_guard.0;
let mut nodes = new_slots.values().flatten().collect::<Vec<_>>();
nodes.sort_unstable();
nodes.dedup();
let nodes_len = nodes.len();
let addresses_and_connections_iter = nodes
.into_iter()
.map(|addr| (addr, connections.get(addr).cloned()));
let addresses_and_connections_iter = nodes.into_iter().map(|addr| async move {
if let Some(conn) = connections.get(addr).cloned() {
return (addr, Some(conn));
}
// If it's a DNS endpoint, it could have been stored in the existing connections vector using the resolved IP address instead of the DNS endpoint's name.
// We shall check if a connection is already exists under the resolved IP name.
let (host, port) = match get_host_and_port_from_addr(addr) {
Some((host, port)) => (host, port),
None => return (addr, None),
};
let conn = get_socket_addrs(host, port)
.await
.ok()
.map(|mut socket_addresses| {
socket_addresses.find_map(|addr| connections.get(&addr.to_string()).cloned())
})
.unwrap_or(None);
(addr, conn)
});
let addresses_and_connections_iter =
futures::future::join_all(addresses_and_connections_iter).await;
let new_connections: HashMap<String, ConnectionFuture<C>> =
stream::iter(addresses_and_connections_iter)
.fold(
Expand All @@ -700,6 +757,7 @@ where
.await;

drop(read_guard);
// Replace the current slot map and connection vector with the new ones
let mut write_guard = inner.conn_lock.write().await;
write_guard.1 = new_slots;
write_guard.0 = new_connections;
Expand Down Expand Up @@ -1301,6 +1359,18 @@ where
(addr, conn)
}

/// Splits a string address into host and port. If the passed address cannot be parsed, None is returned.
/// [addr] should be in the following format: "<host>:<port>".
fn get_host_and_port_from_addr(addr: &str) -> Option<(&str, u16)> {
let parts: Vec<&str> = addr.split(':').collect();
if parts.len() != 2 {
return None;
}
let host = parts.first().unwrap();
let port = parts.get(1).unwrap();
port.parse::<u16>().ok().map(|port| (*host, port))
}

#[cfg(test)]
mod pipeline_routing_tests {
use super::route_pipeline;
Expand Down

0 comments on commit 67b92c3

Please sign in to comment.