Skip to content

Commit

Permalink
Changed refresh topology to run with retries
Browse files Browse the repository at this point in the history
  • Loading branch information
barshaul committed Aug 3, 2023
1 parent 043c363 commit 48cf685
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 142 deletions.
12 changes: 9 additions & 3 deletions redis/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,16 @@ r2d2 = { version = "0.8.8", optional = true }
# Only needed for cluster
crc16 = { version = "0.4", optional = true }
rand = { version = "0.8", optional = true }
derivative = { version = "2.2.0", optional = true }

# Only needed for async_std support
async-std = { version = "1.8.0", optional = true}
async-trait = { version = "0.1.24", optional = true }
derivative = { version = "2.2.0", optional = true }
# To avoid conflicts, backoff-std-async.version != backoff-tokio.version so we could run tests with --all-features
backoff-std-async = { package = "backoff", version = "0.3.0", optional = true, features = ["async-std"] }

# Only needed for tokio support
backoff-tokio = { package = "backoff", version = "0.4.0", optional = true, features = ["tokio"] }

# Only needed for native tls
native-tls = { version = "0.2", optional = true }
Expand Down Expand Up @@ -93,10 +99,10 @@ tls-native-tls = ["native-tls"]
tls-rustls = ["rustls", "rustls-native-certs"]
tls-rustls-insecure = ["tls-rustls", "rustls/dangerous_configuration"]
tls-rustls-webpki-roots = ["tls-rustls", "webpki-roots"]
async-std-comp = ["aio", "async-std"]
async-std-comp = ["aio", "async-std", "backoff-std-async"]
async-std-native-tls-comp = ["async-std-comp", "async-native-tls", "tls-native-tls"]
async-std-rustls-comp = ["async-std-comp", "futures-rustls", "tls-rustls"]
tokio-comp = ["aio", "tokio", "tokio/net"]
tokio-comp = ["aio", "tokio", "tokio/net", "backoff-tokio"]
tokio-native-tls-comp = ["tokio-comp", "tls-native-tls", "tokio-native-tls"]
tokio-rustls-comp = ["tokio-comp", "tls-rustls", "tokio-rustls"]
connection-manager = ["arc-swap", "futures", "aio", "tokio-retry"]
Expand Down
185 changes: 105 additions & 80 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use std::{
marker::Unpin,
mem,
pin::Pin,
sync::{atomic, Arc, Mutex},
sync::{
atomic::{self, AtomicUsize},
Arc, Mutex,
},
task::{self, Poll},
};

Expand All @@ -40,13 +43,26 @@ use crate::{
MultipleNodeRoutingInfo, Redirect, ResponsePolicy, Route, RoutingInfo,
SingleNodeRoutingInfo,
},
cluster_topology::{calculate_topology, SlotMap},
cluster_topology::{
calculate_topology, SlotMap, DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL,
DEFAULT_REFRESH_SLOTS_RETRY_TIMEOUT,
},
Cmd, ConnectionInfo, ErrorKind, IntoConnectionInfo, RedisError, RedisFuture, RedisResult,
Value,
};

#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
use crate::aio::{async_std::AsyncStd, RedisRuntime};
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
use backoff_std_async::future::retry;
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
use backoff_std_async::{Error, ExponentialBackoff};

#[cfg(feature = "tokio-comp")]
use backoff_tokio::future::retry;
#[cfg(feature = "tokio-comp")]
use backoff_tokio::{Error, ExponentialBackoff};

use futures::{
future::{self, BoxFuture},
prelude::*,
Expand Down Expand Up @@ -458,8 +474,7 @@ where
refresh_error: None,
state: ConnectionState::PollComplete,
};
// TODO: add retries
connection.refresh_slots().await?;
connection.refresh_slots_with_retries().await?;
Ok(connection)
}

Expand Down Expand Up @@ -525,80 +540,6 @@ where
}
}

// Query a node to discover slot-> master mappings.
fn refresh_slots(&mut self) -> impl Future<Output = RedisResult<()>> {
self.refresh_slots_with_retries(None)
}
// Query a node to discover slot-> master mappings with retries
fn refresh_slots_with_retries(
&mut self,
retries: Option<Arc<atomic::AtomicUsize>>,
) -> impl Future<Output = RedisResult<()>> {
let inner = self.inner.clone();

async move {
let read_guard = inner.conn_lock.read().await;
let num_of_nodes = read_guard.0.len();
const MAX_REQUESTED_NODES: usize = 50;
let num_of_nodes_to_query = std::cmp::min(num_of_nodes, MAX_REQUESTED_NODES);
let mut requested_nodes = {
let mut rng = thread_rng();
read_guard
.0
.values()
.choose_multiple(&mut rng, num_of_nodes_to_query)
};
let topology_join_results =
futures::future::join_all(requested_nodes.iter_mut().map(|conn| async move {
let mut conn: C = conn.clone().await;
conn.req_packed_command(&slot_cmd()).await
}))
.await;
let topology_values: Vec<_> = topology_join_results
.into_iter()
.filter_map(|r| r.ok())
.collect();
let new_slots = calculate_topology(
topology_values,
retries.clone(),
inner.cluster_params.tls,
inner.cluster_params.read_from_replicas,
num_of_nodes_to_query,
)?;

let connections: &ConnectionMap<C> = &read_guard.0;
let mut nodes = new_slots.values().flatten().collect::<Vec<_>>();
nodes.sort_unstable();
nodes.dedup();
let nodes_len = nodes.len();
let addresses_and_connections_iter = nodes
.into_iter()
.map(|addr| (addr, connections.get(addr).cloned()));
let new_connections: HashMap<String, ConnectionFuture<C>> =
stream::iter(addresses_and_connections_iter)
.fold(
HashMap::with_capacity(nodes_len),
|mut connections, (addr, connection)| async {
let conn =
Self::get_or_create_conn(addr, connection, &inner.cluster_params)
.await;
if let Ok(conn) = conn {
connections
.insert(addr.to_string(), async { conn }.boxed().shared());
}
connections
},
)
.await;

drop(read_guard);
let mut write_guard = inner.conn_lock.write().await;
write_guard.1 = new_slots;
write_guard.0 = new_connections;
Ok(())
}
}

