diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 717563ab..aa02bdc4 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,16 +1,14 @@ use crate::configuration::{FailureMode, FilterConfig}; -use crate::envoy::{RateLimitRequest, RateLimitResponse, RateLimitResponse_Code}; +use crate::envoy::{RateLimitResponse, RateLimitResponse_Code}; use crate::filter::http_context::TracingHeader::{Baggage, Traceparent, Tracestate}; use crate::policy::Policy; +use crate::service::rate_limit::RateLimitService; +use crate::service::Service; use log::{debug, warn}; use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; use proxy_wasm::types::{Action, Bytes}; use std::rc::Rc; -use std::time::Duration; - -const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; -const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; // tracing headers pub enum TracingHeader { @@ -63,28 +61,16 @@ impl Filter { ); return Action::Continue; } - - let mut rl_req = RateLimitRequest::new(); - rl_req.set_domain(rlp.domain.clone()); - rl_req.set_hits_addend(1); - rl_req.set_descriptors(descriptors); - - let rl_req_serialized = Message::write_to_bytes(&rl_req).unwrap(); // TODO(rahulanand16nov): Error Handling - let rl_tracing_headers = self .tracing_headers .iter() .map(|(header, value)| (header.as_str(), value.as_slice())) .collect(); - match self.dispatch_grpc_call( - rlp.service.as_str(), - RATELIMIT_SERVICE_NAME, - RATELIMIT_METHOD_NAME, - rl_tracing_headers, - Some(&rl_req_serialized), - Duration::from_secs(5), - ) { + let rls = RateLimitService::new(rlp.service.as_str(), rl_tracing_headers); + let message = RateLimitService::message(rlp.domain.clone(), descriptors); + + match rls.send(message) { Ok(call_id) => { debug!( "#{} initiated gRPC call (id# {}) to Limitador", diff --git a/src/lib.rs b/src/lib.rs index 8ee6c311..fb1c60aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ mod filter; mod glob; mod policy; mod policy_index; +mod service; #[cfg(test)] mod tests { diff --git a/src/service.rs b/src/service.rs new file mode 100644 index 00000000..cd7b945a --- /dev/null +++ b/src/service.rs @@ -0,0 +1,10 @@ +pub(crate) mod rate_limit; + +use protobuf::reflect::ProtobufValue; +use protobuf::{Message, RepeatedField}; +use proxy_wasm::types::Status; + +pub trait Service { + fn message(domain: String, descriptors: RepeatedField) -> M; + fn send(&self, message: M) -> Result; +} diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs new file mode 100644 index 00000000..fd2f4089 --- /dev/null +++ b/src/service/rate_limit.rs @@ -0,0 +1,76 @@ +use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; +use crate::service::Service; +use protobuf::{Message, RepeatedField}; +use proxy_wasm::hostcalls::dispatch_grpc_call; +use proxy_wasm::types::Status; +use std::time::Duration; + +const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; +const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; + +pub struct RateLimitService<'a> { + endpoint: String, + metadata: Vec<(&'a str, &'a [u8])>, +} + +impl<'a> RateLimitService<'a> { + pub fn new(endpoint: &str, metadata: Vec<(&'a str, &'a [u8])>) -> RateLimitService<'a> { + Self { + endpoint: String::from(endpoint), + metadata, + } + } +} + +impl Service for RateLimitService<'_> { + fn message( + domain: String, + descriptors: RepeatedField, + ) -> RateLimitRequest { + RateLimitRequest { + domain, + descriptors, + hits_addend: 1, + unknown_fields: Default::default(), + cached_size: Default::default(), + } + } + fn send(&self, message: RateLimitRequest) -> Result { + let msg = Message::write_to_bytes(&message).unwrap(); // TODO(didierofrivia): Error Handling + dispatch_grpc_call( + self.endpoint.as_str(), + RATELIMIT_SERVICE_NAME, + RATELIMIT_METHOD_NAME, + self.metadata.clone(), + Some(&msg), + Duration::from_secs(5), + ) + } +} + +#[cfg(test)] +mod tests { + use protobuf::{CachedSize, RepeatedField, UnknownFields}; + use crate::envoy::{RateLimitDescriptor, RateLimitDescriptor_Entry}; + use crate::service::rate_limit::RateLimitService; + use crate::service::Service; + + #[test] + fn builds_message() { + let domain = "rlp1"; + let mut field = RateLimitDescriptor::new(); + let mut entry = RateLimitDescriptor_Entry::new(); + entry.set_key("key1".to_string()); + entry.set_value("value1".to_string()); + field.set_entries(RepeatedField::from_vec(vec![entry])); + let descriptors = RepeatedField::from_vec(vec![field]); + + let msg = RateLimitService::message(domain.to_string(), descriptors.clone()); + + assert_eq!(msg.hits_addend, 1); + assert_eq!(msg.domain, domain.to_string()); + assert_eq!(msg.descriptors , descriptors); + assert_eq!(msg.unknown_fields , UnknownFields::default()); + assert_eq!(msg.cached_size , CachedSize::default()); + } +}