Skip to content

Commit

Permalink
Changed refresh_slots to query multiple nodes and return the most fre…
Browse files Browse the repository at this point in the history
…quent topology
  • Loading branch information
barshaul committed Aug 3, 2023
1 parent e474b7e commit 043c363
Show file tree
Hide file tree
Showing 9 changed files with 837 additions and 284 deletions.
3 changes: 2 additions & 1 deletion redis/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ rand = { version = "0.8", optional = true }
# Only needed for async_std support
async-std = { version = "1.8.0", optional = true}
async-trait = { version = "0.1.24", optional = true }
derivative = { version = "2.2.0", optional = true }

# Only needed for native tls
native-tls = { version = "0.2", optional = true }
Expand Down Expand Up @@ -86,7 +87,7 @@ acl = []
aio = ["bytes", "pin-project-lite", "futures-util", "futures-util/alloc", "futures-util/sink", "tokio/io-util", "tokio-util", "tokio-util/codec", "tokio/sync", "combine/tokio", "async-trait", "futures-time"]
geospatial = []
json = ["serde", "serde/derive", "serde_json"]
cluster = ["crc16", "rand"]
cluster = ["crc16", "rand", "derivative"]
script = ["sha1_smol"]
tls-native-tls = ["native-tls"]
tls-rustls = ["rustls", "rustls-native-certs"]
Expand Down
109 changes: 4 additions & 105 deletions redis/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ use std::str::FromStr;
use std::thread;
use std::time::Duration;

use log::trace;
use rand::{seq::IteratorRandom, thread_rng, Rng};

