From b65e1c635a6308140ee68cbc6dcbf7f46c533faf Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Mon, 4 Nov 2024 14:20:50 -0500 Subject: [PATCH] CEL Strings extension macros Signed-off-by: Alex Snaps --- src/data/cel.rs | 19 ++- src/data/cel/strings.rs | 364 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 382 insertions(+), 1 deletion(-) create mode 100644 src/data/cel/strings.rs diff --git a/src/data/cel.rs b/src/data/cel.rs index fcce88b..09e1aba 100644 --- a/src/data/cel.rs +++ b/src/data/cel.rs @@ -44,7 +44,7 @@ impl Expression { } pub fn eval(&self) -> Value { - let mut ctx = Context::default(); + let mut ctx = create_context(); let Map { map } = self.build_data_map(); ctx.add_function("getHostProperty", get_host_property); @@ -103,6 +103,23 @@ fn get_host_property(This(this): This) -> ResolveResult { } } +fn create_context<'a>() -> Context<'a> { + let mut ctx = Context::default(); + ctx.add_function("charAt", strings::char_at); + ctx.add_function("indexOf", strings::index_of); + ctx.add_function("join", strings::join); + ctx.add_function("lastIndexOf", strings::last_index_of); + ctx.add_function("lowerAscii", strings::lower_ascii); + ctx.add_function("upperAscii", strings::upper_ascii); + ctx.add_function("trim", strings::trim); + ctx.add_function("replace", strings::replace); + ctx.add_function("split", strings::split); + ctx.add_function("substring", strings::substring); + ctx +} + +mod strings; + #[derive(Clone, Debug)] pub struct Predicate { expression: Expression, diff --git a/src/data/cel/strings.rs b/src/data/cel/strings.rs new file mode 100644 index 0000000..cec034d --- /dev/null +++ b/src/data/cel/strings.rs @@ -0,0 +1,364 @@ +use cel_interpreter::extractors::{Arguments, This}; +use cel_interpreter::{ExecutionError, ResolveResult, Value}; +use std::sync::Arc; + +pub fn char_at(This(this): This>, arg: i64) -> ResolveResult { + match this.chars().nth(arg as usize) { + None => Err(ExecutionError::FunctionError { + function: "String.charAt".to_owned(), + message: format!("No index {arg} on `{this}`"), + }), + Some(c) => Ok(c.to_string().into()), + } +} + +pub fn index_of( + This(this): This>, + arg: Arc, + Arguments(args): Arguments, +) -> ResolveResult { + match args.len() { + 1 => match this.find(&*arg) { + None => Ok((-1).into()), + Some(idx) => Ok((idx as u64).into()), + }, + 2 => { + let base = match args[1] { + Value::Int(i) => i as usize, + Value::UInt(u) => u as usize, + _ => { + return Err(ExecutionError::FunctionError { + function: "String.indexOf".to_owned(), + message: format!( + "Expects 2nd argument to be an Integer, got `{:?}`", + args[1] + ), + }) + } + }; + if base >= this.len() { + return Ok((-1).into()); + } + match this[base..].find(&*arg) { + None => Ok((-1).into()), + Some(idx) => Ok(Value::UInt((base + idx) as u64)), + } + } + _ => Err(ExecutionError::FunctionError { + function: "String.indexOf".to_owned(), + message: format!("Expects 2 arguments at most, got `{args:?}`!"), + }), + } +} + +pub fn last_index_of( + This(this): This>, + arg: Arc, + Arguments(args): Arguments, +) -> ResolveResult { + match args.len() { + 1 => match this.rfind(&*arg) { + None => Ok((-1).into()), + Some(idx) => Ok((idx as u64).into()), + }, + 2 => { + let base = match args[1] { + Value::Int(i) => i as usize, + Value::UInt(u) => u as usize, + _ => { + return Err(ExecutionError::FunctionError { + function: "String.lastIndexOf".to_owned(), + message: format!( + "Expects 2nd argument to be an Integer, got `{:?}`", + args[1] + ), + }) + } + }; + if base >= this.len() { + return Ok((-1).into()); + } + match this[base..].rfind(&*arg) { + None => Ok((-1).into()), + Some(idx) => Ok(Value::UInt((idx) as u64)), + } + } + _ => Err(ExecutionError::FunctionError { + function: "String.lastIndexOf".to_owned(), + message: format!("Expects 2 arguments at most, got `{args:?}`!"), + }), + } +} + +pub fn join(This(this): This>>, Arguments(args): Arguments) -> ResolveResult { + let separator = args + .first() + .map(|v| match v { + Value::String(s) => Ok(s.as_str()), + _ => Err(ExecutionError::FunctionError { + function: "List.join".to_owned(), + message: format!("Expects seperator to be a String, got `{v:?}`!"), + }), + }) + .unwrap_or(Ok(""))?; + Ok(this + .iter() + .map(|v| match v { + Value::String(s) => Ok(s.as_str().to_string()), + _ => Err(ExecutionError::FunctionError { + function: "List.join".to_owned(), + message: "Expects a list of String values!".to_owned(), + }), + }) + .collect::, _>>()? + .join(separator) + .into()) +} + +pub fn lower_ascii(This(this): This>) -> ResolveResult { + Ok(this.to_ascii_lowercase().into()) +} + +pub fn upper_ascii(This(this): This>) -> ResolveResult { + Ok(this.to_ascii_uppercase().into()) +} + +pub fn trim(This(this): This>) -> ResolveResult { + Ok(this.trim().into()) +} + +// .replace(, ) -> +// .replace(, , ) -> +pub fn replace(This(this): This>, Arguments(args): Arguments) -> ResolveResult { + match args.len() { + count @ 2..=3 => { + let from = match &args[0] { + Value::String(s) => s.as_str(), + _ => Err(ExecutionError::FunctionError { + function: "String.replace".to_owned(), + message: format!( + "First argument of type String expected, got `{:?}`", + args[0] + ), + })?, + }; + let to = match &args[1] { + Value::String(s) => s.as_str(), + _ => Err(ExecutionError::FunctionError { + function: "String.replace".to_owned(), + message: format!( + "Second argument of type String expected, got `{:?}`", + args[1] + ), + })?, + }; + if count == 3 { + let n = match &args[2] { + Value::Int(i) => *i as usize, + Value::UInt(u) => *u as usize, + _ => Err(ExecutionError::FunctionError { + function: "String.replace".to_owned(), + message: format!( + "Third argument of type Integer expected, got `{:?}`", + args[2] + ), + })?, + }; + Ok(this.replacen(from, to, n).into()) + } else { + Ok(this.replace(from, to).into()) + } + } + _ => Err(ExecutionError::FunctionError { + function: "String.replace".to_owned(), + message: format!("Expects 2 or 3 arguments, got {args:?}"), + }), + } +} + +pub fn split(This(this): This>, Arguments(args): Arguments) -> ResolveResult { + match args.len() { + count @ 1..=2 => { + let sep = match &args[0] { + Value::String(sep) => sep.as_str(), + _ => { + return Err(ExecutionError::FunctionError { + function: "String.split".to_string(), + message: format!( + "Expects a first argument of type String, got `{:?}`", + args[0] + ), + }) + } + }; + let list = if count == 2 { + let pos = match &args[1] { + Value::UInt(u) => *u as usize, + Value::Int(i) => *i as usize, + _ => Err(ExecutionError::FunctionError { + function: "String.split".to_string(), + message: format!( + "Expects a second argument of type Integer, got `{:?}`", + args[1] + ), + })?, + }; + this.splitn(pos, sep) + .map(|s| Value::String(s.to_owned().into())) + .collect::>() + } else { + this.split(sep) + .map(|s| Value::String(s.to_owned().into())) + .collect::>() + }; + Ok(list.into()) + } + _ => Err(ExecutionError::FunctionError { + function: "String.split".to_owned(), + message: format!("Expects at most 2 arguments, got {args:?}"), + }), + } +} + +pub fn substring(This(this): This>, Arguments(args): Arguments) -> ResolveResult { + match args.len() { + count @ 1..=2 => { + let start = match &args[0] { + Value::Int(i) => *i as usize, + Value::UInt(u) => *u as usize, + _ => Err(ExecutionError::FunctionError { + function: "String.substring".to_string(), + message: format!( + "Expects a first argument of type Integer, got `{:?}`", + args[0] + ), + })?, + }; + if count == 2 { + let end = match &args[1] { + Value::Int(i) => *i as usize, + Value::UInt(u) => *u as usize, + _ => Err(ExecutionError::FunctionError { + function: "String.substring".to_string(), + message: format!( + "Expects a second argument of type Integer, got `{:?}`", + args[0] + ), + })?, + }; + if end < start { + Err(ExecutionError::FunctionError { + function: "String.substring".to_string(), + message: format!("Can't have end be before the start: `{end} < {start}"), + })? + } + Ok(this.split_at(start).1[..(end - start)].to_owned().into()) + } else { + Ok(this.split_at(start).1.to_owned().into()) + } + } + _ => Err(ExecutionError::FunctionError { + function: "String.substring".to_owned(), + message: format!("Expects at most 2 arguments, got {args:?}"), + }), + } +} + +// .substring() -> +// .substring(, ) -> + +#[cfg(test)] +mod tests { + use crate::data::Expression; + use cel_interpreter::Value; + + #[test] + fn extended_string_fn() { + let e = Expression::new("'abc'.charAt(1)").expect("This must be valid CEL"); + assert_eq!(e.eval(), "b".into()); + + let e = Expression::new("'hello mellow'.indexOf('')").expect("This must be valid CEL"); + assert_eq!(e.eval(), 0.into()); + let e = Expression::new("'hello mellow'.indexOf('ello')").expect("This must be valid CEL"); + assert_eq!(e.eval(), 1.into()); + let e = Expression::new("'hello mellow'.indexOf('jello')").expect("This must be valid CEL"); + assert_eq!(e.eval(), (-1).into()); + let e = Expression::new("'hello mellow'.indexOf('', 2)").expect("This must be valid CEL"); + assert_eq!(e.eval(), 2.into()); + let e = + Expression::new("'hello mellow'.indexOf('ello', 20)").expect("This must be valid CEL"); + assert_eq!(e.eval(), (-1).into()); + + let e = Expression::new("'hello mellow'.lastIndexOf('')").expect("This must be valid CEL"); + assert_eq!(e.eval(), 12.into()); + let e = + Expression::new("'hello mellow'.lastIndexOf('ello')").expect("This must be valid CEL"); + assert_eq!(e.eval(), 7.into()); + let e = + Expression::new("'hello mellow'.lastIndexOf('jello')").expect("This must be valid CEL"); + assert_eq!(e.eval(), (-1).into()); + let e = Expression::new("'hello mellow'.lastIndexOf('ello', 6)") + .expect("This must be valid CEL"); + assert_eq!(e.eval(), 1.into()); + let e = Expression::new("'hello mellow'.lastIndexOf('ello', 20)") + .expect("This must be valid CEL"); + assert_eq!(e.eval(), (-1).into()); + + let e = Expression::new("['hello', 'mellow'].join()").expect("This must be valid CEL"); + assert_eq!(e.eval(), "hellomellow".into()); + let e = Expression::new("[].join()").expect("This must be valid CEL"); + assert_eq!(e.eval(), "".into()); + let e = Expression::new("['hello', 'mellow'].join(' ')").expect("This must be valid CEL"); + assert_eq!(e.eval(), "hello mellow".into()); + + let e = Expression::new("'TacoCat'.lowerAscii()").expect("This must be valid CEL"); + assert_eq!(e.eval(), "tacocat".into()); + let e = Expression::new("'TacoCÆt Xii'.lowerAscii()").expect("This must be valid CEL"); + assert_eq!(e.eval(), "tacocÆt xii".into()); + + let e = Expression::new("'TacoCat'.upperAscii()").expect("This must be valid CEL"); + assert_eq!(e.eval(), "TACOCAT".into()); + let e = Expression::new("'TacoCÆt Xii'.upperAscii()").expect("This must be valid CEL"); + assert_eq!(e.eval(), "TACOCÆT XII".into()); + + let e = Expression::new("' \ttrim\n '.trim()").expect("This must be valid CEL"); + assert_eq!(e.eval(), "trim".into()); + + let e = + Expression::new("'hello hello'.replace('he', 'we')").expect("This must be valid CEL"); + assert_eq!(e.eval(), "wello wello".into()); + let e = Expression::new("'hello hello'.replace('he', 'we', -1)") + .expect("This must be valid CEL"); + assert_eq!(e.eval(), "wello wello".into()); + let e = Expression::new("'hello hello'.replace('he', 'we', 1)") + .expect("This must be valid CEL"); + assert_eq!(e.eval(), "wello hello".into()); + let e = Expression::new("'hello hello'.replace('he', 'we', 0)") + .expect("This must be valid CEL"); + assert_eq!(e.eval(), "hello hello".into()); + let e = Expression::new("'hello hello'.replace('', '_')").expect("This must be valid CEL"); + assert_eq!(e.eval(), "_h_e_l_l_o_ _h_e_l_l_o_".into()); + let e = Expression::new("'hello hello'.replace('h', '')").expect("This must be valid CEL"); + assert_eq!(e.eval(), "ello ello".into()); + + let e = Expression::new("'hello hello hello'.split(' ')").expect("This must be valid CEL"); + assert_eq!(e.eval(), vec!["hello", "hello", "hello"].into()); + let e = + Expression::new("'hello hello hello'.split(' ', 0)").expect("This must be valid CEL"); + assert_eq!(e.eval(), Value::List(vec![].into())); + let e = + Expression::new("'hello hello hello'.split(' ', 1)").expect("This must be valid CEL"); + assert_eq!(e.eval(), vec!["hello hello hello"].into()); + let e = + Expression::new("'hello hello hello'.split(' ', 2)").expect("This must be valid CEL"); + assert_eq!(e.eval(), vec!["hello", "hello hello"].into()); + let e = + Expression::new("'hello hello hello'.split(' ', -1)").expect("This must be valid CEL"); + assert_eq!(e.eval(), vec!["hello", "hello", "hello"].into()); + + let e = Expression::new("'tacocat'.substring(4)").expect("This must be valid CEL"); + assert_eq!(e.eval(), "cat".into()); + let e = Expression::new("'tacocat'.substring(0, 4)").expect("This must be valid CEL"); + assert_eq!(e.eval(), "taco".into()); + } +}