async fn aggregate_results(
receivers: Vec<(String, oneshot::Receiver<RedisResult<Response>>)>,
routing: &MultipleNodeRoutingInfo,
Expand Down Expand Up @@ -682,6 +623,90 @@ where
}
}

// Query a node to discover slot-> master mappings with retries
fn refresh_slots_with_retries(&mut self) -> impl Future<Output = RedisResult<()>> {
let inner = self.inner.clone();
async move {
let retry_strategy = ExponentialBackoff {
initial_interval: DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL,
max_interval: DEFAULT_REFRESH_SLOTS_RETRY_TIMEOUT,
..Default::default()
};
let retries_counter = AtomicUsize::new(0);
retry(retry_strategy, || {
retries_counter.fetch_add(1, atomic::Ordering::Relaxed);
Self::refresh_slots(
inner.clone(),
retries_counter.load(atomic::Ordering::Relaxed),
)
.map_err(Error::from)
})
.await?;
Ok(())
}
}

// Query a node to discover slot-> master mappings
async fn refresh_slots(inner: Arc<InnerCore<C>>, curr_retry: usize) -> RedisResult<()> {
let read_guard = inner.conn_lock.read().await;
let num_of_nodes = read_guard.0.len();
const MAX_REQUESTED_NODES: usize = 50;
let num_of_nodes_to_query = std::cmp::min(num_of_nodes, MAX_REQUESTED_NODES);
let mut requested_nodes = {
let mut rng = thread_rng();
read_guard
.0
.values()
.choose_multiple(&mut rng, num_of_nodes_to_query)
};
let topology_join_results =
futures::future::join_all(requested_nodes.iter_mut().map(|conn| async move {
let mut conn: C = conn.clone().await;
conn.req_packed_command(&slot_cmd()).await
}))
.await;
let topology_values: Vec<_> = topology_join_results
.into_iter()
.filter_map(|r| r.ok())
.collect();
let new_slots = calculate_topology(
topology_values,
curr_retry,
inner.cluster_params.tls,
inner.cluster_params.read_from_replicas,
num_of_nodes_to_query,
)?;

let connections: &ConnectionMap<C> = &read_guard.0;
let mut nodes = new_slots.values().flatten().collect::<Vec<_>>();
nodes.sort_unstable();
nodes.dedup();
let nodes_len = nodes.len();
let addresses_and_connections_iter = nodes
.into_iter()
.map(|addr| (addr, connections.get(addr).cloned()));
let new_connections: HashMap<String, ConnectionFuture<C>> =
stream::iter(addresses_and_connections_iter)
.fold(
HashMap::with_capacity(nodes_len),
|mut connections, (addr, connection)| async {
let conn =
Self::get_or_create_conn(addr, connection, &inner.cluster_params).await;
if let Ok(conn) = conn {
connections.insert(addr.to_string(), async { conn }.boxed().shared());
}
connections
},
)
.await;

drop(read_guard);
let mut write_guard = inner.conn_lock.write().await;
write_guard.1 = new_slots;
write_guard.0 = new_connections;
Ok(())
}

