Skip to content

Commit

Permalink
feat: Adds datetime and span primitives to the policy expression syntax.
Browse files Browse the repository at this point in the history
  • Loading branch information
mchernicoff committed Sep 18, 2024
1 parent b49ef7b commit 36356e9
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 2 deletions.
26 changes: 26 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions hipcheck/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ indexmap = "2.5.0"
indextree = "4.6.1"
indicatif = { version = "0.17.8", features = ["rayon"] }
itertools = "0.13.0"
jiff = "0.1.13"
kdl = "4.6.0"
log = "0.4.22"
logos = "0.14.0"
Expand Down
135 changes: 135 additions & 0 deletions hipcheck/src/policy_exprs/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use crate::policy_exprs::{eval, Error, Expr, Ident, Primitive, Result, F64};
use itertools::Itertools as _;
use jiff::{Span, Zoned};
use std::{cmp::Ordering, collections::HashMap, ops::Not as _};
use Expr::*;
use Primitive::*;
Expand Down Expand Up @@ -228,6 +229,12 @@ enum ArrayType {
/// An array of bools.
Bool(Vec<bool>),

/// An array of datetimes.
DateTime(Vec<Zoned>),

/// An array of time spans.
Span(Vec<Span>),

/// An empty array (no type hints).
Empty,
}
Expand Down Expand Up @@ -272,6 +279,29 @@ fn array_type(arr: &[Primitive]) -> Result<ArrayType> {
}
Ok(ArrayType::Bool(result))
}
DateTime(_) => {
let mut result: Vec<Zoned> = Vec::with_capacity(arr.len());
for elem in arr {
if let DateTime(val) = elem {
result.push(val.clone());
} else {
return Err(Error::InconsistentArrayTypes);
}
}
Ok(ArrayType::DateTime(result))
}
Span(_) => {
let mut result: Vec<Span> = Vec::with_capacity(arr.len());
for elem in arr {
if let Span(val) = elem {
result.push(*val);
} else {
return Err(Error::InconsistentArrayTypes);
}
}
Ok(ArrayType::Span(result))
}

Identifier(_) => unimplemented!("we don't currently support idents in arrays"),
}
}
Expand Down Expand Up @@ -455,6 +485,8 @@ fn not(env: &Env, args: &[Expr]) -> Result<Expr> {
Int(_) => Err(Error::BadType(name)),
Float(_) => Err(Error::BadType(name)),
Bool(arg) => Ok(Primitive::Bool(arg.not())),
DateTime(_) => Err(Error::BadType(name)),
Span(_) => Err(Error::BadType(name)),
Identifier(_) => unreachable!("no idents should be here"),
};

Expand All @@ -480,6 +512,8 @@ fn max(env: &Env, args: &[Expr]) -> Result<Expr> {
.map(|m| Primitive(Float(m))),

ArrayType::Bool(_) => Err(Error::BadType(name)),
ArrayType::DateTime(_) => Err(Error::BadType(name)),
ArrayType::Span(_) => Err(Error::BadType(name)),
ArrayType::Empty => Err(Error::NoMax),
};

Expand All @@ -505,6 +539,8 @@ fn min(env: &Env, args: &[Expr]) -> Result<Expr> {
.map(|m| Primitive(Float(m))),

ArrayType::Bool(_) => Err(Error::BadType(name)),
ArrayType::DateTime(_) => Err(Error::BadType(name)),
ArrayType::Span(_) => Err(Error::BadType(name)),
ArrayType::Empty => Err(Error::NoMin),
};

Expand All @@ -528,6 +564,8 @@ fn avg(env: &Env, args: &[Expr]) -> Result<Expr> {
}

ArrayType::Bool(_) => Err(Error::BadType(name)),
ArrayType::DateTime(_) => Err(Error::BadType(name)),
ArrayType::Span(_) => Err(Error::BadType(name)),
ArrayType::Empty => Err(Error::NoAvg),
};

Expand All @@ -549,6 +587,8 @@ fn median(env: &Env, args: &[Expr]) -> Result<Expr> {
Ok(Primitive(Float(floats[mid])))
}
ArrayType::Bool(_) => Err(Error::BadType(name)),
ArrayType::DateTime(_) => Err(Error::BadType(name)),
ArrayType::Span(_) => Err(Error::BadType(name)),
ArrayType::Empty => Err(Error::NoMedian),
};

