Skip to content

Commit

Permalink
feat: Adds query sdk proc-macro.
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Chernicoff committed Sep 27, 2024
1 parent a3f9994 commit 3d20690
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 43 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ members = [
"plugins/dummy_sha256",
"plugins/dummy_sha256_sdk",
"sdk/rust",
"hipcheck-sdk-macros",
]

# Make sure Hipcheck is run with `cargo run`.
Expand Down
17 changes: 17 additions & 0 deletions hipcheck-sdk-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "hipcheck-sdk-macros"
description = "Helper macros for the `hipcheck-sdk` crate."
repository = "https://github.com/mitre/hipcheck"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"

[lib]
proc-macro = true

[dependencies]
anyhow = "1.0.89"
convert_case = "0.6.0"
proc-macro2 = "1.0.86"
quote = "1.0.37"
syn = { version = "2.0.77", features = ["full", "printing"] }
240 changes: 240 additions & 0 deletions hipcheck-sdk-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
// SPDX-License-Identifier: Apache-2.0

use convert_case::Casing;
use proc_macro::TokenStream;
use proc_macro2::Span;
use std::ops::Not;
use std::sync::{LazyLock, Mutex};
use syn::spanned::Spanned;
use syn::{parse_macro_input, Error, Ident, ItemFn, Meta, PatType};

static QUERIES: LazyLock<Mutex<Vec<NamedQuerySpec>>> = LazyLock::new(|| Mutex::new(vec![]));

#[allow(unused)]
#[derive(Debug, Clone)]
struct NamedQuerySpec {
pub struct_name: String,
pub function: String,
pub default: bool,
}

struct QuerySpec {
pub function: Ident,
pub input_type: syn::Type,
pub output_type: syn::Type,
pub default: bool,
}

/// Parse Path to confirm that it represents a Result<T: Serialize> and return the type T
fn parse_result_generic(p: &syn::Path) -> Result<syn::Type, Error> {
use syn::GenericArgument;
use syn::PathArguments;
// Assert it is a Result
// Panic: Safe to unwrap because there should be at least one element in the sequence
let last = p.segments.last().unwrap();
if last.ident != "Result" {
return Err(Error::new(
p.span(),
"Expected return type to be a Result<T: Serialize>",
));
}
match &last.arguments {
PathArguments::AngleBracketed(x) => {
let Some(GenericArgument::Type(ty)) = x.args.first() else {
return Err(Error::new(
p.span(),
"Expected return type to be a Result<T: Serialize>",
));
};
Ok(ty.clone())
}
_ => Err(Error::new(
p.span(),
"Expected return type to be a Result<T: Serialize>",
)),
}
}

/// Parse PatType to confirm that it contains a &mut PluginEngine
fn parse_plugin_engine(engine_arg: &PatType) -> Result<(), Error> {
if let syn::Type::Reference(type_reference) = engine_arg.ty.as_ref() {
if type_reference.mutability.is_some() {
if let syn::Type::Path(type_path) = type_reference.elem.as_ref() {
let last = type_path.path.segments.last().unwrap();
if last.ident == "PluginEngine" {
return Ok(());
}
}
}
}

Err(Error::new(
engine_arg.span(),
"The first argument of the query function must be a &mut PluginEngine",
))
}

