diff --git a/redis/src/cluster.rs b/redis/src/cluster.rs index 85582cae0..5c0702d85 100644 --- a/redis/src/cluster.rs +++ b/redis/src/cluster.rs @@ -661,6 +661,13 @@ where _ => crate::cluster_routing::combine_array_results(results), } } + Some(ResponsePolicy::CombineMaps) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + crate::cluster_routing::combine_map_results(results) + } Some(ResponsePolicy::Special) | None => { // This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user. // TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes. diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 833fc01ce..e0479d55a 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -1288,6 +1288,11 @@ where _ => crate::cluster_routing::combine_array_results(results), }) } + Some(ResponsePolicy::CombineMaps) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(crate::cluster_routing::combine_map_results) + } Some(ResponsePolicy::Special) | None => { // This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user. // TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes. diff --git a/redis/src/cluster_routing.rs b/redis/src/cluster_routing.rs index 73116c0bc..a37b875be 100644 --- a/redis/src/cluster_routing.rs +++ b/redis/src/cluster_routing.rs @@ -48,6 +48,8 @@ pub enum ResponsePolicy { CombineArrays, /// Handling is not defined by the Redis standard. Will receive a special case Special, + /// Combines multiple map responses into a single map. + CombineMaps, } /// Defines whether a request should be routed to a single node, or multiple ones. @@ -187,8 +189,42 @@ pub fn logical_aggregate(values: Vec, op: LogicalAggregateOp) -> RedisRes .collect(), )) } +/// Aggregate array responses into a single map. +pub fn combine_map_results(values: Vec) -> RedisResult { + let mut map: HashMap, i64> = HashMap::new(); -/// Aggreagte arrau responses into a single array. + for value in values { + match value { + Value::Array(elements) => { + let mut iter = elements.into_iter(); + + while let Some(key) = iter.next() { + if let Value::BulkString(key_bytes) = key { + if let Some(Value::Int(value)) = iter.next() { + *map.entry(key_bytes).or_insert(0) += value; + } else { + return Err((ErrorKind::TypeError, "expected integer value").into()); + } + } else { + return Err((ErrorKind::TypeError, "expected string key").into()); + } + } + } + _ => { + return Err((ErrorKind::TypeError, "expected array of values as response").into()); + } + } + } + + let result_vec: Vec<(Value, Value)> = map + .into_iter() + .map(|(k, v)| (Value::BulkString(k), Value::Int(v))) + .collect(); + + Ok(Value::Map(result_vec)) +} + +/// Aggregate array responses into a single array. pub fn combine_array_results(values: Vec) -> RedisResult { let mut results = Vec::new(); @@ -302,7 +338,9 @@ impl ResponsePolicy { b"SCRIPT EXISTS" => Some(ResponsePolicy::AggregateLogical(LogicalAggregateOp::And)), b"DBSIZE" | b"DEL" | b"EXISTS" | b"SLOWLOG LEN" | b"TOUCH" | b"UNLINK" - | b"LATENCY RESET" => Some(ResponsePolicy::Aggregate(AggregateOp::Sum)), + | b"LATENCY RESET" | b"PUBSUB NUMPAT" => { + Some(ResponsePolicy::Aggregate(AggregateOp::Sum)) + } b"WAIT" => Some(ResponsePolicy::Aggregate(AggregateOp::Min)), @@ -314,7 +352,10 @@ impl ResponsePolicy { Some(ResponsePolicy::AllSucceeded) } - b"KEYS" | b"MGET" | b"SLOWLOG GET" => Some(ResponsePolicy::CombineArrays), + b"KEYS" | b"MGET" | b"SLOWLOG GET" | b"PUBSUB CHANNELS" | b"PUBSUB SHARDCHANNELS" => { + Some(ResponsePolicy::CombineArrays) + } + b"PUBSUB NUMSUB" | b"PUBSUB SHARDNUMSUB" => Some(ResponsePolicy::CombineMaps), b"FUNCTION KILL" | b"SCRIPT KILL" => Some(ResponsePolicy::OneSucceeded), @@ -354,11 +395,30 @@ enum RouteBy { fn base_routing(cmd: &[u8]) -> RouteBy { match cmd { - b"ACL SETUSER" | b"ACL DELUSER" | b"ACL SAVE" | b"CLIENT SETNAME" | b"CLIENT SETINFO" - | b"SLOWLOG GET" | b"SLOWLOG LEN" | b"SLOWLOG RESET" | b"CONFIG SET" - | b"CONFIG RESETSTAT" | b"CONFIG REWRITE" | b"SCRIPT FLUSH" | b"SCRIPT LOAD" - | b"LATENCY RESET" | b"LATENCY GRAPH" | b"LATENCY HISTOGRAM" | b"LATENCY HISTORY" - | b"LATENCY DOCTOR" | b"LATENCY LATEST" => RouteBy::AllNodes, + b"ACL SETUSER" + | b"ACL DELUSER" + | b"ACL SAVE" + | b"CLIENT SETNAME" + | b"CLIENT SETINFO" + | b"SLOWLOG GET" + | b"SLOWLOG LEN" + | b"SLOWLOG RESET" + | b"CONFIG SET" + | b"CONFIG RESETSTAT" + | b"CONFIG REWRITE" + | b"SCRIPT FLUSH" + | b"SCRIPT LOAD" + | b"LATENCY RESET" + | b"LATENCY GRAPH" + | b"LATENCY HISTOGRAM" + | b"LATENCY HISTORY" + | b"LATENCY DOCTOR" + | b"LATENCY LATEST" + | b"PUBSUB NUMPAT" + | b"PUBSUB CHANNELS" + | b"PUBSUB NUMSUB" + | b"PUBSUB SHARDCHANNELS" + | b"PUBSUB SHARDNUMSUB" => RouteBy::AllNodes, b"DBSIZE" | b"FLUSHALL" @@ -463,10 +523,6 @@ fn base_routing(cmd: &[u8]) -> RouteBy { | b"MODULE LOAD" | b"MODULE LOADEX" | b"MODULE UNLOAD" - | b"PUBSUB CHANNELS" - | b"PUBSUB NUMPAT" - | b"PUBSUB NUMSUB" - | b"PUBSUB SHARDCHANNELS" | b"READONLY" | b"READWRITE" | b"SAVE" @@ -1233,4 +1289,49 @@ mod tests { ]) ); } + + #[test] + fn test_combine_map_results() { + let input = vec![]; + let result = super::combine_map_results(input).unwrap(); + assert_eq!(result, Value::Map(vec![])); + + let input = vec![ + Value::Array(vec![ + Value::BulkString(b"key1".to_vec()), + Value::Int(5), + Value::BulkString(b"key2".to_vec()), + Value::Int(10), + ]), + Value::Array(vec![ + Value::BulkString(b"key1".to_vec()), + Value::Int(3), + Value::BulkString(b"key3".to_vec()), + Value::Int(15), + ]), + ]; + let result = super::combine_map_results(input).unwrap(); + let mut expected = vec![ + (Value::BulkString(b"key1".to_vec()), Value::Int(8)), + (Value::BulkString(b"key2".to_vec()), Value::Int(10)), + (Value::BulkString(b"key3".to_vec()), Value::Int(15)), + ]; + expected.sort_unstable_by(|a, b| match (&a.0, &b.0) { + (Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes), + _ => std::cmp::Ordering::Equal, + }); + let mut result_vec = match result { + Value::Map(v) => v, + _ => panic!("Expected Map"), + }; + result_vec.sort_unstable_by(|a, b| match (&a.0, &b.0) { + (Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes), + _ => std::cmp::Ordering::Equal, + }); + assert_eq!(result_vec, expected); + + let input = vec![Value::Int(5)]; + let result = super::combine_map_results(input); + assert!(result.is_err()); + } }