Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adds query sdk proc-macro #452

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ members = [
"plugins/dummy_sha256",
"plugins/dummy_sha256_sdk",
"sdk/rust",
"hipcheck-sdk-macros",
]

# Make sure Hipcheck is run with `cargo run`.
Expand Down
17 changes: 17 additions & 0 deletions hipcheck-sdk-macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "hipcheck-sdk-macros"
description = "Helper macros for the `hipcheck-sdk` crate."
repository = "https://github.com/mitre/hipcheck"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"

[lib]
proc-macro = true

[dependencies]
anyhow = "1.0.89"
convert_case = "0.6.0"
proc-macro2 = "1.0.86"
quote = "1.0.37"
syn = { version = "2.0.77", features = ["full", "printing"] }
240 changes: 240 additions & 0 deletions hipcheck-sdk-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
// SPDX-License-Identifier: Apache-2.0

use convert_case::Casing;
use proc_macro::TokenStream;
use proc_macro2::Span;
use std::ops::Not;
use std::sync::{LazyLock, Mutex};
use syn::spanned::Spanned;
use syn::{parse_macro_input, Error, Ident, ItemFn, Meta, PatType};

static QUERIES: LazyLock<Mutex<Vec<NamedQuerySpec>>> = LazyLock::new(|| Mutex::new(vec![]));

#[allow(unused)]
#[derive(Debug, Clone)]
struct NamedQuerySpec {
pub struct_name: String,
pub function: String,
pub default: bool,
}

struct QuerySpec {
pub function: Ident,
pub input_type: syn::Type,
pub output_type: syn::Type,
pub default: bool,
}

/// Parse Path to confirm that it represents a Result<T: Serialize> and return the type T
fn parse_result_generic(p: &syn::Path) -> Result<syn::Type, Error> {
use syn::GenericArgument;
use syn::PathArguments;
// Assert it is a Result
// Panic: Safe to unwrap because there should be at least one element in the sequence
let last = p.segments.last().unwrap();
if last.ident != "Result" {
return Err(Error::new(
p.span(),
"Expected return type to be a Result<T: Serialize>",
));
}
match &last.arguments {
PathArguments::AngleBracketed(x) => {
let Some(GenericArgument::Type(ty)) = x.args.first() else {
return Err(Error::new(
p.span(),
"Expected return type to be a Result<T: Serialize>",
));
};
Ok(ty.clone())
}
_ => Err(Error::new(
p.span(),
"Expected return type to be a Result<T: Serialize>",
)),
}
}

/// Parse PatType to confirm that it contains a &mut PluginEngine
fn parse_plugin_engine(engine_arg: &PatType) -> Result<(), Error> {
if let syn::Type::Reference(type_reference) = engine_arg.ty.as_ref() {
if type_reference.mutability.is_some() {
if let syn::Type::Path(type_path) = type_reference.elem.as_ref() {
let last = type_path.path.segments.last().unwrap();
if last.ident == "PluginEngine" {
return Ok(());
}
}
}
}

Err(Error::new(
engine_arg.span(),
"The first argument of the query function must be a &mut PluginEngine",
))
}