fn parse_named_query_spec(opt_meta: Option<Meta>, item_fn: ItemFn) -> Result<QuerySpec, Error> {
use syn::Meta::*;
use syn::ReturnType;
let sig = &item_fn.sig;

let function = sig.ident.clone();

let input_type: syn::Type = {
let inputs = &sig.inputs;
if inputs.len() != 2 {
return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
}
// Validate that the first arg is type &mut PluginEngine
if let Some(syn::FnArg::Typed(engine_arg)) = inputs.get(0) {
parse_plugin_engine(engine_arg)?;
}

if let Some(input_arg) = inputs.get(1) {
let syn::FnArg::Typed(input_arg_info) = input_arg else {
return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
};
input_arg_info.ty.as_ref().clone()
} else {
return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
}
};

let output_type = match &sig.output {
ReturnType::Default => {
return Err(Error::new(
item_fn.span(),
"Query function must return Result<T: Serialize>",
));
}
ReturnType::Type(_, b_type) => {
use syn::Type;
match b_type.as_ref() {
Type::Path(p) => parse_result_generic(&p.path)?,
_ => {
return Err(Error::new(
item_fn.span(),
"Query function must return Result<T: Serialize>",
))
}
}
}
};

let default = match opt_meta {
Some(NameValue(nv)) => {
// Panic: Safe to unwrap because there should be at least one element in the sequence
if nv.path.segments.first().unwrap().ident == "default" {
match nv.value {
syn::Expr::Lit(e) => match e.lit {
syn::Lit::Bool(s) => s.value,
_ => {
return Err(Error::new(
item_fn.span(),
"Default field on query function options must have a Boolean value",
));
}
},
_ => {
return Err(Error::new(
item_fn.span(),
"Default field on query function options must have a Boolean value",
));
}
}
} else {
return Err(Error::new(
item_fn.span(),
"Default field must be set if options are included for the query function",
));
}
}
Some(Path(p)) => {
let seg: &syn::PathSegment = p.segments.first().unwrap();
if seg.ident == "default" {
match seg.arguments {
syn::PathArguments::None => true,
_ => return Err(Error::new(item_fn.span(), "Default field in query options path cannot have any parenthized or bracketed arguments")),
}
} else {
return Err(Error::new(
item_fn.span(),
"Default field must be set if options are included for the query function",
));
}
}
None => false,
_ => {
return Err(Error::new(
item_fn.span(),
"Cannot parse query function options",
));
}
};

Ok(QuerySpec {
function,
default,
input_type,
output_type,
})
}

#[proc_macro_attribute]
pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut to_return = proc_macro2::TokenStream::from(item.clone());
let item_fn = parse_macro_input!(item as ItemFn);
let opt_meta: Option<Meta> = if attr.is_empty().not() {
Some(parse_macro_input!(attr as Meta))
} else {
None
};
let spec = match parse_named_query_spec(opt_meta, item_fn) {
Ok(span) => span,
Err(err) => return err.to_compile_error().into(),
};

let struct_name = Ident::new(
spec.function
.to_string()
.to_case(convert_case::Case::Pascal)
.as_str(),
Span::call_site(),
);
let ident = &spec.function;
let input_type = spec.input_type;
let output_type = spec.output_type;

let to_follow = quote::quote! {
struct #struct_name {}

#[hipcheck_sdk::prelude::async_trait]
impl hipcheck_sdk::prelude::Query for #struct_name {
fn input_schema(&self) -> hipcheck_sdk::prelude::JsonSchema {
hipcheck_sdk::prelude::schema_for!(#input_type).schema
}

fn output_schema(&self) -> hipcheck_sdk::prelude::JsonSchema {
hipcheck_sdk::prelude::schema_for!(#output_type).schema
}

async fn run(&self, engine: &mut hipcheck_sdk::prelude::PluginEngine, input: hipcheck_sdk::prelude::Value) -> hipcheck_sdk::prelude::Result<hipcheck_sdk::prelude::Value> {
let input = hipcheck_sdk::prelude::from_value(input).map_err(|_|
hipcheck_sdk::prelude::Error::UnexpectedPluginQueryInputFormat)?;
let output = #ident(engine, input).await?;
hipcheck_sdk::prelude::to_value(output).map_err(|_|
hipcheck_sdk::prelude::Error::UnexpectedPluginQueryOutputFormat)
}
}
};

QUERIES.lock().unwrap().push(NamedQuerySpec {
struct_name: struct_name.to_string(),
function: spec.function.to_string(),
default: spec.default,
});

to_return.extend(to_follow);
proc_macro::TokenStream::from(to_return)
}
2 changes: 1 addition & 1 deletion plugins/dummy_sha256_sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ publish = false