async fn execute_on_multiple_nodes<'a>(
cmd: &'a Arc<Cmd>,
routing: &'a MultipleNodeRoutingInfo,
Expand Down Expand Up @@ -902,7 +927,7 @@ where
}
Poll::Ready(Err(err)) => {
self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin(
self.refresh_slots(),
self.refresh_slots_with_retries(),
)));
Poll::Ready(Err(err))
}
Expand Down Expand Up @@ -1142,7 +1167,7 @@ where
PollFlushAction::None => return Poll::Ready(Ok(())),
PollFlushAction::RebuildSlots => {
self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(
Box::pin(self.refresh_slots()),
Box::pin(self.refresh_slots_with_retries()),
));
}
PollFlushAction::Reconnect(addrs) => {
Expand Down
41 changes: 20 additions & 21 deletions redis/src/cluster_topology.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! This module provides the functionality to refresh and calculate the cluster topology for Redis Cluster.

use crate::cluster::get_connection_addr;
use crate::cluster_routing::MultipleNodeRoutingInfo;
use crate::cluster_routing::Route;
Expand All @@ -8,10 +10,17 @@ use derivative::Derivative;
use log::trace;
use std::collections::hash_map::DefaultHasher;
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::{collections::HashMap, sync::atomic};
use std::time::Duration;

/// The default number of refersh topology retries
pub const DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES: usize = 3;
/// The default timeout for retrying topology refresh
pub const DEFAULT_REFRESH_SLOTS_RETRY_TIMEOUT: Duration = Duration::from_secs(1);
/// The default initial interval for retrying topology refresh
pub const DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL: Duration = Duration::from_millis(100);

pub(crate) const SLOT_SIZE: u16 = 16384;

Expand Down Expand Up @@ -229,7 +238,7 @@ fn calculate_hash<T: Hash>(t: &T) -> u64 {

pub(crate) fn calculate_topology(
topology_views: Vec<Value>,
retries: Option<Arc<atomic::AtomicUsize>>, // TODO: change to usize
curr_retry: usize,
tls_mode: Option<TlsMode>,
read_from_replicas: bool,
num_of_queried_nodes: usize,
Expand Down Expand Up @@ -278,10 +287,8 @@ pub(crate) fn calculate_topology(
};
if has_more_than_a_single_max {
// More than a single most frequent view was found
// If it's the last retry, or if we it's a 2-nodes cluster, we'll return all found topologies to be checked by the caller
if (retries.is_some() && retries.unwrap().fetch_sub(1, atomic::Ordering::SeqCst) == 1)
|| num_of_queried_nodes < 3
{
// If we reached the last retry, or if we it's a 2-nodes cluster, we'll return all found topologies to be checked by the caller
if curr_retry >= DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES || num_of_queried_nodes < 3 {
for (idx, topology_view) in hash_view_map.values().enumerate() {
match parse_slots(&topology_view.topology_value, tls_mode)
.and_then(|v| build_slot_map(&mut new_slots, v, read_from_replicas))
Expand Down Expand Up @@ -329,7 +336,6 @@ pub(crate) fn calculate_topology(
mod tests {
use super::*;
use crate::cluster_routing::SlotAddrs;
use std::sync::atomic::AtomicUsize;

#[test]
fn test_get_hashtag() {
Expand Down Expand Up @@ -415,7 +421,7 @@ mod tests {
get_view(&ViewType::TwoNodesViewFullCoverage),
];
let topology_view =
calculate_topology(topology_results, None, None, false, queried_nodes).unwrap();
calculate_topology(topology_results, 1, None, false, queried_nodes).unwrap();
let res: Vec<_> = topology_view.values().collect();
let node_1 = get_node_addr("node1", 6379);
let expected: Vec<&SlotAddrs> = vec![&node_1];
Expand All @@ -431,7 +437,7 @@ mod tests {
get_view(&ViewType::TwoNodesViewFullCoverage),
get_view(&ViewType::TwoNodesViewMissingSlots),
];
let topology_view = calculate_topology(topology_results, None, None, false, queried_nodes);
let topology_view = calculate_topology(topology_results, 1, None, false, queried_nodes);
assert!(topology_view.is_err());
}

Expand All @@ -444,14 +450,8 @@ mod tests {
get_view(&ViewType::TwoNodesViewFullCoverage),
get_view(&ViewType::TwoNodesViewMissingSlots),
];
let topology_view = calculate_topology(
topology_results,
Some(Arc::new(AtomicUsize::new(1))),
None,
false,
queried_nodes,
)
.unwrap();
let topology_view =
calculate_topology(topology_results, 3, None, false, queried_nodes).unwrap();
let res: Vec<_> = topology_view.values().collect();
let node_1 = get_node_addr("node1", 6379);
let node_2 = get_node_addr("node2", 6380);
Expand All @@ -468,7 +468,7 @@ mod tests {
get_view(&ViewType::TwoNodesViewMissingSlots),
];
let topology_view =
calculate_topology(topology_results, None, None, false, queried_nodes).unwrap();
calculate_topology(topology_results, 1, None, false, queried_nodes).unwrap();
let res: Vec<_> = topology_view.values().collect();
let node_1 = get_node_addr("node1", 6379);
let node_2 = get_node_addr("node2", 6380);
Expand All @@ -484,8 +484,7 @@ mod tests {
get_view(&ViewType::SingleNodeViewMissingSlots),
get_view(&ViewType::TwoNodesViewMissingSlots),
];
let topology_view_res =
calculate_topology(topology_results, None, None, false, queried_nodes);
let topology_view_res = calculate_topology(topology_results, 1, None, false, queried_nodes);
assert!(topology_view_res.is_err());
}
}
3 changes: 2 additions & 1 deletion redis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ mod cluster_pipeline;
pub mod cluster_routing;

#[cfg(feature = "cluster")]
mod cluster_topology;
#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))]
pub mod cluster_topology;

#[cfg(feature = "r2d2")]
#[cfg_attr(docsrs, doc(cfg(feature = "r2d2")))]
Expand Down
Loading

0 comments on commit 48cf685

Please sign in to comment.