fn parse_named_query_spec(opt_meta: Option<Meta>, item_fn: ItemFn) -> Result<QuerySpec, Error> {
use syn::Meta::*;
use syn::ReturnType;
let sig = &item_fn.sig;

let function = sig.ident.clone();

let input_type: syn::Type = {
let inputs = &sig.inputs;
if inputs.len() != 2 {
return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
}
// Validate that the first arg is type &mut PluginEngine
if let Some(syn::FnArg::Typed(engine_arg)) = inputs.get(0) {
parse_plugin_engine(engine_arg)?;
}

if let Some(input_arg) = inputs.get(1) {
let syn::FnArg::Typed(input_arg_info) = input_arg else {
return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
};
input_arg_info.ty.as_ref().clone()
} else {
return Err(Error::new(item_fn.span(), "Query function must take two arguments: &mut PluginEngine, and an input type that implements Serialize"));
}
};

let output_type = match &sig.output {
ReturnType::Default => {
return Err(Error::new(
item_fn.span(),
"Query function must return Result<T: Serialize>",
));
}
ReturnType::Type(_, b_type) => {
use syn::Type;
match b_type.as_ref() {
Type::Path(p) => parse_result_generic(&p.path)?,
_ => {
return Err(Error::new(
item_fn.span(),
"Query function must return Result<T: Serialize>",
))
}
}
}
};

let default = match opt_meta {
Some(NameValue(nv)) => {
// Panic: Safe to unwrap because there should be at least one element in the sequence
if nv.path.segments.first().unwrap().ident == "default" {
match nv.value {
syn::Expr::Lit(e) => match e.lit {
syn::Lit::Bool(s) => s.value,
_ => {
return Err(Error::new(
item_fn.span(),
"Default field on query function options must have a Boolean value",
));
}
},
_ => {
return Err(Error::new(
item_fn.span(),
"Default field on query function options must have a Boolean value",
));
}
}
} else {
return Err(Error::new(
item_fn.span(),
"Default field must be set if options are included for the query function",
));
}
}
Some(Path(p)) => {
let seg: &syn::PathSegment = p.segments.first().unwrap();
if seg.ident == "default" {
match seg.arguments {
syn::PathArguments::None => true,
_ => return Err(Error::new(item_fn.span(), "Default field in query options path cannot have any parenthized or bracketed arguments")),
}
} else {
return Err(Error::new(
item_fn.span(),
"Default field must be set if options are included for the query function",
));
}
}
None => false,
_ => {
return Err(Error::new(
item_fn.span(),
"Cannot parse query function options",
));
}
};

Ok(QuerySpec {
function,
default,
input_type,
output_type,
})
}