[dependencies]
clap = { version = "4.5.18", features = ["derive"] }
hipcheck-sdk = { path = "../../sdk/rust" }
hipcheck-sdk = { path = "../../sdk/rust", features = ["macros"]}
sha2 = "0.10.8"
tokio = { version = "1.40.0", features = ["rt"] }
45 changes: 5 additions & 40 deletions plugins/dummy_sha256_sdk/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,17 @@ use clap::Parser;
use hipcheck_sdk::prelude::*;
use sha2::{Digest, Sha256};

static SHA256_KEY_SCHEMA: &str = include_str!("../schema/query_schema_sha256.json");
static SHA256_OUTPUT_SCHEMA: &str = include_str!("../schema/query_schema_sha256.json");

/// calculate sha256 of provided content
fn sha256(content: &[u8]) -> Vec<u8> {
#[query(default = false)]
async fn query_sha256(_engine: &mut PluginEngine, content: Vec<u8>) -> Result<Vec<u8>> {
let mut hasher = Sha256::new();
hasher.update(content);
hasher.finalize().to_vec()
hasher.update(content.as_slice());
Ok(hasher.finalize().to_vec())
}

/// This plugin takes in a Value::Array(Vec<Value::Number>) and calculates its sha256
#[derive(Clone, Debug)]
struct Sha256Plugin;

#[async_trait]
impl Query for Sha256Plugin {
fn input_schema(&self) -> JsonSchema {
from_str(SHA256_KEY_SCHEMA).unwrap()
}

fn output_schema(&self) -> JsonSchema {
from_str(SHA256_OUTPUT_SCHEMA).unwrap()
}

async fn run(
&self,
_engine: &mut PluginEngine,
input: Value,
) -> hipcheck_sdk::error::Result<Value> {
let Value::Array(data) = &input else {
return Err(Error::UnexpectedPluginQueryInputFormat);
};

let data = data
.iter()
.map(|elem| elem.as_u64().map(|num| num as u8))
.collect::<Option<Vec<_>>>()
.ok_or_else(|| Error::UnexpectedPluginQueryInputFormat)?;

let hash = sha256(&data);
// convert to Value
let hash = hash.iter().map(|x| Value::Number((*x).into())).collect();
Ok(hash)
}
}

impl Plugin for Sha256Plugin {
const PUBLISHER: &'static str = "dummy";

Expand All @@ -70,7 +35,7 @@ impl Plugin for Sha256Plugin {
fn queries(&self) -> impl Iterator<Item = NamedQuery> {
vec![NamedQuery {
name: "sha256",
inner: Box::new(Sha256Plugin),
inner: Box::new(QuerySha256 {}),
}]
.into_iter()
}
Expand Down
4 changes: 4 additions & 0 deletions sdk/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ tokio = { version = "1.39.2", features = ["rt"] }
tokio-stream = "0.1.15"
tonic = "0.12.1"
schemars = "0.8.21"
hipcheck-sdk-macros = { path = "../../hipcheck-sdk-macros", version = "0.1.0", optional = true }

[build-dependencies]
anyhow = "1.0.86"
tonic-build = "0.12.1"

[features]
macros = ["hipcheck-sdk-macros"]
4 changes: 4 additions & 0 deletions sdk/rust/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ pub enum Error {
#[error("unexpected JSON value from plugin")]
UnexpectedPluginQueryInputFormat,

/// The `Query::run` function implementation produced an output that cannot be serialized to JSON
#[error("plugin output could not be serialized to JSON")]
UnexpectedPluginQueryOutputFormat,

/// The `PluginEngine` received a request for an unknown query endpoint
#[error("could not determine which plugin query to run")]
UnknownPluginQuery,
Expand Down
Loading

0 comments on commit 3d20690

Please sign in to comment.