diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 1d855a8f..a9ab3c18 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -1,7 +1,7 @@ use crate::configuration::{Extension, ExtensionType, FailureMode}; use crate::envoy::RateLimitDescriptor; use crate::policy::Policy; -use crate::service::{GetMapValuesBytes, GrpcCall, GrpcMessage, GrpcServiceHandler}; +use crate::service::{GetMapValuesBytesFn, GrpcCallFn, GrpcMessage, GrpcServiceHandler}; use protobuf::RepeatedField; use proxy_wasm::hostcalls; use proxy_wasm::types::{Bytes, MapType, Status}; @@ -38,8 +38,8 @@ pub(crate) struct Operation { result: Result, extension: Rc, procedure: Procedure, - grpc_call: GrpcCall, - get_map_values_bytes: GetMapValuesBytes, + grpc_call_fn: GrpcCallFn, + get_map_values_bytes_fn: GetMapValuesBytesFn, } #[allow(dead_code)] @@ -50,8 +50,8 @@ impl Operation { result: Err(Status::Empty), extension, procedure, - grpc_call, - get_map_values_bytes, + grpc_call_fn, + get_map_values_bytes_fn, } } @@ -59,8 +59,8 @@ impl Operation { if let State::Done = self.state { } else { self.result = self.procedure.0.send( - self.get_map_values_bytes, - self.grpc_call, + self.get_map_values_bytes_fn, + self.grpc_call_fn, self.procedure.1.clone(), ); self.state.next(); @@ -147,7 +147,8 @@ impl OperationDispatcher { let mut operations = self.operations.borrow_mut(); if let Some((i, operation)) = operations.iter_mut().enumerate().next() { if let State::Done = operation.get_state() { - Some(operations.remove(i)) + operations.remove(i); + operations.get(i).cloned() // The next op is now at `i` } else { operation.trigger(); Some(operation.clone()) @@ -158,7 +159,7 @@ impl OperationDispatcher { } } -fn grpc_call( +fn grpc_call_fn( upstream_name: &str, service_name: &str, method_name: &str, @@ -176,7 +177,7 @@ fn grpc_call( ) } -fn get_map_values_bytes(map_type: MapType, key: &str) -> Result, Status> { +fn get_map_values_bytes_fn(map_type: MapType, key: &str) -> Result, Status> { hostcalls::get_map_value_bytes(map_type, key) } @@ -186,7 +187,7 @@ mod tests { use crate::envoy::RateLimitRequest; use std::time::Duration; - fn grpc_call( + fn grpc_call_fn_stub( _upstream_name: &str, _service_name: &str, _method_name: &str, @@ -197,7 +198,10 @@ mod tests { Ok(200) } - fn get_map_values_bytes(_map_type: MapType, _key: &str) -> Result, Status> { + fn get_map_values_bytes_fn_stub( + _map_type: MapType, + _key: &str, + ) -> Result, Status> { Ok(Some(Vec::new())) } @@ -218,14 +222,14 @@ mod tests { fn build_operation() -> Operation { Operation { state: State::Pending, - result: Ok(200), + result: Ok(1), extension: Rc::new(Extension::default()), procedure: ( Rc::new(build_grpc_service_handler()), GrpcMessage::RateLimit(build_message()), ), - grpc_call, - get_map_values_bytes, + grpc_call_fn: grpc_call_fn_stub, + get_map_values_bytes_fn: get_map_values_bytes_fn_stub, } } @@ -236,7 +240,7 @@ mod tests { assert_eq!(operation.get_state(), State::Pending); assert_eq!(operation.get_extension_type(), ExtensionType::RateLimit); assert_eq!(operation.get_failure_mode(), FailureMode::Deny); - assert_eq!(operation.get_result(), Ok(200)); + assert_eq!(operation.get_result(), Ok(1)); } #[test] @@ -272,20 +276,37 @@ mod tests { #[test] fn operation_dispatcher_next() { - let operation = build_operation(); let operation_dispatcher = OperationDispatcher::default(); - operation_dispatcher.push_operations(vec![operation]); + operation_dispatcher.push_operations(vec![build_operation(), build_operation()]); - if let Some(operation) = operation_dispatcher.next() { - assert_eq!(operation.get_result(), Ok(200)); - assert_eq!(operation.get_state(), State::Waiting); - } + assert_eq!(operation_dispatcher.get_current_operation_result(), Ok(1)); + assert_eq!( + operation_dispatcher.get_current_operation_state(), + Some(State::Pending) + ); - if let Some(operation) = operation_dispatcher.next() { - assert_eq!(operation.get_result(), Ok(200)); - assert_eq!(operation.get_state(), State::Done); - } - operation_dispatcher.next(); - assert_eq!(operation_dispatcher.get_current_operation_state(), None); + let mut op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Waiting); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Done); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(1)); + assert_eq!(op.unwrap().get_state(), State::Pending); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Waiting); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Done); + + op = operation_dispatcher.next(); + assert!(op.is_none()); + assert!(operation_dispatcher.get_current_operation_state().is_none()); } } diff --git a/src/service.rs b/src/service.rs index aec2aff0..69359ea7 100644 --- a/src/service.rs +++ b/src/service.rs @@ -181,7 +181,7 @@ impl GrpcService { } } -pub type GrpcCall = fn( +pub type GrpcCallFn = fn( upstream_name: &str, service_name: &str, method_name: &str, @@ -190,7 +190,7 @@ pub type GrpcCall = fn( timeout: Duration, ) -> Result; -pub type GetMapValuesBytes = fn(map_type: MapType, key: &str) -> Result, Status>; +pub type GetMapValuesBytesFn = fn(map_type: MapType, key: &str) -> Result, Status>; pub struct GrpcServiceHandler { service: Rc, @@ -207,19 +207,19 @@ impl GrpcServiceHandler { pub fn send( &self, - get_map_values_bytes: GetMapValuesBytes, - grpc_call: GrpcCall, + get_map_values_bytes_fn: GetMapValuesBytesFn, + grpc_call_fn: GrpcCallFn, message: GrpcMessage, ) -> Result { let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self .header_resolver - .get(get_map_values_bytes) + .get(get_map_values_bytes_fn) .iter() .map(|(header, value)| (*header, value.as_slice())) .collect(); - grpc_call( + grpc_call_fn( self.service.endpoint(), self.service.name(), self.service.method(), @@ -255,7 +255,7 @@ impl HeaderResolver { } } - pub fn get(&self, get_map_values_bytes: GetMapValuesBytes) -> &Vec<(&'static str, Bytes)> { + pub fn get(&self, get_map_values_bytes: GetMapValuesBytesFn) -> &Vec<(&'static str, Bytes)> { self.headers.get_or_init(|| { let mut headers = Vec::new(); for header in TracingHeader::all() {