use crate::cluster_pipeline::UNROUTABLE_ERROR;
use crate::cluster_routing::{MultipleNodeRoutingInfo, SingleNodeRoutingInfo, SlotAddr};
use crate::cluster_topology::{build_slot_map, parse_slots, SlotMap, SLOT_SIZE};
use crate::cmd::{cmd, Cmd};
use crate::connection::{
connect, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, RedisConnectionInfo,
Expand All @@ -55,7 +55,7 @@ use crate::types::{ErrorKind, HashMap, RedisError, RedisResult, Value};
use crate::IntoConnectionInfo;
use crate::{
cluster_client::ClusterParams,
cluster_routing::{Redirect, Routable, Route, RoutingInfo, Slot, SlotMap, SLOT_SIZE},
cluster_routing::{Redirect, Routable, Route, RoutingInfo},
};

pub use crate::cluster_client::{ClusterClient, ClusterClientBuilder};
Expand Down Expand Up @@ -298,7 +298,7 @@ where
)));
for conn in samples.iter_mut() {
let value = conn.req_command(&slot_cmd())?;
match parse_slots(value, self.cluster_params.tls).and_then(|v| {
match parse_slots(&value, self.cluster_params.tls).and_then(|v| {
build_slot_map(&mut new_slots, v, self.cluster_params.read_from_replicas)
}) {
Ok(_) => {
Expand Down Expand Up @@ -705,107 +705,6 @@ fn get_random_connection<C: ConnectionLike + Connect + Sized>(
(addr, con)
}

// Parse slot data from raw redis value.
pub(crate) fn parse_slots(raw_slot_resp: Value, tls: Option<TlsMode>) -> RedisResult<Vec<Slot>> {
// Parse response.
let mut result = Vec::with_capacity(2);

if let Value::Bulk(items) = raw_slot_resp {
let mut iter = items.into_iter();
while let Some(Value::Bulk(item)) = iter.next() {
if item.len() < 3 {
continue;
}

let start = if let Value::Int(start) = item[0] {
start as u16
} else {
continue;
};

let end = if let Value::Int(end) = item[1] {
end as u16
} else {
continue;
};

let mut nodes: Vec<String> = item
.into_iter()
.skip(2)
.filter_map(|node| {
if let Value::Bulk(node) = node {
if node.len() < 2 {
return None;
}

let ip = if let Value::Data(ref ip) = node[0] {
String::from_utf8_lossy(ip)
} else {
return None;
};
if ip.is_empty() {
return None;
}

let port = if let Value::Int(port) = node[1] {
port as u16
} else {
return None;
};
Some(get_connection_addr(ip.into_owned(), port, tls).to_string())
} else {
None
}
})
.collect();

if nodes.is_empty() {
continue;
}

let replicas = nodes.split_off(1);
result.push(Slot::new(start, end, nodes.pop().unwrap(), replicas));
}
}

Ok(result)
}

pub(crate) fn build_slot_map(
slot_map: &mut SlotMap,
mut slots_data: Vec<Slot>,
read_from_replicas: bool,
) -> RedisResult<()> {
slots_data.sort_by_key(|slot_data| slot_data.start());
let last_slot = slots_data.iter().try_fold(0, |prev_end, slot_data| {
if prev_end != slot_data.start() {
return Err(RedisError::from((
ErrorKind::ResponseError,
"Slot refresh error.",
format!(
"Received overlapping slots {} and {}..{}",
prev_end,
slot_data.start(),
slot_data.end()
),
)));
}
Ok(slot_data.end() + 1)
})?;

if last_slot != SLOT_SIZE {
return Err(RedisError::from((
ErrorKind::ResponseError,
"Slot refresh error.",
format!("Lacks the slots >= {last_slot}"),
)));
}
slot_map.clear();
slot_map.fill_slots(&slots_data, read_from_replicas);
trace!("{:?}", slot_map);
Ok(())
}

// The node string passed to this function will always be in the format host:port as it is either:
// - Created by calling ConnectionAddr::to_string (unix connections are not supported in cluster mode)
// - Returned from redis via the ASK/MOVED response
Expand Down Expand Up @@ -834,7 +733,7 @@ pub(crate) fn get_connection_info(
})
}

fn get_connection_addr(host: String, port: u16, tls: Option<TlsMode>) -> ConnectionAddr {
pub(crate) fn get_connection_addr(host: String, port: u16, tls: Option<TlsMode>) -> ConnectionAddr {
match tls {
Some(TlsMode::Secure) => ConnectionAddr::TcpTls {
host,
Expand Down
109 changes: 65 additions & 44 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,19 @@ use std::{
marker::Unpin,
mem,
pin::Pin,
sync::{Arc, Mutex},
sync::{atomic, Arc, Mutex},
task::{self, Poll},
};

use crate::{
aio::{ConnectionLike, MultiplexedConnection},
cluster::{build_slot_map, get_connection_info, parse_slots, slot_cmd},
cluster::{get_connection_info, slot_cmd},
cluster_client::{ClusterParams, RetryParams},
cluster_routing::{
MultipleNodeRoutingInfo, Redirect, ResponsePolicy, Route, RoutingInfo,
SingleNodeRoutingInfo, SlotMap,
SingleNodeRoutingInfo,
},
cluster_topology::{calculate_topology, SlotMap},
Cmd, ConnectionInfo, ErrorKind, IntoConnectionInfo, RedisError, RedisFuture, RedisResult,
Value,
};
Expand Down Expand Up @@ -457,6 +458,7 @@ where
refresh_error: None,
state: ConnectionState::PollComplete,
};
// TODO: add retries
connection.refresh_slots().await?;
Ok(connection)
}
Expand Down Expand Up @@ -525,56 +527,74 @@ where

// Query a node to discover slot-> master mappings.
fn refresh_slots(&mut self) -> impl Future<Output = RedisResult<()>> {
self.refresh_slots_with_retries(None)
}
// Query a node to discover slot-> master mappings with retries
fn refresh_slots_with_retries(
&mut self,
retries: Option<Arc<atomic::AtomicUsize>>,
) -> impl Future<Output = RedisResult<()>> {
let inner = self.inner.clone();

async move {
let mut write_guard = inner.conn_lock.write().await;
let mut connections = mem::take(&mut write_guard.0);
let slots = &mut write_guard.1;
let mut result = Ok(());
for (_, conn) in connections.iter_mut() {
let mut conn = conn.clone().await;
let value = match conn.req_packed_command(&slot_cmd()).await {
Ok(value) => value,
Err(err) => {
result = Err(err);
continue;
}
};
match parse_slots(value, inner.cluster_params.tls)
.and_then(|v| build_slot_map(slots, v, inner.cluster_params.read_from_replicas))
{
Ok(_) => {
result = Ok(());
break;
}
Err(err) => result = Err(err),
}
}
result?;

let mut nodes = write_guard.1.values().flatten().collect::<Vec<_>>();
let read_guard = inner.conn_lock.read().await;
let num_of_nodes = read_guard.0.len();
const MAX_REQUESTED_NODES: usize = 50;
let num_of_nodes_to_query = std::cmp::min(num_of_nodes, MAX_REQUESTED_NODES);
let mut requested_nodes = {
let mut rng = thread_rng();
read_guard
.0
.values()
.choose_multiple(&mut rng, num_of_nodes_to_query)
};
let topology_join_results =
futures::future::join_all(requested_nodes.iter_mut().map(|conn| async move {
let mut conn: C = conn.clone().await;
conn.req_packed_command(&slot_cmd()).await
}))
.await;
let topology_values: Vec<_> = topology_join_results
.into_iter()
.filter_map(|r| r.ok())
.collect();
let new_slots = calculate_topology(
topology_values,
retries.clone(),
inner.cluster_params.tls,
inner.cluster_params.read_from_replicas,
num_of_nodes_to_query,
)?;

let connections: &ConnectionMap<C> = &read_guard.0;
let mut nodes = new_slots.values().flatten().collect::<Vec<_>>();
nodes.sort_unstable();
nodes.dedup();
let nodes_len = nodes.len();
let addresses_and_connections_iter = nodes
.into_iter()
.map(|addr| (addr, connections.remove(addr)));

write_guard.0 = stream::iter(addresses_and_connections_iter)
.fold(
HashMap::with_capacity(nodes_len),
|mut connections, (addr, connection)| async {
let conn =
Self::get_or_create_conn(addr, connection, &inner.cluster_params).await;
if let Ok(conn) = conn {
connections.insert(addr.to_string(), async { conn }.boxed().shared());
}
connections
},
)
.await;
.map(|addr| (addr, connections.get(addr).cloned()));
let new_connections: HashMap<String, ConnectionFuture<C>> =
stream::iter(addresses_and_connections_iter)
.fold(
HashMap::with_capacity(nodes_len),
|mut connections, (addr, connection)| async {
let conn =
Self::get_or_create_conn(addr, connection, &inner.cluster_params)
.await;
if let Ok(conn) = conn {
connections
.insert(addr.to_string(), async { conn }.boxed().shared());
}
connections
},
)
.await;

drop(read_guard);
let mut write_guard = inner.conn_lock.write().await;
write_guard.1 = new_slots;
write_guard.0 = new_connections;
Ok(())
}
}
Expand Down Expand Up @@ -1227,6 +1247,7 @@ async fn check_connection<C>(conn: &mut C, timeout: futures_time::time::Duration
where
C: ConnectionLike + Send + 'static,
{
// TODO: Add a check to re-resolve DNS addresses to verify we that we have a connection to the right node
crate::cmd("PING")
.query_async::<_, String>(conn)
.timeout(timeout)
Expand Down
Loading

0 comments on commit 043c363

Please sign in to comment.