diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index cb68e2a15..587a4aa16 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -28,7 +28,7 @@ use std::{ marker::Unpin, mem, pin::Pin, - sync::Arc, + sync::{Arc, Mutex}, task::{self, Poll}, }; @@ -102,12 +102,9 @@ where .send(Message { cmd: CmdArg::Cmd { cmd: Arc::new(cmd.clone()), // TODO Remove this clone? - func: |mut conn, cmd| { - Box::pin(async move { - conn.req_packed_command(&cmd).await.map(Response::Single) - }) - }, - routing: routing.or_else(|| RoutingInfo::for_routable(cmd)), + routing: CommandRouting::Route( + routing.or_else(|| RoutingInfo::for_routable(cmd)), + ), response_policy: RoutingInfo::response_policy(cmd), }, sender, @@ -148,13 +145,6 @@ where pipeline: Arc::new(pipeline.clone()), // TODO Remove this clone? offset, count, - func: |mut conn, pipeline, offset, count| { - Box::pin(async move { - conn.req_packed_commands(&pipeline, offset, count) - .await - .map(Response::Multiple) - }) - }, route: route.or_else(|| route_pipeline(pipeline)), }, sender, @@ -178,6 +168,7 @@ type ConnectionMap = HashMap>; struct InnerCore { conn_lock: RwLock<(ConnectionMap, SlotMap)>, cluster_params: ClusterParams, + pending_requests: Mutex>>, } type Core = Arc>; @@ -192,22 +183,28 @@ struct ClusterConnInner { >, >, refresh_error: Option, - pending_requests: Vec>, +} + +#[derive(Clone)] +enum CommandRouting { + Route(Option), + Connection { + addr: String, + conn: ConnectionFuture, + }, } #[derive(Clone)] enum CmdArg { Cmd { cmd: Arc, - func: fn(C, Arc) -> RedisFuture<'static, Response>, - routing: Option, + routing: CommandRouting, response_policy: Option, }, Pipeline { pipeline: Arc, offset: usize, count: usize, - func: fn(C, Arc, usize, usize) -> RedisFuture<'static, Response>, route: Option, }, } @@ -330,7 +327,6 @@ enum Next { impl Future for Request where F: Future)>, - C: ConnectionLike, { type Output = Next; @@ -428,7 +424,6 @@ where impl Request where F: Future)>, - C: ConnectionLike, { fn respond(self: Pin<&mut Self>, msg: RedisResult) { // If `send` errors the receiver has dropped and thus does not care about the message @@ -454,12 +449,12 @@ where let inner = Arc::new(InnerCore { conn_lock: RwLock::new((connections, Default::default())), cluster_params, + pending_requests: Mutex::new(Vec::new()), }); let mut connection = ClusterConnInner { inner, in_flight_requests: Default::default(), refresh_error: None, - pending_requests: Vec::new(), state: ConnectionState::PollComplete, }; connection.refresh_slots().await?; @@ -584,67 +579,43 @@ where } } - async fn execute_on_multiple_nodes<'a>( - func: fn(C, Arc) -> RedisFuture<'static, Response>, - cmd: &'a Arc, - routing: &'a MultipleNodeRoutingInfo, - core: Core, + async fn aggregate_results( + receivers: Vec<(String, oneshot::Receiver>)>, + routing: &MultipleNodeRoutingInfo, response_policy: Option, - ) -> (OperationTarget, RedisResult) { - let read_guard = core.conn_lock.read().await; - let connections: Vec<_> = read_guard - .1 - .addresses_for_multi_routing(routing) - .into_iter() - .enumerate() - .filter_map(|(index, addr)| { - read_guard.0.get(addr).cloned().map(|conn| { - let cmd = match routing { - MultipleNodeRoutingInfo::MultiSlot(vec) => { - let mut new_cmd = Cmd::new(); - new_cmd.arg(cmd.arg_idx(0)); - let (_, indices) = vec.get(index).unwrap(); - for index in indices { - new_cmd.arg(cmd.arg_idx(*index)); - } - Arc::new(new_cmd) - } - _ => cmd.clone(), - }; - (addr.to_string(), conn, cmd) - }) - }) - .collect(); - drop(read_guard); - + ) -> RedisResult { let extract_result = |response| match response { Response::Single(value) => value, Response::Multiple(_) => unreachable!(), }; - let run_func = |(_, conn, cmd)| { - Box::pin(async move { - let conn = conn.await; - Ok(extract_result(func(conn, cmd).await?)) - }) + let convert_result = |res: Result, _>| { + res.map_err(|_| RedisError::from((ErrorKind::ResponseError, "request wasn't handled due to internal failure"))) // this happens only if the result sender is dropped before usage. + .and_then(|res| res.map(extract_result)) + }; + + let get_receiver = |(_, receiver): (_, oneshot::Receiver>)| async { + convert_result(receiver.await) }; // TODO - once Value::Error will be merged, these will need to be updated to handle this new value. - let result = match response_policy { + match response_policy { Some(ResponsePolicy::AllSucceeded) => { - future::try_join_all(connections.into_iter().map(run_func)) + future::try_join_all(receivers.into_iter().map(get_receiver)) .await .map(|mut results| results.pop().unwrap()) // unwrap is safe, since at least one function succeeded } - Some(ResponsePolicy::OneSucceeded) => { - future::select_ok(connections.into_iter().map(run_func)) - .await - .map(|(result, _)| result) - } + Some(ResponsePolicy::OneSucceeded) => future::select_ok( + receivers + .into_iter() + .map(|tuple| Box::pin(get_receiver(tuple))), + ) + .await + .map(|(result, _)| result), Some(ResponsePolicy::OneSucceededNonEmpty) => { - future::select_ok(connections.into_iter().map(|tuple| { + future::select_ok(receivers.into_iter().map(|(_, receiver)| { Box::pin(async move { - let result = run_func(tuple).await?; + let result = convert_result(receiver.await)?; match result { Value::Nil => Err((ErrorKind::ResponseError, "no value found").into()), _ => Ok(result), @@ -655,17 +626,17 @@ where .map(|(result, _)| result) } Some(ResponsePolicy::Aggregate(op)) => { - future::try_join_all(connections.into_iter().map(run_func)) + future::try_join_all(receivers.into_iter().map(get_receiver)) .await .and_then(|results| crate::cluster_routing::aggregate(results, op)) } Some(ResponsePolicy::AggregateLogical(op)) => { - future::try_join_all(connections.into_iter().map(run_func)) + future::try_join_all(receivers.into_iter().map(get_receiver)) .await .and_then(|results| crate::cluster_routing::logical_aggregate(results, op)) } Some(ResponsePolicy::CombineArrays) => { - future::try_join_all(connections.into_iter().map(run_func)) + future::try_join_all(receivers.into_iter().map(get_receiver)) .await .and_then(|results| match routing { MultipleNodeRoutingInfo::MultiSlot(vec) => { @@ -681,51 +652,122 @@ where // This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user. // TODO - once RESP3 is merged, return a map value here. // TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes. - future::try_join_all(connections.into_iter().map(|(addr, conn, cmd)| async move { - let conn = conn.await; - Ok(Value::Bulk(vec![ - Value::Data(addr.into_bytes()), - extract_result(func(conn, cmd).await?), - ])) + future::try_join_all(receivers.into_iter().map(|(addr, receiver)| async move { + let result = convert_result(receiver.await)?; + Ok(Value::Bulk(vec![Value::Data(addr.into_bytes()), result])) })) .await .map(Value::Bulk) } } - .map(Response::Single); + } + + async fn execute_on_multiple_nodes<'a>( + cmd: &'a Arc, + routing: &'a MultipleNodeRoutingInfo, + core: Core, + response_policy: Option, + ) -> (OperationTarget, RedisResult) { + let read_guard = core.conn_lock.read().await; + let (receivers, requests): (Vec<_>, Vec<_>) = read_guard + .1 + .addresses_for_multi_routing(routing) + .into_iter() + .enumerate() + .filter_map(|(index, addr)| { + read_guard.0.get(addr).cloned().map(|conn| { + let cmd = match routing { + MultipleNodeRoutingInfo::MultiSlot(vec) => { + let mut new_cmd = Cmd::new(); + new_cmd.arg(cmd.arg_idx(0)); + let (_, indices) = vec.get(index).unwrap(); + for index in indices { + new_cmd.arg(cmd.arg_idx(*index)); + } + Arc::new(new_cmd) + } + _ => cmd.clone(), + }; + let (sender, receiver) = oneshot::channel(); + let addr = addr.to_string(); + ( + (addr.clone(), receiver), + PendingRequest { + retry: 0, + sender, + info: RequestInfo { + cmd: CmdArg::Cmd { + cmd, + routing: CommandRouting::Connection { addr, conn }, + response_policy: None, + }, + redirect: None, + }, + }, + ) + }) + }) + .unzip(); + drop(read_guard); + core.pending_requests.lock().unwrap().extend(requests); + + let result = Self::aggregate_results(receivers, routing, response_policy) + .await + .map(Response::Single); (OperationTarget::FanOut, result) } async fn try_cmd_request( cmd: Arc, - func: fn(C, Arc) -> RedisFuture<'static, Response>, redirect: Option, - routing: Option, + routing: CommandRouting, response_policy: Option, core: Core, asking: bool, ) -> (OperationTarget, RedisResult) { - let route_option = match routing - .as_ref() - .unwrap_or(&RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) - { - RoutingInfo::MultiNode(multi_node_routing) => { - return Self::execute_on_multiple_nodes( - func, - &cmd, - multi_node_routing, - core, - response_policy, - ) - .await + let route_option = if redirect.is_some() { + // if we have a redirect, we don't take info from `routing`. + // TODO - combine the info in `routing` and `redirect` and `asking` into a single structure, so there won't be this question of which field takes precedence. + None + } else { + match routing { + // commands that are sent to multiple nodes are handled here. + CommandRouting::Route(Some(RoutingInfo::MultiNode(multi_node_routing))) => { + assert!(!asking); + assert!(redirect.is_none()); + return Self::execute_on_multiple_nodes( + &cmd, + &multi_node_routing, + core, + response_policy, + ) + .await; + } + + // commands that have concrete connections, and don't require redirection, are handled here. + CommandRouting::Connection { addr, conn } => { + let mut conn = conn.await; + let result = conn.req_packed_command(&cmd).await.map(Response::Single); + return (addr.into(), result); + } + + CommandRouting::Route(Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(route), + ))) => Some(route), + + CommandRouting::Route(Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::Random, + ))) => None, + + CommandRouting::Route(None) => None, } - RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random) => None, - RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route)) => Some(route), }; - let (addr, conn) = Self::get_connection(redirect, route_option, core, asking).await; - let result = func(conn, cmd).await; + // if we reached this point, we're sending the command only to single node, and we need to find the + // right connection to the node. + let (addr, mut conn) = Self::get_connection(redirect, route_option, core, asking).await; + let result = conn.req_packed_command(&cmd).await.map(Response::Single); (addr.into(), result) } @@ -733,11 +775,13 @@ where pipeline: Arc, offset: usize, count: usize, - func: fn(C, Arc, usize, usize) -> RedisFuture<'static, Response>, conn: impl Future, ) -> (OperationTarget, RedisResult) { - let (addr, conn) = conn.await; - let result = func(conn, pipeline, offset, count).await; + let (addr, mut conn) = conn.await; + let result = conn + .req_packed_commands(&pipeline, offset, count) + .await + .map(Response::Multiple); (OperationTarget::Node { address: addr }, result) } @@ -750,34 +794,23 @@ where match info.cmd { CmdArg::Cmd { cmd, - func, routing, response_policy, } => { - Self::try_cmd_request( - cmd, - func, - info.redirect, - routing, - response_policy, - core, - asking, - ) - .await + Self::try_cmd_request(cmd, info.redirect, routing, response_policy, core, asking) + .await } CmdArg::Pipeline { pipeline, offset, count, - func, route, } => { Self::try_pipeline_request( pipeline, offset, count, - func, - Self::get_connection(info.redirect, route.as_ref(), core, asking), + Self::get_connection(info.redirect, route, core, asking), ) .await } @@ -786,7 +819,7 @@ where async fn get_connection( mut redirect: Option, - route: Option<&Route>, + route: Option, core: Core, asking: bool, ) -> (String, C) { @@ -872,8 +905,9 @@ where fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll { let mut poll_flush_action = PollFlushAction::None; - if !self.pending_requests.is_empty() { - let mut pending_requests = mem::take(&mut self.pending_requests); + let mut pending_requests_guard = self.inner.pending_requests.lock().unwrap(); + if !pending_requests_guard.is_empty() { + let mut pending_requests = mem::take(&mut *pending_requests_guard); for request in pending_requests.drain(..) { // Drop the request if noone is waiting for a response to free up resources for // requests callers care about (load shedding). It will be ambigous whether the @@ -889,8 +923,9 @@ where future: RequestState::Future { future }, })); } - self.pending_requests = pending_requests; + *pending_requests_guard = pending_requests; } + drop(pending_requests_guard); loop { let result = match Pin::new(&mut self.in_flight_requests).poll_next(cx) { @@ -931,7 +966,7 @@ where poll_flush_action.change_state(PollFlushAction::RebuildSlots) } }; - self.pending_requests.push(request); + self.inner.pending_requests.lock().unwrap().push(request); } } } @@ -958,7 +993,7 @@ where (*request) .as_mut() .respond(Err(self.refresh_error.take().unwrap())); - } else if let Some(request) = self.pending_requests.pop() { + } else if let Some(request) = self.inner.pending_requests.lock().unwrap().pop() { let _ = request.sender.send(Err(self.refresh_error.take().unwrap())); } } @@ -1038,18 +1073,22 @@ where } } - fn start_send(mut self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> { + fn start_send(self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> { trace!("start_send"); let Message { cmd, sender } = msg; let redirect = None; let info = RequestInfo { cmd, redirect }; - self.pending_requests.push(PendingRequest { - retry: 0, - sender, - info, - }); + self.inner + .pending_requests + .lock() + .unwrap() + .push(PendingRequest { + retry: 0, + sender, + info, + }); Ok(()) } diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index 2ae8b6872..4f49be5ad 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -1,7 +1,7 @@ #![cfg(feature = "cluster-async")] mod support; use std::sync::{ - atomic::{self, AtomicI32}, + atomic::{self, AtomicI32, AtomicU16}, atomic::{AtomicBool, Ordering}, Arc, }; @@ -1176,6 +1176,53 @@ fn test_cluster_split_multi_shard_command_and_combine_arrays_of_values() { assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6380"]); } +#[test] +fn test_cluster_handle_asking_error_in_split_multi_shard_command() { + let name = "test_cluster_handle_asking_error_in_split_multi_shard_command"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let asking_called = Arc::new(AtomicU16::new(0)); + let asking_called_cloned = asking_called.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + if cmd_str.contains("ASKING") && port == 6382 { + asking_called_cloned.fetch_add(1, Ordering::Relaxed); + } + if port == 6380 && cmd_str.contains("baz") { + return Err(parse_redis_value( + format!("-ASK 14000 {name}:6382\r\n").as_bytes(), + )); + } + let results = ["foo", "bar", "baz"] + .iter() + .filter_map(|expected_key| { + if cmd_str.contains(expected_key) { + Some(Value::Data(format!("{expected_key}-{port}").into_bytes())) + } else { + None + } + }) + .collect(); + Err(Ok(Value::Bulk(results))) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6382"]); + assert_eq!(asking_called.load(Ordering::Relaxed), 1); +} + #[test] fn test_async_cluster_with_username_and_password() { let cluster = TestClusterContext::new_with_cluster_client_builder(3, 0, |builder| {