From 6f729bf1e27a2d443fa360d2906037eb27b1290b Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Fri, 11 Oct 2024 15:23:44 +0100 Subject: [PATCH 1/2] Store multiple policies in the index Signed-off-by: Adam Cattermole --- src/configuration.rs | 16 ++++++----- src/filter/http_context.rs | 16 ++++++----- src/policy.rs | 2 +- src/policy_index.rs | 56 ++++++++++++++++++++++---------------- tests/auth.rs | 4 +-- tests/multi.rs | 4 +-- tests/rate_limited.rs | 10 +++++-- 7 files changed, 63 insertions(+), 45 deletions(-) diff --git a/src/configuration.rs b/src/configuration.rs index fbe2abed..339cb0e1 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -466,8 +466,8 @@ impl TryFrom for FilterConfig { fn try_from(config: PluginConfiguration) -> Result { let mut index = PolicyIndex::new(); - for rlp in config.policies.iter() { - for rule in &rlp.rules { + for policy in config.policies.iter() { + for rule in &policy.rules { for condition in &rule.conditions { for pe in &condition.all_of { let result = pe.compile(); @@ -486,8 +486,8 @@ impl TryFrom for FilterConfig { } } - for hostname in rlp.hostnames.iter() { - index.insert(hostname, rlp.clone()); + for hostname in policy.hostnames.iter() { + index.insert(hostname, Rc::new(policy.clone())); } } @@ -1163,15 +1163,17 @@ mod test { let result = FilterConfig::try_from(res.unwrap()); let filter_config = result.expect("That didn't work"); - let rlp_option = filter_config.index.get_longest_match_policy("example.com"); + let rlp_option = filter_config + .index + .get_longest_match_policies("example.com"); assert!(rlp_option.is_some()); let rlp_option = filter_config .index - .get_longest_match_policy("test.toystore.com"); + .get_longest_match_policies("test.toystore.com"); assert!(rlp_option.is_some()); - let rlp_option = filter_config.index.get_longest_match_policy("unknown"); + let rlp_option = filter_config.index.get_longest_match_policies("unknown"); assert!(rlp_option.is_none()); } diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index be35779e..47b9ed6f 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -29,8 +29,13 @@ impl Filter { } } - fn process_policy(&self, policy: &Policy) -> Action { - if let Some(rule) = policy.find_rule_that_applies() { + fn process_policies(&self, policies: &[Rc]) -> Action { + if let Some(rule) = policies.iter().find_map(|policy| { + policy.find_rule_that_applies().map(|rule| { + debug!("#{} policy selected {}", self.context_id, policy.name); + rule + }) + }) { self.operation_dispatcher .borrow_mut() .build_operations(rule); @@ -66,7 +71,7 @@ impl HttpContext for Filter { match self .config .index - .get_longest_match_policy(self.request_authority().as_str()) + .get_longest_match_policies(self.request_authority().as_str()) { None => { debug!( @@ -75,10 +80,7 @@ impl HttpContext for Filter { ); Action::Continue } - Some(policy) => { - debug!("#{} policy selected {}", self.context_id, policy.name); - self.process_policy(policy) - } + Some(policies) => self.process_policies(policies), } } diff --git a/src/policy.rs b/src/policy.rs index db9ed9fe..af3d5179 100644 --- a/src/policy.rs +++ b/src/policy.rs @@ -16,7 +16,7 @@ pub struct Rule { pub actions: Vec, } -#[derive(Deserialize, Debug, Clone)] +#[derive(Default, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct Policy { pub name: String, diff --git a/src/policy_index.rs b/src/policy_index.rs index 0c0759a9..1ef684de 100644 --- a/src/policy_index.rs +++ b/src/policy_index.rs @@ -1,9 +1,10 @@ use radix_trie::Trie; +use std::rc::Rc; use crate::policy::Policy; pub struct PolicyIndex { - raw_tree: Trie, + raw_tree: Trie>>, } impl PolicyIndex { @@ -13,12 +14,18 @@ impl PolicyIndex { } } - pub fn insert(&mut self, subdomain: &str, policy: Policy) { + pub fn insert(&mut self, subdomain: &str, policy: Rc) { let rev = Self::reverse_subdomain(subdomain); - self.raw_tree.insert(rev, policy); + self.raw_tree.map_with_default( + rev, + |policies| { + policies.push(Rc::clone(&policy)); + }, + vec![Rc::clone(&policy)], + ); } - pub fn get_longest_match_policy(&self, subdomain: &str) -> Option<&Policy> { + pub fn get_longest_match_policies(&self, subdomain: &str) -> Option<&Vec>> { let rev = Self::reverse_subdomain(subdomain); self.raw_tree.get_ancestor_value(&rev) } @@ -39,6 +46,7 @@ impl PolicyIndex { mod tests { use crate::policy::Policy; use crate::policy_index::PolicyIndex; + use std::rc::Rc; fn build_ratelimit_policy(name: &str) -> Policy { Policy::new(name.to_owned(), Vec::new(), Vec::new()) @@ -48,20 +56,20 @@ mod tests { fn not_wildcard_subdomain() { let mut index = PolicyIndex::new(); let rlp1 = build_ratelimit_policy("rlp1"); - index.insert("example.com", rlp1); + index.insert("example.com", Rc::new(rlp1)); - let val = index.get_longest_match_policy("test.example.com"); + let val = index.get_longest_match_policies("test.example.com"); assert!(val.is_none()); - let val = index.get_longest_match_policy("other.com"); + let val = index.get_longest_match_policies("other.com"); assert!(val.is_none()); - let val = index.get_longest_match_policy("net"); + let val = index.get_longest_match_policies("net"); assert!(val.is_none()); - let val = index.get_longest_match_policy("example.com"); + let val = index.get_longest_match_policies("example.com"); assert!(val.is_some()); - assert_eq!(val.unwrap().name, "rlp1"); + assert_eq!(val.unwrap()[0].name, "rlp1"); } #[test] @@ -69,8 +77,8 @@ mod tests { let mut index = PolicyIndex::new(); let rlp1 = build_ratelimit_policy("rlp1"); - index.insert("*.example.com", rlp1); - let val = index.get_longest_match_policy("example.com"); + index.insert("*.example.com", Rc::new(rlp1)); + let val = index.get_longest_match_policies("example.com"); assert!(val.is_none()); } @@ -79,38 +87,38 @@ mod tests { let mut index = PolicyIndex::new(); let rlp1 = build_ratelimit_policy("rlp1"); - index.insert("*.example.com", rlp1); - let val = index.get_longest_match_policy("test.example.com"); + index.insert("*.example.com", Rc::new(rlp1)); + let val = index.get_longest_match_policies("test.example.com"); assert!(val.is_some()); - assert_eq!(val.unwrap().name, "rlp1"); + assert_eq!(val.unwrap()[0].name, "rlp1"); } #[test] fn longest_domain_match() { let mut index = PolicyIndex::new(); let rlp1 = build_ratelimit_policy("rlp1"); - index.insert("*.com", rlp1); + index.insert("*.com", Rc::new(rlp1)); let rlp2 = build_ratelimit_policy("rlp2"); - index.insert("*.example.com", rlp2); + index.insert("*.example.com", Rc::new(rlp2)); - let val = index.get_longest_match_policy("test.example.com"); + let val = index.get_longest_match_policies("test.example.com"); assert!(val.is_some()); - assert_eq!(val.unwrap().name, "rlp2"); + assert_eq!(val.unwrap()[0].name, "rlp2"); - let val = index.get_longest_match_policy("example.com"); + let val = index.get_longest_match_policies("example.com"); assert!(val.is_some()); - assert_eq!(val.unwrap().name, "rlp1"); + assert_eq!(val.unwrap()[0].name, "rlp1"); } #[test] fn global_wildcard_match_all() { let mut index = PolicyIndex::new(); let rlp1 = build_ratelimit_policy("rlp1"); - index.insert("*", rlp1); + index.insert("*", Rc::new(rlp1)); - let val = index.get_longest_match_policy("test.example.com"); + let val = index.get_longest_match_policies("test.example.com"); assert!(val.is_some()); - assert_eq!(val.unwrap().name, "rlp1"); + assert_eq!(val.unwrap()[0].name, "rlp1"); } } diff --git a/tests/auth.rs b/tests/auth.rs index 85006da8..7e7a8e24 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -92,7 +92,6 @@ fn it_auths() { .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("cars.toystore.com")) // retrieving properties for conditions - .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) .expect_log( Some(LogLevel::Debug), Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"), @@ -111,6 +110,7 @@ fn it_auths() { ) .expect_get_property(Some(vec!["request", "method"])) .returning(Some("POST".as_bytes())) + .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) // retrieving properties for CheckRequest .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) .returning(None) @@ -244,7 +244,6 @@ fn it_denies() { .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("cars.toystore.com")) // retrieving properties for conditions - .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) .expect_log( Some(LogLevel::Debug), Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"), @@ -263,6 +262,7 @@ fn it_denies() { ) .expect_get_property(Some(vec!["request", "method"])) .returning(Some("POST".as_bytes())) + .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) // retrieving properties for CheckRequest .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) .returning(None) diff --git a/tests/multi.rs b/tests/multi.rs index 3c6525f4..8fa47e96 100644 --- a/tests/multi.rs +++ b/tests/multi.rs @@ -110,7 +110,6 @@ fn it_performs_authenticated_rate_limiting() { .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("cars.toystore.com")) // retrieving properties for conditions - .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) .expect_log( Some(LogLevel::Debug), Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"), @@ -129,6 +128,7 @@ fn it_performs_authenticated_rate_limiting() { ) .expect_get_property(Some(vec!["request", "method"])) .returning(Some("POST".as_bytes())) + .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) // retrieving properties for CheckRequest .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) .returning(None) @@ -280,7 +280,6 @@ fn unauthenticated_does_not_ratelimit() { .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("cars.toystore.com")) // retrieving properties for conditions - .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) .expect_log( Some(LogLevel::Debug), Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"), @@ -299,6 +298,7 @@ fn unauthenticated_does_not_ratelimit() { ) .expect_get_property(Some(vec!["request", "method"])) .returning(Some("POST".as_bytes())) + .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) // retrieving properties for CheckRequest .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) .returning(None) diff --git a/tests/rate_limited.rs b/tests/rate_limited.rs index 7daeb930..813e46d6 100644 --- a/tests/rate_limited.rs +++ b/tests/rate_limited.rs @@ -159,7 +159,7 @@ fn it_limits() { .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("cars.toystore.com")) - .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) + // retrieving properties for conditions .expect_log( Some(LogLevel::Debug), Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"), @@ -178,6 +178,8 @@ fn it_limits() { ) .expect_get_property(Some(vec!["request", "method"])) .returning(Some("POST".as_bytes())) + .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) + // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) @@ -315,7 +317,7 @@ fn it_passes_additional_headers() { .expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("cars.toystore.com")) - .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) + // retrieving properties for conditions .expect_log( Some(LogLevel::Debug), Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"), @@ -334,6 +336,8 @@ fn it_passes_additional_headers() { ) .expect_get_property(Some(vec!["request", "method"])) .returning(Some("POST".as_bytes())) + .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) + // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) @@ -467,6 +471,7 @@ fn it_rate_limits_with_empty_conditions() { .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority")) .returning(Some("a.com")) .expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name")) + // retrieving tracing headers .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent")) .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate")) @@ -588,6 +593,7 @@ fn it_does_not_rate_limits_when_selector_does_not_exist_and_misses_default_value Some(LogLevel::Debug), Some("#2 policy selected some-name"), ) + // retrieving properties for RateLimitRequest .expect_log( Some(LogLevel::Debug), Some("get_property: selector: unknown.path path: Path { tokens: [\"unknown\", \"path\"] }"), From a20e8b27635a579a4c893a5b713802e2338afc88 Mon Sep 17 00:00:00 2001 From: Adam Cattermole Date: Fri, 11 Oct 2024 15:34:02 +0100 Subject: [PATCH 2/2] Allow manual inspect Signed-off-by: Adam Cattermole --- src/filter/http_context.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 47b9ed6f..a26238ee 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -29,6 +29,7 @@ impl Filter { } } + #[allow(clippy::manual_inspect)] fn process_policies(&self, policies: &[Rc]) -> Action { if let Some(rule) = policies.iter().find_map(|policy| { policy.find_rule_that_applies().map(|rule| {