Skip to content

Commit

Permalink
Merge pull request #109 from Kuadrant/multiple-policies
Browse files Browse the repository at this point in the history
Store multiple policies in the index
  • Loading branch information
adam-cattermole authored Oct 14, 2024
2 parents 2ae5866 + a20e8b2 commit 6b196bf
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 45 deletions.
16 changes: 9 additions & 7 deletions src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ impl TryFrom<PluginConfiguration> for FilterConfig {
fn try_from(config: PluginConfiguration) -> Result<Self, Self::Error> {
let mut index = PolicyIndex::new();

for rlp in config.policies.iter() {
for rule in &rlp.rules {
for policy in config.policies.iter() {
for rule in &policy.rules {
for condition in &rule.conditions {
for pe in &condition.all_of {
let result = pe.compile();
Expand All @@ -486,8 +486,8 @@ impl TryFrom<PluginConfiguration> for FilterConfig {
}
}

for hostname in rlp.hostnames.iter() {
index.insert(hostname, rlp.clone());
for hostname in policy.hostnames.iter() {
index.insert(hostname, Rc::new(policy.clone()));
}
}

Expand Down Expand Up @@ -1163,15 +1163,17 @@ mod test {

let result = FilterConfig::try_from(res.unwrap());
let filter_config = result.expect("That didn't work");
let rlp_option = filter_config.index.get_longest_match_policy("example.com");
let rlp_option = filter_config
.index
.get_longest_match_policies("example.com");
assert!(rlp_option.is_some());

let rlp_option = filter_config
.index
.get_longest_match_policy("test.toystore.com");
.get_longest_match_policies("test.toystore.com");
assert!(rlp_option.is_some());

let rlp_option = filter_config.index.get_longest_match_policy("unknown");
let rlp_option = filter_config.index.get_longest_match_policies("unknown");
assert!(rlp_option.is_none());
}

Expand Down
17 changes: 10 additions & 7 deletions src/filter/http_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@ impl Filter {
}
}

fn process_policy(&self, policy: &Policy) -> Action {
if let Some(rule) = policy.find_rule_that_applies() {
#[allow(clippy::manual_inspect)]
fn process_policies(&self, policies: &[Rc<Policy>]) -> Action {
if let Some(rule) = policies.iter().find_map(|policy| {
policy.find_rule_that_applies().map(|rule| {
debug!("#{} policy selected {}", self.context_id, policy.name);
rule
})
}) {
self.operation_dispatcher
.borrow_mut()
.build_operations(rule);
Expand Down Expand Up @@ -66,7 +72,7 @@ impl HttpContext for Filter {
match self
.config
.index
.get_longest_match_policy(self.request_authority().as_str())
.get_longest_match_policies(self.request_authority().as_str())
{
None => {
debug!(
Expand All @@ -75,10 +81,7 @@ impl HttpContext for Filter {
);
Action::Continue
}
Some(policy) => {
debug!("#{} policy selected {}", self.context_id, policy.name);
self.process_policy(policy)
}
Some(policies) => self.process_policies(policies),
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct Rule {
pub actions: Vec<Action>,
}

#[derive(Deserialize, Debug, Clone)]
#[derive(Default, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Policy {
pub name: String,
Expand Down
56 changes: 32 additions & 24 deletions src/policy_index.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use radix_trie::Trie;
use std::rc::Rc;

use crate::policy::Policy;

pub struct PolicyIndex {
raw_tree: Trie<String, Policy>,
raw_tree: Trie<String, Vec<Rc<Policy>>>,
}

impl PolicyIndex {
Expand All @@ -13,12 +14,18 @@ impl PolicyIndex {
}
}

pub fn insert(&mut self, subdomain: &str, policy: Policy) {
pub fn insert(&mut self, subdomain: &str, policy: Rc<Policy>) {
let rev = Self::reverse_subdomain(subdomain);
self.raw_tree.insert(rev, policy);
self.raw_tree.map_with_default(
rev,
|policies| {
policies.push(Rc::clone(&policy));
},
vec![Rc::clone(&policy)],
);
}

pub fn get_longest_match_policy(&self, subdomain: &str) -> Option<&Policy> {
pub fn get_longest_match_policies(&self, subdomain: &str) -> Option<&Vec<Rc<Policy>>> {
let rev = Self::reverse_subdomain(subdomain);
self.raw_tree.get_ancestor_value(&rev)
}
Expand All @@ -39,6 +46,7 @@ impl PolicyIndex {
mod tests {
use crate::policy::Policy;
use crate::policy_index::PolicyIndex;
use std::rc::Rc;

fn build_ratelimit_policy(name: &str) -> Policy {
Policy::new(name.to_owned(), Vec::new(), Vec::new())
Expand All @@ -48,29 +56,29 @@ mod tests {
fn not_wildcard_subdomain() {
let mut index = PolicyIndex::new();
let rlp1 = build_ratelimit_policy("rlp1");
index.insert("example.com", rlp1);
index.insert("example.com", Rc::new(rlp1));

let val = index.get_longest_match_policy("test.example.com");
let val = index.get_longest_match_policies("test.example.com");
assert!(val.is_none());

let val = index.get_longest_match_policy("other.com");
let val = index.get_longest_match_policies("other.com");
assert!(val.is_none());

let val = index.get_longest_match_policy("net");
let val = index.get_longest_match_policies("net");
assert!(val.is_none());

let val = index.get_longest_match_policy("example.com");
let val = index.get_longest_match_policies("example.com");
assert!(val.is_some());
assert_eq!(val.unwrap().name, "rlp1");
assert_eq!(val.unwrap()[0].name, "rlp1");
}

#[test]
fn wildcard_subdomain_does_not_match_domain() {
let mut index = PolicyIndex::new();
let rlp1 = build_ratelimit_policy("rlp1");

index.insert("*.example.com", rlp1);
let val = index.get_longest_match_policy("example.com");
index.insert("*.example.com", Rc::new(rlp1));
let val = index.get_longest_match_policies("example.com");
assert!(val.is_none());
}

Expand All @@ -79,38 +87,38 @@ mod tests {
let mut index = PolicyIndex::new();
let rlp1 = build_ratelimit_policy("rlp1");

index.insert("*.example.com", rlp1);
let val = index.get_longest_match_policy("test.example.com");
index.insert("*.example.com", Rc::new(rlp1));
let val = index.get_longest_match_policies("test.example.com");

assert!(val.is_some());
assert_eq!(val.unwrap().name, "rlp1");
assert_eq!(val.unwrap()[0].name, "rlp1");
}

#[test]
fn longest_domain_match() {
let mut index = PolicyIndex::new();
let rlp1 = build_ratelimit_policy("rlp1");
index.insert("*.com", rlp1);
index.insert("*.com", Rc::new(rlp1));
let rlp2 = build_ratelimit_policy("rlp2");
index.insert("*.example.com", rlp2);
index.insert("*.example.com", Rc::new(rlp2));

let val = index.get_longest_match_policy("test.example.com");
let val = index.get_longest_match_policies("test.example.com");
assert!(val.is_some());
assert_eq!(val.unwrap().name, "rlp2");
assert_eq!(val.unwrap()[0].name, "rlp2");

let val = index.get_longest_match_policy("example.com");
let val = index.get_longest_match_policies("example.com");
assert!(val.is_some());
assert_eq!(val.unwrap().name, "rlp1");
assert_eq!(val.unwrap()[0].name, "rlp1");
}

#[test]
fn global_wildcard_match_all() {
let mut index = PolicyIndex::new();
let rlp1 = build_ratelimit_policy("rlp1");
index.insert("*", rlp1);
index.insert("*", Rc::new(rlp1));

let val = index.get_longest_match_policy("test.example.com");
let val = index.get_longest_match_policies("test.example.com");
assert!(val.is_some());
assert_eq!(val.unwrap().name, "rlp1");
assert_eq!(val.unwrap()[0].name, "rlp1");
}
}
4 changes: 2 additions & 2 deletions tests/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ fn it_auths() {
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority"))
.returning(Some("cars.toystore.com"))
// retrieving properties for conditions
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
.expect_log(
Some(LogLevel::Debug),
Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"),
Expand All @@ -111,6 +110,7 @@ fn it_auths() {
)
.expect_get_property(Some(vec!["request", "method"]))
.returning(Some("POST".as_bytes()))
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
// retrieving properties for CheckRequest
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
Expand Down Expand Up @@ -244,7 +244,6 @@ fn it_denies() {
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority"))
.returning(Some("cars.toystore.com"))
// retrieving properties for conditions
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
.expect_log(
Some(LogLevel::Debug),
Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"),
Expand All @@ -263,6 +262,7 @@ fn it_denies() {
)
.expect_get_property(Some(vec!["request", "method"]))
.returning(Some("POST".as_bytes()))
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
// retrieving properties for CheckRequest
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
Expand Down
4 changes: 2 additions & 2 deletions tests/multi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ fn it_performs_authenticated_rate_limiting() {
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority"))
.returning(Some("cars.toystore.com"))
// retrieving properties for conditions
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
.expect_log(
Some(LogLevel::Debug),
Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"),
Expand All @@ -129,6 +128,7 @@ fn it_performs_authenticated_rate_limiting() {
)
.expect_get_property(Some(vec!["request", "method"]))
.returning(Some("POST".as_bytes()))
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
// retrieving properties for CheckRequest
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
Expand Down Expand Up @@ -280,7 +280,6 @@ fn unauthenticated_does_not_ratelimit() {
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority"))
.returning(Some("cars.toystore.com"))
// retrieving properties for conditions
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
.expect_log(
Some(LogLevel::Debug),
Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"),
Expand All @@ -299,6 +298,7 @@ fn unauthenticated_does_not_ratelimit() {
)
.expect_get_property(Some(vec!["request", "method"]))
.returning(Some("POST".as_bytes()))
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
// retrieving properties for CheckRequest
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
Expand Down
10 changes: 8 additions & 2 deletions tests/rate_limited.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ fn it_limits() {
.expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers"))
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority"))
.returning(Some("cars.toystore.com"))
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
// retrieving properties for conditions
.expect_log(
Some(LogLevel::Debug),
Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"),
Expand All @@ -178,6 +178,8 @@ fn it_limits() {
)
.expect_get_property(Some(vec!["request", "method"]))
.returning(Some("POST".as_bytes()))
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
// retrieving tracing headers
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate"))
Expand Down Expand Up @@ -315,7 +317,7 @@ fn it_passes_additional_headers() {
.expect_log(Some(LogLevel::Debug), Some("#2 on_http_request_headers"))
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority"))
.returning(Some("cars.toystore.com"))
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
// retrieving properties for conditions
.expect_log(
Some(LogLevel::Debug),
Some("get_property: selector: request.url_path path: [\"request\", \"url_path\"]"),
Expand All @@ -334,6 +336,8 @@ fn it_passes_additional_headers() {
)
.expect_get_property(Some(vec!["request", "method"]))
.returning(Some("POST".as_bytes()))
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
// retrieving tracing headers
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate"))
Expand Down Expand Up @@ -467,6 +471,7 @@ fn it_rate_limits_with_empty_conditions() {
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":authority"))
.returning(Some("a.com"))
.expect_log(Some(LogLevel::Debug), Some("#2 policy selected some-name"))
// retrieving tracing headers
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("traceparent"))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("tracestate"))
Expand Down Expand Up @@ -588,6 +593,7 @@ fn it_does_not_rate_limits_when_selector_does_not_exist_and_misses_default_value
Some(LogLevel::Debug),
Some("#2 policy selected some-name"),
)
// retrieving properties for RateLimitRequest
.expect_log(
Some(LogLevel::Debug),
Some("get_property: selector: unknown.path path: Path { tokens: [\"unknown\", \"path\"] }"),
Expand Down

0 comments on commit 6b196bf

Please sign in to comment.