Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added to the initial_nodes list all retrieved IP entries #19

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion redis/src/aio/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ where
}
}

async fn get_socket_addrs(
pub(crate) async fn get_socket_addrs(
host: &str,
port: u16,
) -> RedisResult<impl Iterator<Item = SocketAddr> + Send + '_> {
Expand Down
84 changes: 77 additions & 7 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use std::{
};

use crate::{
aio::{ConnectionLike, MultiplexedConnection},
aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection},
cluster::{get_connection_info, slot_cmd},
cluster_client::{ClusterParams, RetryParams},
cluster_routing::{
Expand Down Expand Up @@ -478,15 +478,54 @@ where
Ok(connection)
}

/// Go through each of the initial nodes and attempt to retrieve all IP entries from them.
/// If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list.
pub(crate) async fn try_to_expand_initial_nodes(
initial_nodes: &[ConnectionInfo],
) -> Vec<String> {
stream::iter(initial_nodes)
barshaul marked this conversation as resolved.
Show resolved Hide resolved
.fold(
barshaul marked this conversation as resolved.
Show resolved Hide resolved
Vec::with_capacity(initial_nodes.len()),
|mut acc, info| async {
let (host, port) = match &info.addr {
crate::ConnectionAddr::Tcp(host, port) => (host, port),
crate::ConnectionAddr::TcpTls {
host,
port,
insecure: _,
} => (host, port),
crate::ConnectionAddr::Unix(_) => {
// We don't support multiple addresses for a Unix address. Store the initial node address and continue
acc.push(info.addr.to_string());
return acc;
}
};
match get_socket_addrs(host, *port).await {
Ok(socket_addrs) => {
for addr in socket_addrs {
acc.push(addr.to_string());
}
}
Err(_) => {
// Couldn't find socket addresses, store the initial node address and continue
acc.push(info.addr.to_string());
}
};
acc
},
)
.await
}

async fn create_initial_connections(
initial_nodes: &[ConnectionInfo],
params: &ClusterParams,
) -> RedisResult<ConnectionMap<C>> {
let initial_nodes: Vec<String> = Self::try_to_expand_initial_nodes(initial_nodes).await;
let connections = stream::iter(initial_nodes.iter().cloned())
.map(|info| {
.map(|addr| {
let params = params.clone();
async move {
let addr = info.addr.to_string();
let result = connect_and_check(&addr, params).await;
match result {
Ok(conn) => Some((addr, async { conn }.boxed().shared())),
Expand Down Expand Up @@ -675,15 +714,33 @@ where
inner.cluster_params.tls,
num_of_nodes_to_query,
)?;

// Create a new connection vector of the found nodes
let connections: &ConnectionMap<C> = &read_guard.0;
let mut nodes = new_slots.values().flatten().collect::<Vec<_>>();
nodes.sort_unstable();
nodes.dedup();
let nodes_len = nodes.len();
let addresses_and_connections_iter = nodes
.into_iter()
.map(|addr| (addr, connections.get(addr).cloned()));
let addresses_and_connections_iter = nodes.into_iter().map(|addr| async move {
if let Some(conn) = connections.get(addr).cloned() {
return (addr, Some(conn));
}
// If it's a DNS endpoint, it could have been stored in the existing connections vector using the resolved IP address instead of the DNS endpoint's name.
// We shall check if a connection is already exists under the resolved IP name.
let (host, port) = match get_host_and_port_from_addr(addr) {
Some((host, port)) => (host, port),
None => return (addr, None),
};
let conn = get_socket_addrs(host, port)
.await
.ok()
.map(|mut socket_addresses| {
socket_addresses.find_map(|addr| connections.get(&addr.to_string()).cloned())
})
.unwrap_or(None);
(addr, conn)
});
let addresses_and_connections_iter =
futures::future::join_all(addresses_and_connections_iter).await;
let new_connections: HashMap<String, ConnectionFuture<C>> =
stream::iter(addresses_and_connections_iter)
.fold(
Expand All @@ -700,6 +757,7 @@ where
.await;

drop(read_guard);
// Replace the current slot map and connection vector with the new ones
let mut write_guard = inner.conn_lock.write().await;
write_guard.1 = new_slots;
write_guard.0 = new_connections;
Expand Down Expand Up @@ -1301,6 +1359,18 @@ where
(addr, conn)
}

/// Splits a string address into host and port. If the passed address cannot be parsed, None is returned.
/// [addr] should be in the following format: "<host>:<port>".
fn get_host_and_port_from_addr(addr: &str) -> Option<(&str, u16)> {
let parts: Vec<&str> = addr.split(':').collect();
if parts.len() != 2 {
return None;
}
let host = parts.first().unwrap();
let port = parts.get(1).unwrap();
port.parse::<u16>().ok().map(|port| (*host, port))
}

#[cfg(test)]
mod pipeline_routing_tests {
use super::route_pipeline;
Expand Down
Loading