Expand All @@ -562,6 +602,8 @@ fn count(env: &Env, args: &[Expr]) -> Result<Expr> {
ArrayType::Int(ints) => Ok(Primitive(Int(ints.len() as i64))),
ArrayType::Float(floats) => Ok(Primitive(Int(floats.len() as i64))),
ArrayType::Bool(bools) => Ok(Primitive(Int(bools.len() as i64))),
ArrayType::DateTime(dts) => Ok(Primitive(Int(dts.len() as i64))),
ArrayType::Span(spans) => Ok(Primitive(Int(spans.len() as i64))),
ArrayType::Empty => Ok(Primitive(Int(0))),
};

Expand Down Expand Up @@ -591,6 +633,18 @@ fn all(env: &Env, args: &[Expr]) -> Result<Expr> {
.process_results(|mut iter| {
iter.all(|expr| matches!(expr, Primitive(Bool(true))))
})?,
ArrayType::DateTime(dts) => dts
.iter()
.map(|val| eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone()))
.process_results(|mut iter| {
iter.all(|expr| matches!(expr, Primitive(Bool(true))))
})?,
ArrayType::Span(spans) => spans
.iter()
.map(|val| eval_lambda(env, &ident, Span(*val), (*body).clone()))
.process_results(|mut iter| {
iter.all(|expr| matches!(expr, Primitive(Bool(true))))
})?,
ArrayType::Empty => true,
};

Expand Down Expand Up @@ -623,6 +677,18 @@ fn nall(env: &Env, args: &[Expr]) -> Result<Expr> {
.process_results(|mut iter| {
iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not()
})?,
ArrayType::DateTime(dts) => dts
.iter()
.map(|val| eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone()))
.process_results(|mut iter| {
iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not()
})?,
ArrayType::Span(spans) => spans
.iter()
.map(|val| eval_lambda(env, &ident, Span(*val), (*body).clone()))
.process_results(|mut iter| {
iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not()
})?,
ArrayType::Empty => false,
};

Expand Down Expand Up @@ -655,6 +721,18 @@ fn some(env: &Env, args: &[Expr]) -> Result<Expr> {
.process_results(|mut iter| {
iter.any(|expr| matches!(expr, Primitive(Bool(true))))
})?,
ArrayType::DateTime(dts) => dts
.iter()
.map(|val| eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone()))
.process_results(|mut iter| {
iter.any(|expr| matches!(expr, Primitive(Bool(true))))
})?,
ArrayType::Span(spans) => spans
.iter()
.map(|val| eval_lambda(env, &ident, Span(*val), (*body).clone()))
.process_results(|mut iter| {
iter.any(|expr| matches!(expr, Primitive(Bool(true))))
})?,
ArrayType::Empty => false,
};

Expand Down Expand Up @@ -687,6 +765,18 @@ fn none(env: &Env, args: &[Expr]) -> Result<Expr> {
.process_results(|mut iter| {
iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not()
})?,
ArrayType::DateTime(dts) => dts
.iter()
.map(|val| eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone()))
.process_results(|mut iter| {
iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not()
})?,
ArrayType::Span(spans) => spans
.iter()
.map(|val| eval_lambda(env, &ident, Span(*val), (*body).clone()))
.process_results(|mut iter| {
iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not()
})?,
ArrayType::Empty => true,
};

Expand Down Expand Up @@ -734,6 +824,33 @@ fn filter(env: &Env, args: &[Expr]) -> Result<Expr> {
}
})
.collect::<Result<Vec<_>>>()?,
ArrayType::DateTime(dts) => dts
.iter()
.map(|val| {
Ok((
val,
eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone()),
))
})
.filter_map_ok(|(val, expr)| {
if let Ok(Primitive(Bool(true))) = expr {
Some(Primitive::DateTime(val.clone()))
} else {
None
}
})
.collect::<Result<Vec<_>>>()?,
ArrayType::Span(spans) => spans
.iter()
.map(|val| Ok((val, eval_lambda(env, &ident, Span(*val), (*body).clone()))))
.filter_map_ok(|(val, expr)| {
if let Ok(Primitive(Bool(true))) = expr {
Some(Primitive::Span(*val))
} else {
None
}
})
.collect::<Result<Vec<_>>>()?,
ArrayType::Empty => Vec::new(),
};

Expand Down Expand Up @@ -775,6 +892,24 @@ fn foreach(env: &Env, args: &[Expr]) -> Result<Expr> {
Err(err) => Err(err),
})
.collect::<Result<Vec<_>>>()?,
ArrayType::DateTime(dts) => dts
.iter()
.map(|val| eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone()))
.map(|expr| match expr {
Ok(Primitive(inner)) => Ok(inner),
Ok(_) => Err(Error::BadType(name)),
Err(err) => Err(err),
})
.collect::<Result<Vec<_>>>()?,
ArrayType::Span(spans) => spans
.iter()
.map(|val| eval_lambda(env, &ident, Span(*val), (*body).clone()))
.map(|expr| match expr {
Ok(Primitive(inner)) => Ok(inner),
Ok(_) => Err(Error::BadType(name)),
Err(err) => Err(err),
})
.collect::<Result<Vec<_>>>()?,
ArrayType::Empty => Vec::new(),
};

