Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle PubSub commands routing #176

Merged
merged 8 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions redis/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,13 @@ where
_ => crate::cluster_routing::combine_array_results(results),
}
}
Some(ResponsePolicy::CombineMaps) => {
let results = results
.into_iter()
.map(|res| res.map(|(_, val)| val))
.collect::<RedisResult<Vec<_>>>()?;
crate::cluster_routing::combine_map_results(results)
}
Some(ResponsePolicy::Special) | None => {
// This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user.
// TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes.
Expand Down
5 changes: 5 additions & 0 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,11 @@ where
_ => crate::cluster_routing::combine_array_results(results),
})
}
Some(ResponsePolicy::CombineMaps) => {
future::try_join_all(receivers.into_iter().map(get_receiver))
.await
.and_then(crate::cluster_routing::combine_map_results)
}
Some(ResponsePolicy::Special) | None => {
// This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user.
// TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes.
Expand Down
125 changes: 113 additions & 12 deletions redis/src/cluster_routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub enum ResponsePolicy {
CombineArrays,
/// Handling is not defined by the Redis standard. Will receive a special case
Special,
/// Combines multiple map responses into a single map.
CombineMaps,
}

/// Defines whether a request should be routed to a single node, or multiple ones.
Expand Down Expand Up @@ -187,8 +189,42 @@ pub fn logical_aggregate(values: Vec<Value>, op: LogicalAggregateOp) -> RedisRes
.collect(),
))
}
/// Aggregate array responses into a single map.
pub fn combine_map_results(values: Vec<Value>) -> RedisResult<Value> {
let mut map: HashMap<Vec<u8>, i64> = HashMap::new();

/// Aggreagte arrau responses into a single array.
for value in values {
match value {
Value::Array(elements) => {
let mut iter = elements.into_iter();

while let Some(key) = iter.next() {
if let Value::BulkString(key_bytes) = key {
if let Some(Value::Int(value)) = iter.next() {
*map.entry(key_bytes).or_insert(0) += value;
} else {
return Err((ErrorKind::TypeError, "expected integer value").into());
}
} else {
return Err((ErrorKind::TypeError, "expected string key").into());
}
}
}
_ => {
return Err((ErrorKind::TypeError, "expected array of values as response").into());
}
}
}

let result_vec: Vec<(Value, Value)> = map
.into_iter()
.map(|(k, v)| (Value::BulkString(k), Value::Int(v)))
.collect();

Ok(Value::Map(result_vec))
}