#[proc_macro_attribute]
pub fn query(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut to_return = proc_macro2::TokenStream::from(item.clone());
let item_fn = parse_macro_input!(item as ItemFn);
let opt_meta: Option<Meta> = if attr.is_empty().not() {
Some(parse_macro_input!(attr as Meta))
} else {
None
};
let spec = match parse_named_query_spec(opt_meta, item_fn) {
Ok(span) => span,
Err(err) => return err.to_compile_error().into(),
};

let struct_name = Ident::new(
spec.function
.to_string()
.to_case(convert_case::Case::Pascal)
.as_str(),
Span::call_site(),
);
let ident = &spec.function;
let input_type = spec.input_type;
let output_type = spec.output_type;

let to_follow = quote::quote! {
struct #struct_name {}

#[hipcheck_sdk::prelude::async_trait]
impl hipcheck_sdk::prelude::Query for #struct_name {
fn input_schema(&self) -> hipcheck_sdk::prelude::JsonSchema {
hipcheck_sdk::prelude::schema_for!(#input_type).schema
}

fn output_schema(&self) -> hipcheck_sdk::prelude::JsonSchema {
hipcheck_sdk::prelude::schema_for!(#output_type).schema
}

async fn run(&self, engine: &mut hipcheck_sdk::prelude::PluginEngine, input: hipcheck_sdk::prelude::Value) -> hipcheck_sdk::prelude::Result<hipcheck_sdk::prelude::Value> {
let input = hipcheck_sdk::prelude::from_value(input).map_err(|_|
hipcheck_sdk::prelude::Error::UnexpectedPluginQueryInputFormat)?;
let output = #ident(engine, input).await?;
hipcheck_sdk::prelude::to_value(output).map_err(|_|
hipcheck_sdk::prelude::Error::UnexpectedPluginQueryOutputFormat)
}
}
};

QUERIES.lock().unwrap().push(NamedQuerySpec {
struct_name: struct_name.to_string(),
function: spec.function.to_string(),
default: spec.default,
});

to_return.extend(to_follow);
proc_macro::TokenStream::from(to_return)
}
2 changes: 1 addition & 1 deletion plugins/dummy_sha256_sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ publish = false

[dependencies]
clap = { version = "4.5.18", features = ["derive"] }
hipcheck-sdk = { path = "../../sdk/rust" }
hipcheck-sdk = { path = "../../sdk/rust", features = ["macros"]}
sha2 = "0.10.8"
tokio = { version = "1.40.0", features = ["rt"] }
45 changes: 5 additions & 40 deletions plugins/dummy_sha256_sdk/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,17 @@ use clap::Parser;
use hipcheck_sdk::prelude::*;
use sha2::{Digest, Sha256};

static SHA256_KEY_SCHEMA: &str = include_str!("../schema/query_schema_sha256.json");
static SHA256_OUTPUT_SCHEMA: &str = include_str!("../schema/query_schema_sha256.json");

/// calculate sha256 of provided content
fn sha256(content: &[u8]) -> Vec<u8> {
#[query(default = false)]
async fn query_sha256(_engine: &mut PluginEngine, content: Vec<u8>) -> Result<Vec<u8>> {
let mut hasher = Sha256::new();
hasher.update(content);
hasher.finalize().to_vec()
hasher.update(content.as_slice());
Ok(hasher.finalize().to_vec())
}

/// This plugin takes in a Value::Array(Vec<Value::Number>) and calculates its sha256
#[derive(Clone, Debug)]
struct Sha256Plugin;

#[async_trait]
impl Query for Sha256Plugin {
fn input_schema(&self) -> JsonSchema {
from_str(SHA256_KEY_SCHEMA).unwrap()
}

fn output_schema(&self) -> JsonSchema {
from_str(SHA256_OUTPUT_SCHEMA).unwrap()
}

async fn run(
&self,
_engine: &mut PluginEngine,
input: Value,
) -> hipcheck_sdk::error::Result<Value> {
let Value::Array(data) = &input else {
return Err(Error::UnexpectedPluginQueryInputFormat);
};

let data = data
.iter()
.map(|elem| elem.as_u64().map(|num| num as u8))
.collect::<Option<Vec<_>>>()
.ok_or_else(|| Error::UnexpectedPluginQueryInputFormat)?;

let hash = sha256(&data);
// convert to Value
let hash = hash.iter().map(|x| Value::Number((*x).into())).collect();
Ok(hash)
}
}

impl Plugin for Sha256Plugin {
const PUBLISHER: &'static str = "dummy";

Expand All @@ -70,7 +35,7 @@ impl Plugin for Sha256Plugin {
fn queries(&self) -> impl Iterator<Item = NamedQuery> {
vec![NamedQuery {
name: "sha256",
inner: Box::new(Sha256Plugin),
inner: Box::new(QuerySha256 {}),
}]
.into_iter()
}
Expand Down
4 changes: 4 additions & 0 deletions sdk/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ tokio = { version = "1.39.2", features = ["rt"] }
tokio-stream = "0.1.15"
tonic = "0.12.1"
schemars = "0.8.21"
hipcheck-sdk-macros = { path = "../../hipcheck-sdk-macros", version = "0.1.0", optional = true }

[build-dependencies]
anyhow = "1.0.86"
tonic-build = "0.12.1"

[features]
macros = ["hipcheck-sdk-macros"]
4 changes: 4 additions & 0 deletions sdk/rust/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ pub enum Error {
#[error("unexpected JSON value from plugin")]
UnexpectedPluginQueryInputFormat,

/// The `Query::run` function implementation produced an output that cannot be serialized to JSON
#[error("plugin output could not be serialized to JSON")]
UnexpectedPluginQueryOutputFormat,

/// The `PluginEngine` received a request for an unknown query endpoint
#[error("could not determine which plugin query to run")]
UnknownPluginQuery,
Expand Down
Loading