Expand Down
22 changes: 22 additions & 0 deletions hipcheck/src/policy_exprs/error.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// SPDX-License-Identifier: Apache-2.0

use crate::policy_exprs::{Expr, Ident, LexingError};
use jiff::Error as JError;
use nom::{error::ErrorKind, Needed};
use ordered_float::FloatIsNan;
use std::fmt;

/// `Result` which uses [`Error`].
pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -126,6 +128,26 @@ pub enum UnrepresentableJSONType {
JSONNull,
}

// Custom error to handle jiff's native error not impl PartialEq
// We exploit the fact that it *does* impl Display
#[derive(Clone, Debug, thiserror::Error, PartialEq)]
pub struct JiffError {
jiff_error: String,
}

impl JiffError {
pub fn new(err: JError) -> Self {
let msg = err.to_string();
JiffError { jiff_error: msg }
}
}

impl fmt::Display for JiffError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.jiff_error)
}
}

fn needed_str(needed: &Needed) -> String {
match needed {
Needed::Unknown => String::from(""),
Expand Down
36 changes: 35 additions & 1 deletion hipcheck/src/policy_exprs/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::policy_exprs::{
Error, Result, Tokens,
};
use itertools::Itertools;
use jiff::{Span, Zoned};
use nom::{
branch::alt,
combinator::{all_consuming, map},
Expand Down Expand Up @@ -49,6 +50,21 @@ pub enum Primitive {

/// Boolean.
Bool(bool),

/// Date-time value with timezone information using the [jiff] crate, which uses a modified version of ISO8601.
/// This must include a date in the format <YYYY>-<MM>-<DD>.
/// An optional time in the format T<HH>:[MM]:[SS] will be accepted after the date.
/// Decimal fractions of hours and minutes are not allowed; use smaller time units instead (e.g. T10:30 instead of T10.5). Decimal fractions of seconds are allowed.
/// The timezone is always set to UTC, but you can set an offeset from UTC by including +{HH}:[MM] or -{HH}:[MM]. The time will be adjusted to the correct UTC time during parsing.
DateTime(Zoned),

/// Span of time using the [jiff] crate, which uses a modified version of ISO8601.
/// Can include years, months, weeks, days, hours, minutes, and seconds (including decimal fractions of a second).
/// Spans are preceded by the letter "P" with any optional time units separated from optional date units by the letter "T".
/// All units of dates and times are represented by single case-agnostic letter abbreviations after the number.
/// For example, a span of one year, one month, one week, one day, one hour, one minute, and one-and-a-tenth seconds would be represented as
/// "P1y1m1w1dT1h1m1.1s"
Span(Span),
}

/// A variable or function identifier.
Expand Down Expand Up @@ -89,6 +105,8 @@ impl Display for Primitive {
Primitive::Int(i) => write!(f, "{}", i),
Primitive::Float(fl) => write!(f, "{}", fl),
Primitive::Bool(b) => write!(f, "{}", if *b { "#t" } else { "#f" }),
Primitive::DateTime(dt) => write!(f, "{}", dt),
Primitive::Span(span) => write!(f, "{}", span),
}
}
}
Expand Down Expand Up @@ -141,6 +159,16 @@ crate::data_variant_parser! {
pattern = Token::Bool(b) => Primitive::Bool(b);
}

crate::data_variant_parser! {
fn parse_datetime(input) -> Result<Primitive>;
pattern = Token::DateTime(dt) => Primitive::DateTime(*dt);
}

crate::data_variant_parser! {
fn parse_span(input) -> Result<Primitive>;
pattern = Token::Span(span) => Primitive::Span(*span);
}

crate::data_variant_parser! {
fn parse_ident(input) -> Result<String>;
pattern = Token::Ident(s) => s.to_owned();
Expand All @@ -156,7 +184,13 @@ pub type Input<'source> = Tokens<'source, Token>;

/// Parse a single piece of primitive data.
fn parse_primitive(input: Input<'_>) -> IResult<Input<'_>, Primitive> {
alt((parse_integer, parse_float, parse_bool))(input)
alt((
parse_integer,
parse_float,
parse_bool,
parse_datetime,
parse_span,
))(input)
}

/// Parse an array.
Expand Down
Loading

0 comments on commit 36356e9

Please sign in to comment.