/// Aggregate array responses into a single array.
pub fn combine_array_results(values: Vec<Value>) -> RedisResult<Value> {
let mut results = Vec::new();

Expand Down Expand Up @@ -302,7 +338,9 @@ impl ResponsePolicy {
b"SCRIPT EXISTS" => Some(ResponsePolicy::AggregateLogical(LogicalAggregateOp::And)),

b"DBSIZE" | b"DEL" | b"EXISTS" | b"SLOWLOG LEN" | b"TOUCH" | b"UNLINK"
| b"LATENCY RESET" => Some(ResponsePolicy::Aggregate(AggregateOp::Sum)),
| b"LATENCY RESET" | b"PUBSUB NUMPAT" => {
Some(ResponsePolicy::Aggregate(AggregateOp::Sum))
}

b"WAIT" => Some(ResponsePolicy::Aggregate(AggregateOp::Min)),

Expand All @@ -314,7 +352,10 @@ impl ResponsePolicy {
Some(ResponsePolicy::AllSucceeded)
}

b"KEYS" | b"MGET" | b"SLOWLOG GET" => Some(ResponsePolicy::CombineArrays),
b"KEYS" | b"MGET" | b"SLOWLOG GET" | b"PUBSUB CHANNELS" | b"PUBSUB SHARDCHANNELS" => {
Some(ResponsePolicy::CombineArrays)
}
b"PUBSUB NUMSUB" | b"PUBSUB SHARDNUMSUB" => Some(ResponsePolicy::CombineMaps),

b"FUNCTION KILL" | b"SCRIPT KILL" => Some(ResponsePolicy::OneSucceeded),

Expand Down Expand Up @@ -354,11 +395,30 @@ enum RouteBy {

fn base_routing(cmd: &[u8]) -> RouteBy {
match cmd {
b"ACL SETUSER" | b"ACL DELUSER" | b"ACL SAVE" | b"CLIENT SETNAME" | b"CLIENT SETINFO"
| b"SLOWLOG GET" | b"SLOWLOG LEN" | b"SLOWLOG RESET" | b"CONFIG SET"
| b"CONFIG RESETSTAT" | b"CONFIG REWRITE" | b"SCRIPT FLUSH" | b"SCRIPT LOAD"
| b"LATENCY RESET" | b"LATENCY GRAPH" | b"LATENCY HISTOGRAM" | b"LATENCY HISTORY"
| b"LATENCY DOCTOR" | b"LATENCY LATEST" => RouteBy::AllNodes,
b"ACL SETUSER"
| b"ACL DELUSER"
| b"ACL SAVE"
| b"CLIENT SETNAME"
| b"CLIENT SETINFO"
| b"SLOWLOG GET"
| b"SLOWLOG LEN"
| b"SLOWLOG RESET"
| b"CONFIG SET"
| b"CONFIG RESETSTAT"
| b"CONFIG REWRITE"
| b"SCRIPT FLUSH"
| b"SCRIPT LOAD"
| b"LATENCY RESET"
| b"LATENCY GRAPH"
| b"LATENCY HISTOGRAM"
| b"LATENCY HISTORY"
| b"LATENCY DOCTOR"
| b"LATENCY LATEST"
| b"PUBSUB NUMPAT"
| b"PUBSUB CHANNELS"
| b"PUBSUB NUMSUB"
| b"PUBSUB SHARDCHANNELS"
| b"PUBSUB SHARDNUMSUB" => RouteBy::AllNodes,

b"DBSIZE"
| b"FLUSHALL"
Expand Down Expand Up @@ -463,10 +523,6 @@ fn base_routing(cmd: &[u8]) -> RouteBy {
| b"MODULE LOAD"
| b"MODULE LOADEX"
| b"MODULE UNLOAD"
| b"PUBSUB CHANNELS"
| b"PUBSUB NUMPAT"
| b"PUBSUB NUMSUB"
| b"PUBSUB SHARDCHANNELS"
| b"READONLY"
| b"READWRITE"
| b"SAVE"
Expand Down Expand Up @@ -1233,4 +1289,49 @@ mod tests {
])
);
}

#[test]
fn test_combine_map_results() {
let input = vec![];
let result = super::combine_map_results(input).unwrap();
assert_eq!(result, Value::Map(vec![]));

let input = vec![
Value::Array(vec![
Value::BulkString(b"key1".to_vec()),
Value::Int(5),
Value::BulkString(b"key2".to_vec()),
Value::Int(10),
]),
Value::Array(vec![
Value::BulkString(b"key1".to_vec()),
Value::Int(3),
Value::BulkString(b"key3".to_vec()),
Value::Int(15),
]),
];
let result = super::combine_map_results(input).unwrap();
let mut expected = vec![
(Value::BulkString(b"key1".to_vec()), Value::Int(8)),
(Value::BulkString(b"key2".to_vec()), Value::Int(10)),
(Value::BulkString(b"key3".to_vec()), Value::Int(15)),
];
expected.sort_unstable_by(|a, b| match (&a.0, &b.0) {
(Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes),
_ => std::cmp::Ordering::Equal,
});
let mut result_vec = match result {
Value::Map(v) => v,
_ => panic!("Expected Map"),
};
result_vec.sort_unstable_by(|a, b| match (&a.0, &b.0) {
(Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes),
_ => std::cmp::Ordering::Equal,
});
assert_eq!(result_vec, expected);

let input = vec![Value::Int(5)];
let result = super::combine_map_results(input);
assert!(result.is_err());
}
}