-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
292 additions
and
43 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
[package] | ||
name = "hipcheck-sdk-macros" | ||
description = "Helper macros for the `hipcheck-sdk` crate." | ||
repository = "https://github.com/mitre/hipcheck" | ||
version = "0.1.0" | ||
edition = "2021" | ||
license = "Apache-2.0" | ||
|
||
[lib] | ||
proc-macro = true | ||
|
||
[dependencies] | ||
anyhow = "1.0.89" | ||
convert_case = "0.6.0" | ||
proc-macro2 = "1.0.86" | ||
quote = "1.0.37" | ||
syn = { version = "2.0.77", features = ["full", "printing"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
use convert_case::Casing; | ||
use proc_macro::TokenStream; | ||
use proc_macro2::Span; | ||
use std::ops::Not; | ||
use std::sync::{LazyLock, Mutex}; | ||
use syn::spanned::Spanned; | ||
use syn::{parse_macro_input, Error, Ident, ItemFn, Meta, PatType}; | ||
|
||
static QUERIES: LazyLock<Mutex<Vec<NamedQuerySpec>>> = LazyLock::new(|| Mutex::new(vec![])); | ||
|
||
#[allow(unused)] | ||
#[derive(Debug, Clone)] | ||
struct NamedQuerySpec { | ||
pub struct_name: String, | ||
pub function: String, | ||
pub default: bool, | ||
} | ||
|
||
struct QuerySpec { | ||
pub function: Ident, | ||
pub input_type: syn::Type, | ||
pub output_type: syn::Type, | ||
pub default: bool, | ||
} | ||
|
||
/// Parse Path to confirm that it represents a Result<T: Serialize> and return the type T | ||
fn parse_result_generic(p: &syn::Path) -> Result<syn::Type, Error> { | ||
use syn::GenericArgument; | ||
use syn::PathArguments; | ||
// Assert it is a Result | ||
// Panic: Safe to unwrap because there should be at least one element in the sequence | ||
let last = p.segments.last().unwrap(); | ||
if last.ident != "Result" { | ||
return Err(Error::new( | ||
p.span(), | ||
"Expected return type to be a Result<T: Serialize>", | ||
)); | ||
} | ||
match &last.arguments { | ||
PathArguments::AngleBracketed(x) => { | ||
let Some(GenericArgument::Type(ty)) = x.args.first() else { | ||
return Err(Error::new( | ||
p.span(), | ||
"Expected return type to be a Result<T: Serialize>", | ||
)); | ||
}; | ||
Ok(ty.clone()) | ||
} | ||
_ => Err(Error::new( | ||
p.span(), | ||
"Expected return type to be a Result<T: Serialize>", | ||
)), | ||
} | ||
} | ||
|
||
/// Parse PatType to confirm that it contains a &mut PluginEngine | ||
fn parse_plugin_engine(engine_arg: &PatType) -> Result<(), Error> { | ||
if let syn::Type::Reference(type_reference) = engine_arg.ty.as_ref() { | ||
if type_reference.mutability.is_some() { | ||
if let syn::Type::Path(type_path) = type_reference.elem.as_ref() { | ||
let last = type_path.path.segments.last().unwrap(); | ||
if last.ident == "PluginEngine" { | ||
return Ok(()); | ||
} | ||
} | ||
} | ||
} | ||
|
||
Err(Error::new( | ||
engine_arg.span(), | ||
"The first argument of the query function must be a &mut PluginEngine", | ||
)) | ||
} | ||
|
||
fn parse_named_query_spec(opt_meta: Option<Meta>, item_fn: ItemFn) -> Result<QuerySpec, Error> { | ||
use syn::Meta::*; | ||
use syn::ReturnType; | ||
let sig = &item_fn.sig; | ||
|
||
let function = sig.ident.clone(); | ||
|
||
let input_type: syn::Type = { | ||
let inputs = &sig.inputs; | ||
if inputs.len() != 2 { | ||
return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize")); | ||
} | ||
// Validate that the first arg is type &mut PluginEngine | ||
if let Some(syn::FnArg::Typed(engine_arg)) = inputs.get(0) { | ||
parse_plugin_engine(engine_arg)?; | ||
} | ||
|
||
if let Some(input_arg) = inputs.get(1) { | ||
let syn::FnArg::Typed(input_arg_info) = input_arg else { | ||
return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize")); | ||
}; | ||
input_arg_info.ty.as_ref().clone() | ||
} else { | ||
return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize")); | ||
} | ||
}; | ||
|
||
let output_type = match &sig.output { | ||
ReturnType::Default => { | ||
return Err(Error::new( | ||
item_fn.span(), | ||
"Query function must return Result<T: Serialize>", | ||
)); | ||
} | ||
ReturnType::Type(_, b_type) => { | ||
use syn::Type; | ||
match b_type.as_ref() { | ||
Type::Path(p) => parse_result_generic(&p.path)?, | ||
_ => { | ||
return Err(Error::new( | ||
item_fn.span(), | ||
"Query function must return Result<T: Serialize>", | ||
)) | ||
} | ||
} | ||
} | ||
}; | ||
|
||
let default = match opt_meta { | ||
Some(NameValue(nv)) => { | ||
// Panic: Safe to unwrap because there should be at least one element in the sequence | ||
if nv.path.segments.first().unwrap().ident == "default" { | ||
match nv.value { | ||
syn::Expr::Lit(e) => match e.lit { | ||
syn::Lit::Bool(s) => s.value, | ||
_ => { | ||
return Err(Error::new( | ||
item_fn.span(), | ||
"Default field on query function options must have a Boolean value", | ||
)); | ||
} | ||
}, | ||
_ => { | ||
return Err(Error::new( | ||
item_fn.span(), | ||
"Default field on query function options must have a Boolean value", | ||
)); | ||
} | ||
} | ||
} else { | ||
return Err(Error::new( | ||
item_fn.span(), | ||
"Default field must be set if options are included for the query function", | ||
)); | ||
} | ||
} | ||
Some(Path(p)) => { | ||
let seg: &syn::PathSegment = p.segments.first().unwrap(); | ||
if seg.ident == "default" { | ||
match seg.arguments { | ||
syn::PathArguments::None => true, | ||
_ => return Err(Error::new(item_fn.span(), "Default field in query options path cannot have any parenthized or bracketed arguments")), | ||
} | ||
} else { | ||
return Err(Error::new( | ||
item_fn.span(), | ||
"Default field must be set if options are included for the query function", | ||
)); | ||
} | ||
} | ||
None => false, | ||
_ => { | ||
return Err(Error::new( | ||
item_fn.span(), | ||
"Cannot parse query function options", | ||
)); | ||
} | ||
}; | ||
|
||
Ok(QuerySpec { | ||
function, | ||
default, | ||
input_type, | ||
output_type, | ||
}) | ||
} | ||
|
||
#[proc_macro_attribute] | ||
pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream { | ||
let mut to_return = proc_macro2::TokenStream::from(item.clone()); | ||
let item_fn = parse_macro_input!(item as ItemFn); | ||
let opt_meta: Option<Meta> = if attr.is_empty().not() { | ||
Some(parse_macro_input!(attr as Meta)) | ||
} else { | ||
None | ||
}; | ||
let spec = match parse_named_query_spec(opt_meta, item_fn) { | ||
Ok(span) => span, | ||
Err(err) => return err.to_compile_error().into(), | ||
}; | ||
|
||
let struct_name = Ident::new( | ||
spec.function | ||
.to_string() | ||
.to_case(convert_case::Case::Pascal) | ||
.as_str(), | ||
Span::call_site(), | ||
); | ||
let ident = &spec.function; | ||
let input_type = spec.input_type; | ||
let output_type = spec.output_type; | ||
|
||
let to_follow = quote::quote! { | ||
struct #struct_name {} | ||
|
||
#[hipcheck_sdk::prelude::async_trait] | ||
impl hipcheck_sdk::prelude::Query for #struct_name { | ||
fn input_schema(&self) -> hipcheck_sdk::prelude::JsonSchema { | ||
hipcheck_sdk::prelude::schema_for!(#input_type).schema | ||
} | ||
|
||
fn output_schema(&self) -> hipcheck_sdk::prelude::JsonSchema { | ||
hipcheck_sdk::prelude::schema_for!(#output_type).schema | ||
} | ||
|
||
async fn run(&self, engine: &mut hipcheck_sdk::prelude::PluginEngine, input: hipcheck_sdk::prelude::Value) -> hipcheck_sdk::prelude::Result<hipcheck_sdk::prelude::Value> { | ||
let input = hipcheck_sdk::prelude::from_value(input).map_err(|_| | ||
hipcheck_sdk::prelude::Error::UnexpectedPluginQueryInputFormat)?; | ||
let output = #ident(engine, input).await?; | ||
hipcheck_sdk::prelude::to_value(output).map_err(|_| | ||
hipcheck_sdk::prelude::Error::UnexpectedPluginQueryOutputFormat) | ||
} | ||
} | ||
}; | ||
|
||
QUERIES.lock().unwrap().push(NamedQuerySpec { | ||
struct_name: struct_name.to_string(), | ||
function: spec.function.to_string(), | ||
default: spec.default, | ||
}); | ||
|
||
to_return.extend(to_follow); | ||
proc_macro::TokenStream::from(to_return) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.