From a65f7d965e4cffdeb84c1a3edda6eeefbb8d9a9b Mon Sep 17 00:00:00 2001 From: jlanson Date: Wed, 28 Aug 2024 09:24:53 -0400 Subject: [PATCH] fix: dummy plugin refactored to use "Session" tracking so to overcome hanging --- Cargo.lock | 27 + Cargo.toml | 2 +- hipcheck/src/cli.rs | 12 +- hipcheck/src/engine.rs | 14 + hipcheck/src/main.rs | 120 ++-- hipcheck/src/plugin/manager.rs | 13 +- hipcheck/src/plugin/mod.rs | 1 + .../dummy_rand_data/src/hipcheck_transport.rs | 265 ++++--- plugins/dummy_rand_data/src/main.rs | 139 ++-- plugins/dummy_sha256/Cargo.toml | 17 + plugins/dummy_sha256/src/hipcheck.rs | 679 ++++++++++++++++++ .../dummy_sha256/src/hipcheck_transport.rs | 226 ++++++ plugins/dummy_sha256/src/main.rs | 181 +++++ .../dummy_sha256/src/query_schema_sha256.json | 3 + 14 files changed, 1505 insertions(+), 194 deletions(-) create mode 100644 plugins/dummy_sha256/Cargo.toml create mode 100644 plugins/dummy_sha256/src/hipcheck.rs create mode 100644 plugins/dummy_sha256/src/hipcheck_transport.rs create mode 100644 plugins/dummy_sha256/src/main.rs create mode 100644 plugins/dummy_sha256/src/query_schema_sha256.json diff --git a/Cargo.lock b/Cargo.lock index 2ebdb168..415eaf32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -709,6 +709,22 @@ dependencies = [ "tonic", ] +[[package]] +name = "dummy_sha256" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "indexmap 2.4.0", + "prost", + "rand", + "serde_json", + "sha2", + "tokio", + "tokio-stream", + "tonic", +] + [[package]] name = "dyn-clone" version = "1.0.17" @@ -2533,6 +2549,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha2" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sharded-slab" version = "0.1.7" diff --git a/Cargo.toml b/Cargo.toml index c4411856..f00e8d84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ resolver = "2" # Members of the workspace. -members = ["hipcheck", "hipcheck-macros", "xtask", "plugins/dummy_rand_data"] +members = ["hipcheck", "hipcheck-macros", "xtask", "plugins/dummy_rand_data", "plugins/dummy_sha256"] # Make sure Hipcheck is run with `cargo run`. # diff --git a/hipcheck/src/cli.rs b/hipcheck/src/cli.rs index d0e1d706..876222ce 100644 --- a/hipcheck/src/cli.rs +++ b/hipcheck/src/cli.rs @@ -398,7 +398,7 @@ pub enum FullCommands { Ready, Update(UpdateArgs), Cache(CacheArgs), - Plugin, + Plugin(PluginArgs), PrintConfig, PrintData, PrintCache, @@ -415,7 +415,7 @@ impl From<&Commands> for FullCommands { Commands::Scoring => FullCommands::Scoring, Commands::Update(args) => FullCommands::Update(args.clone()), Commands::Cache(args) => FullCommands::Cache(args.clone()), - Commands::Plugin => FullCommands::Plugin, + Commands::Plugin(args) => FullCommands::Plugin(args.clone()), } } } @@ -446,7 +446,7 @@ pub enum Commands { Cache(CacheArgs), /// Execute temporary code for exercising plugin engine #[command(hide = true)] - Plugin, + Plugin(PluginArgs), } // If no subcommand matched, default to use of '-t > for RepoCacheDeleteScope { } } +#[derive(Debug, Clone, clap::Args)] +pub struct PluginArgs { + #[arg(long = "async")] + pub asynch: bool, +} + /// Test CLI commands #[cfg(test)] mod tests { diff --git a/hipcheck/src/engine.rs b/hipcheck/src/engine.rs index 8559f840..b0fae8be 100644 --- a/hipcheck/src/engine.rs +++ b/hipcheck/src/engine.rs @@ -37,6 +37,7 @@ fn query( }; // Initiate the query. If remote closed or we got our response immediately, // return + println!("Querying {plugin}::{query} with key {key:?}"); let mut ar = match runtime.block_on(p_handle.query(query, key))? { PluginResponse::RemoteClosed => { return Err(hc_error!("Plugin channel closed unexpected")); @@ -48,12 +49,14 @@ fn query( // (with salsa memo-ization) to get the needed data, and resume our // current query by providing the plugin the answer. loop { + println!("Query needs more info, recursing..."); let answer = db.query( ar.publisher.clone(), ar.plugin.clone(), ar.query.clone(), ar.key.clone(), )?; + println!("Got answer {answer:?}, resuming"); ar = match runtime.block_on(p_handle.resume_query(ar, answer))? { PluginResponse::RemoteClosed => { return Err(hc_error!("Plugin channel closed unexpected")); @@ -79,6 +82,7 @@ pub fn async_query( }; // Initiate the query. If remote closed or we got our response immediately, // return + println!("Querying: {query}, key: {key:?}"); let mut ar = match p_handle.query(query, key).await? { PluginResponse::RemoteClosed => { return Err(hc_error!("Plugin channel closed unexpected")); @@ -92,6 +96,7 @@ pub fn async_query( // (with salsa memo-ization) to get the needed data, and resume our // current query by providing the plugin the answer. loop { + println!("Awaiting result, now recursing"); let answer = async_query( Arc::clone(&core), ar.publisher.clone(), @@ -100,6 +105,7 @@ pub fn async_query( ar.key.clone(), ) .await?; + println!("Resuming query with answer {answer:?}"); ar = match p_handle.resume_query(ar, answer).await? { PluginResponse::RemoteClosed => { return Err(hc_error!("Plugin channel closed unexpected")); @@ -120,6 +126,14 @@ pub struct HcEngineImpl { impl salsa::Database for HcEngineImpl {} +impl salsa::ParallelDatabase for HcEngineImpl { + fn snapshot(&self) -> salsa::Snapshot { + salsa::Snapshot::new(HcEngineImpl { + storage: self.storage.snapshot(), + }) + } +} + impl HcEngineImpl { // Really HcEngineImpl and HcPluginCore do the same thing right now, except HcPluginCore // has an async constructor. If we can manipulate salsa to accept async functions, we diff --git a/hipcheck/src/main.rs b/hipcheck/src/main.rs index 7c72b1fd..deccb6bb 100644 --- a/hipcheck/src/main.rs +++ b/hipcheck/src/main.rs @@ -50,6 +50,7 @@ use cli::CacheOp; use cli::CheckArgs; use cli::CliConfig; use cli::FullCommands; +use cli::PluginArgs; use cli::SchemaArgs; use cli::SchemaCommand; use cli::SetupArgs; @@ -111,7 +112,7 @@ fn main() -> ExitCode { Some(FullCommands::Ready) => cmd_ready(&config), Some(FullCommands::Update(args)) => cmd_update(&args), Some(FullCommands::Cache(args)) => return cmd_cache(args, &config), - Some(FullCommands::Plugin) => cmd_plugin(), + Some(FullCommands::Plugin(args)) => cmd_plugin(args), Some(FullCommands::PrintConfig) => cmd_print_config(config.config()), Some(FullCommands::PrintData) => cmd_print_data(config.data()), Some(FullCommands::PrintCache) => cmd_print_home(config.cache()), @@ -603,16 +604,21 @@ fn check_github_token() -> StdResult<(), EnvVarCheckError> { }) } -fn cmd_plugin() { +fn cmd_plugin(args: PluginArgs) { use crate::engine::{async_query, HcEngine, HcEngineImpl}; use std::sync::Arc; use tokio::task::JoinSet; let tgt_dir = "./target/debug"; - let entrypoint = pathbuf![tgt_dir, "dummy_rand_data"]; - let plugin = Plugin { + let entrypoint1 = pathbuf![tgt_dir, "dummy_rand_data"]; + let entrypoint2 = pathbuf![tgt_dir, "dummy_sha256"]; + let plugin1 = Plugin { name: "rand_data".to_owned(), - entrypoint: entrypoint.display().to_string(), + entrypoint: entrypoint1.display().to_string(), + }; + let plugin2 = Plugin { + name: "sha256".to_owned(), + entrypoint: entrypoint2.display().to_string(), }; let plugin_executor = PluginExecutor::new( /* max_spawn_attempts */ 3, @@ -624,7 +630,10 @@ fn cmd_plugin() { .unwrap(); let engine = match HcEngineImpl::new( plugin_executor, - vec![PluginWithConfig(plugin, serde_json::json!(null))], + vec![ + PluginWithConfig(plugin1, serde_json::json!(null)), + PluginWithConfig(plugin2, serde_json::json!(null)), + ], ) { Ok(e) => e, Err(e) => { @@ -632,49 +641,62 @@ fn cmd_plugin() { return; } }; - let core = engine.core(); - let handle = HcEngineImpl::runtime(); - // @Note - how to initiate multiple queries with async calls - handle.block_on(async move { - let mut futs = JoinSet::new(); - for i in 1..10 { - let arc_core = Arc::clone(&core); - println!("Spawning"); - futs.spawn(async_query( - arc_core, - "MITRE".to_owned(), - "rand_data".to_owned(), - "rand_data".to_owned(), - serde_json::json!(i), - )); - } - while let Some(res) = futs.join_next().await { - println!("res: {res:?}"); - } - }); - // @Note - how to initiate multiple queries with sync calls - // let conc: Vec> = vec![]; - // for i in 0..10 { - // let fut = thread::spawn(|| { - // let res = match engine.query( - // "MITRE".to_owned(), - // "rand_data".to_owned(), - // "rand_data".to_owned(), - // serde_json::json!(i), - // ) { - // Ok(r) => r, - // Err(e) => { - // println!("{i}: Query failed: {e}"); - // return; - // } - // }; - // println!("{i}: Result: {res}"); - // }); - // conc.push(fut); - // } - // while let Some(x) = conc.pop() { - // x.join().unwrap(); - // } + if args.asynch { + // @Note - how to initiate multiple queries with async calls + let core = engine.core(); + let handle = HcEngineImpl::runtime(); + handle.block_on(async move { + let mut futs = JoinSet::new(); + for i in 1..10 { + let arc_core = Arc::clone(&core); + println!("Spawning"); + futs.spawn(async_query( + arc_core, + "MITRE".to_owned(), + "rand_data".to_owned(), + "rand_data".to_owned(), + serde_json::json!(i), + )); + } + while let Some(res) = futs.join_next().await { + println!("res: {res:?}"); + } + }); + } else { + let res = engine.query( + "MITRE".to_owned(), + "rand_data".to_owned(), + "rand_data".to_owned(), + serde_json::json!(1), + ); + println!("res: {res:?}"); + // @Note - how to initiate multiple queries with sync calls + // Currently does not work, compiler complains need Sync impl + // use std::thread; + // let conc: Vec> = vec![]; + // for i in 0..10 { + // let snapshot = engine.snapshot(); + // let fut = thread::spawn(|| { + // let res = match snapshot.query( + // "MITRE".to_owned(), + // "rand_data".to_owned(), + // "rand_data".to_owned(), + // serde_json::json!(i), + // ) { + // Ok(r) => r, + // Err(e) => { + // println!("{i}: Query failed: {e}"); + // return; + // } + // }; + // println!("{i}: Result: {res}"); + // }); + // conc.push(fut); + // } + // while let Some(x) = conc.pop() { + // x.join().unwrap(); + // } + } } fn cmd_ready(config: &CliConfig) { diff --git a/hipcheck/src/plugin/manager.rs b/hipcheck/src/plugin/manager.rs index cc129bef..5e1af99b 100644 --- a/hipcheck/src/plugin/manager.rs +++ b/hipcheck/src/plugin/manager.rs @@ -42,8 +42,15 @@ impl PluginExecutor { } fn get_available_port(&self) -> Result { for i in self.port_range.start..self.port_range.end { - if std::net::TcpListener::bind(format!("127.0.0.1:{i}")).is_ok() { - return Ok(i); + // @Todo - either TcpListener::bind returns Ok even if port is bound + // or we have a race condition. For now just have OS assign a port + // if std::net::TcpListener::bind(format!("127.0.0.1:{i}")).is_ok() { + // return Ok(i); + // } + if let Ok(addr) = std::net::TcpListener::bind("127.0.0.1:0") { + if let Ok(local_addr) = addr.local_addr() { + return Ok(local_addr.port()); + } } } Err(hc_error!("Failed to find available port")) @@ -60,6 +67,7 @@ impl PluginExecutor { // on the cmdline is not already in use, but it is still possible for that // port to become unavailable between our check and the plugin's bind attempt. // Hence the need for subsequent attempts if we get unlucky + eprintln!("Starting plugin '{}'", plugin.name); let mut spawn_attempts: usize = 0; while spawn_attempts < self.max_spawn_attempts { // Find free port for process. Don't retry if we fail since this means all @@ -67,6 +75,7 @@ impl PluginExecutor { let port = self.get_available_port()?; let port_str = port.to_string(); // Spawn plugin process + eprintln!("Spawning '{}' on port {}", &plugin.entrypoint, port_str); let Ok(mut proc) = Command::new(&plugin.entrypoint) .args(["--port", port_str.as_str()]) // @Temporary - directly forward stdout/stderr from plugin to shell diff --git a/hipcheck/src/plugin/mod.rs b/hipcheck/src/plugin/mod.rs index 444405cc..ca7490e0 100644 --- a/hipcheck/src/plugin/mod.rs +++ b/hipcheck/src/plugin/mod.rs @@ -88,6 +88,7 @@ impl ActivePlugin { key: serde_json::json!(null), output, }; + eprintln!("Resuming query with answer {query:?}"); Ok(self.channel.query(query).await?.into()) } } diff --git a/plugins/dummy_rand_data/src/hipcheck_transport.rs b/plugins/dummy_rand_data/src/hipcheck_transport.rs index 5d50c0f4..e206860b 100644 --- a/plugins/dummy_rand_data/src/hipcheck_transport.rs +++ b/plugins/dummy_rand_data/src/hipcheck_transport.rs @@ -1,15 +1,12 @@ use crate::hipcheck::{Query as PluginQuery, QueryState}; use anyhow::{anyhow, Result}; -use indexmap::map::IndexMap; use serde_json::Value; -use std::collections::VecDeque; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex}; +use std::collections::{HashMap, VecDeque}; +use tokio::sync::mpsc; use tonic::{codec::Streaming, Status}; #[derive(Debug)] pub struct Query { - pub id: usize, // if false, response pub request: bool, pub publisher: String, @@ -18,10 +15,13 @@ pub struct Query { pub key: Value, pub output: Value, } + impl TryFrom for Query { type Error = anyhow::Error; + fn try_from(value: PluginQuery) -> Result { use QueryState::*; + let request = match TryInto::::try_into(value.state)? { QueryUnspecified => return Err(anyhow!("unspecified error from plugin")), QueryReplyInProgress => { @@ -32,10 +32,11 @@ impl TryFrom for Query { QueryReplyComplete => false, QuerySubmit => true, }; + let key: Value = serde_json::from_str(value.key.as_str())?; let output: Value = serde_json::from_str(value.output.as_str())?; + Ok(Query { - id: value.id as usize, request, publisher: value.publisher_name, plugin: value.plugin_name, @@ -45,17 +46,35 @@ impl TryFrom for Query { }) } } -impl TryFrom for PluginQuery { - type Error = anyhow::Error; - fn try_from(value: Query) -> Result { + +type SessionTracker = HashMap>>; + +pub struct QuerySession { + id: usize, + tx: mpsc::Sender>, + rx: mpsc::Receiver>, + // So that we can remove ourselves when we get dropped + drop_tx: mpsc::Sender, +} + +impl QuerySession { + pub fn id(&self) -> usize { + self.id + } + + // Roughly equivalent to TryFrom, but the `id` field value + // comes from the QuerySession + fn convert(&self, value: Query) -> Result { let state_enum = match value.request { true => QueryState::QuerySubmit, false => QueryState::QueryReplyComplete, }; + let key = serde_json::to_string(&value.key)?; let output = serde_json::to_string(&value.output)?; + Ok(PluginQuery { - id: value.id as i32, + id: self.id() as i32, state: state_enum as i32, publisher_name: value.publisher, plugin_name: value.plugin, @@ -64,41 +83,65 @@ impl TryFrom for PluginQuery { output, }) } -} -#[derive(Clone, Debug)] -pub struct HcTransport { - tx: mpsc::Sender>, - rx: Arc>, -} -impl HcTransport { - pub fn new(rx: Streaming, tx: mpsc::Sender>) -> Self { - HcTransport { - rx: Arc::new(Mutex::new(MultiplexedQueryReceiver::new(rx))), - tx, + async fn recv_raw(&mut self) -> Result>> { + let mut out = VecDeque::new(); + + eprintln!("RAND-session: awaiting raw rx recv"); + + let opt_first = self + .rx + .recv() + .await + .ok_or(anyhow!("session channel closed unexpectedly"))?; + + let Some(first) = opt_first else { + // Underlying gRPC channel closed + return Ok(None); + }; + eprintln!("RAND-session: got first msg"); + out.push_back(first); + + // If more messages in the queue, opportunistically read more + loop { + eprintln!("RAND-session: trying to get additional msg"); + + match self.rx.try_recv() { + Ok(Some(msg)) => { + out.push_back(msg); + } + Ok(None) => { + eprintln!("warning: None received, gRPC channel closed. we may not close properly if None is not returned again"); + break; + } + // Whether empty or disconnected, we return what we have + Err(_) => { + break; + } + } } + + eprintln!("RAND-session: got {} msgs", out.len()); + Ok(Some(out)) } + pub async fn send(&self, query: Query) -> Result<()> { - let query: PluginQuery = query.try_into()?; + eprintln!("RAND-session: sending query"); + let query: PluginQuery = self.convert(query)?; self.tx.send(Ok(query)).await?; Ok(()) } - pub async fn recv_new(&self) -> Result> { - let mut rx_handle = self.rx.lock().await; - match rx_handle.recv_new().await? { - Some(msg) => msg.try_into().map(Some), - None => Ok(None), - } - } - pub async fn recv(&self, id: usize) -> Result> { + + pub async fn recv(&mut self) -> Result> { use QueryState::*; - let id = id as i32; - let mut rx_handle = self.rx.lock().await; - let Some(mut msg_chunks) = rx_handle.recv(id).await? else { + + eprintln!("RAND-session: calling recv_raw"); + let Some(mut msg_chunks) = self.recv_raw().await? else { return Ok(None); }; - drop(rx_handle); let mut raw = msg_chunks.pop_front().unwrap(); + eprintln!("RAND-session: recv got raw {raw:?}"); + let mut state: QueryState = raw.state.try_into()?; // If response is the first of a set of chunks, handle @@ -110,10 +153,8 @@ impl HcTransport { Some(msg) => msg, None => { // We ran out of messages, get a new batch - let mut rx_handle = self.rx.lock().await; - match rx_handle.recv(id).await? { + match self.recv_raw().await? { Some(x) => { - drop(rx_handle); msg_chunks = x; } None => { @@ -123,6 +164,7 @@ impl HcTransport { msg_chunks.pop_front().unwrap() } }; + // By now we have our "next" message state = next.state.try_into()?; match state { @@ -137,89 +179,122 @@ impl HcTransport { } }; } + // Sanity check - after we've left this loop, there should be no left over message if !msg_chunks.is_empty() { return Err(anyhow!( "received additional messages for id '{}' after QueryComplete status message", - id + self.id )); } } + raw.try_into().map(Some) } } +impl Drop for QuerySession { + // Notify to have self removed from session tracker + fn drop(&mut self) { + use mpsc::error::TrySendError; + let raw_id = self.id as i32; + + while let Err(e) = self.drop_tx.try_send(self.id as i32) { + match e { + TrySendError::Closed(_) => { + break; + } + TrySendError::Full(_) => (), + } + } + } +} + #[derive(Debug)] -pub struct MultiplexedQueryReceiver { +pub struct HcSessionSocket { + tx: mpsc::Sender>, rx: Streaming, - // Unlike in HipCheck, backlog is an IndexMap to ensure the earliest received - // requests are handled first - backlog: IndexMap>, + drop_tx: mpsc::Sender, + drop_rx: mpsc::Receiver, + sessions: SessionTracker, } -impl MultiplexedQueryReceiver { - pub fn new(rx: Streaming) -> Self { + +impl HcSessionSocket { + pub fn new(tx: mpsc::Sender>, rx: Streaming) -> Self { + // channel for QuerySession objects to notify us they dropped + // @Todo - make this configurable + let (drop_tx, drop_rx) = mpsc::channel(10); Self { + tx, rx, - backlog: IndexMap::new(), + drop_tx, + drop_rx, + sessions: HashMap::new(), } } - pub async fn recv_new(&mut self) -> Result> { - let opt_unhandled = self.backlog.iter().find(|(k, v)| { - if let Some(req) = v.front() { - return req.state() == QueryState::QuerySubmit; - } - false - }); - if let Some((k, v)) = opt_unhandled { - let id: i32 = *k; - let mut vec = self.backlog.shift_remove(&id).unwrap(); - // @Note - for now QuerySubmit doesn't chunk so we shouldn't expect - // multiple messages in the backlog for a new request - assert!(vec.len() == 1); - return Ok(vec.pop_front()); - } - // No backlog message, need to operate the receiver - loop { - let Some(raw) = self.rx.message().await? else { - // gRPC channel was closed - return Ok(None); - }; - if raw.state() == QueryState::QuerySubmit { - return Ok(Some(raw)); - } - match self.backlog.get_mut(&raw.id) { - Some(vec) => { - vec.push_back(raw); - } - None => { - self.backlog.insert(raw.id, VecDeque::from([raw])); - } + + fn cleanup_sessions(&mut self) { + // Pull off all existing drop notifications + while let Ok(id) = self.drop_rx.try_recv() { + if self.sessions.remove(&id).is_none() { + eprintln!( + "WARNING: HcSessionSocket got request to drop a session that does not exist" + ); + } else { + eprintln!("Cleaned up session {id}"); } } } - // @Invariant - this function will never return an empty VecDeque - pub async fn recv(&mut self, id: i32) -> Result>> { - // If we have 1+ messages on backlog for `id`, return them all, - // no need to waste time with successive calls - if let Some(msgs) = self.backlog.shift_remove(&id) { - return Ok(Some(msgs)); - } - // No backlog message, need to operate the receiver + + pub async fn listen(&mut self) -> Result> { loop { + eprintln!("RAND: listening"); + let Some(raw) = self.rx.message().await? else { - // gRPC channel was closed return Ok(None); }; - if raw.id == id { - return Ok(Some(VecDeque::from([raw]))); - } - match self.backlog.get_mut(&raw.id) { - Some(vec) => { - vec.push_back(raw); - } - None => { - self.backlog.insert(raw.id, VecDeque::from([raw])); - } + + // While we were waiting for a message, some session objects may have + // dropped, handle them before we look at the ID of this message. + // The downside of this strategy is that once we receive our last message, + // we won't clean up any sessions that close after + self.cleanup_sessions(); + + let id = raw.id; + + // If there is already a session with this ID, forward msg + if let Some(tx) = self.sessions.get_mut(&id) { + eprintln!("RAND-listen: forwarding message to session {id}"); + + if let Err(e) = tx.send(Some(raw)).await { + eprintln!("Error forwarding msg to session {id}"); + self.sessions.remove(&id); + }; + // If got a new query ID, create session + } else if raw.state() == QueryState::QuerySubmit { + eprintln!("RAND-listen: creating new session {id}"); + + let (in_tx, rx) = mpsc::channel::>(10); + let tx = self.tx.clone(); + + let session = QuerySession { + id: id as usize, + tx, + rx, + drop_tx: self.drop_tx.clone(), + }; + + in_tx + .send(Some(raw)) + .await + .expect("Failed sending message to newly created Session, should never happen"); + + eprintln!("RAND-listen: adding new session {id} to tracker"); + self.sessions.insert(id, in_tx); + + return Ok(Some(session)); + } else { + eprintln!("Got query with id {}, does not match existing session and is not new QuerySubmit", id); } } } diff --git a/plugins/dummy_rand_data/src/main.rs b/plugins/dummy_rand_data/src/main.rs index ee8d9ee2..494b61fc 100644 --- a/plugins/dummy_rand_data/src/main.rs +++ b/plugins/dummy_rand_data/src/main.rs @@ -11,7 +11,6 @@ use hipcheck::{ Configuration, ConfigurationResult, ConfigurationStatus, Empty, PolicyExpression, Query as PluginQuery, Schema, }; -use rand::Rng; use serde_json::{json, Value}; use std::pin::Pin; use tokio::sync::mpsc; @@ -21,18 +20,45 @@ use tonic::{transport::Server, Request, Response, Status, Streaming}; static GET_RAND_KEY_SCHEMA: &str = include_str!("query_schema_get_rand.json"); static GET_RAND_OUTPUT_SCHEMA: &str = include_str!("query_schema_get_rand.json"); -fn get_rand(num_bytes: usize) -> Vec { - let mut vec = vec![0u8; num_bytes]; - let mut rng = rand::thread_rng(); - rng.fill(vec.as_mut_slice()); - vec +fn reduce(input: u64) -> u64 { + input % 7 } -pub async fn handle_rand_data(channel: HcTransport, id: usize, key: u64) -> Result<()> { - let res = get_rand(key as usize); - let output = serde_json::to_value(res)?; +pub async fn handle_rand_data(mut session: QuerySession, key: u64) -> Result<()> { + let id = session.id(); + let sha_input = reduce(key); + eprintln!("RAND-{id}: key: {key}, reduced: {sha_input}"); + + let sha_req = Query { + request: true, + publisher: "MITRE".to_owned(), + plugin: "sha256".to_owned(), + query: "sha256".to_owned(), + key: json!(vec![sha_input]), + output: json!(null), + }; + + session.send(sha_req).await?; + let Some(res) = session.recv().await? else { + return Err(anyhow!("channel closed prematurely by remote")); + }; + + if res.request { + return Err(anyhow!("expected response from remote")); + } + + let mut sha_vec: Vec = serde_json::from_value(res.output)?; + eprintln!("RAND-{id}: hash: {sha_vec:02x?}"); + let key_vec = key.to_le_bytes().to_vec(); + + for (i, val) in key_vec.into_iter().enumerate() { + *sha_vec.get_mut(i).unwrap() += val; + } + + eprintln!("RAND-{id}: output: {sha_vec:02x?}"); + let output = serde_json::to_value(sha_vec)?; + let resp = Query { - id, request: false, publisher: "".to_owned(), plugin: "".to_owned(), @@ -40,51 +66,67 @@ pub async fn handle_rand_data(channel: HcTransport, id: usize, key: u64) -> Resu key: json!(null), output, }; - channel.send(resp).await?; + + session.send(resp).await?; + Ok(()) } + +async fn handle_session(mut session: QuerySession) -> Result<()> { + let Some(query) = session.recv().await? else { + eprintln!("session closed by remote"); + return Ok(()); + }; + + if !query.request { + return Err(anyhow!("Expected request from remote")); + } + + let name = query.query; + let key = query.key; + + if name == "rand_data" { + let Value::Number(num_size) = &key else { + return Err(anyhow!("get_rand argument must be a number")); + }; + + let Some(size) = num_size.as_u64() else { + return Err(anyhow!("get_rand argument must be an unsigned integer")); + }; + + handle_rand_data(session, size).await?; + + Ok(()) + } else { + Err(anyhow!("unrecognized query '{}'", name)) + } +} + struct RandDataRunner { - channel: HcTransport, + channel: HcSessionSocket, } + impl RandDataRunner { - pub fn new(channel: HcTransport) -> Self { + pub fn new(channel: HcSessionSocket) -> Self { RandDataRunner { channel } } - async fn handle_query(channel: HcTransport, id: usize, name: String, key: Value) -> Result<()> { - if name == "rand_data" { - let Value::Number(num_size) = &key else { - return Err(anyhow!("get_rand argument must be a number")); - }; - let Some(size) = num_size.as_u64() else { - return Err(anyhow!("get_rand argument must be an unsigned integer")); - }; - handle_rand_data(channel, id, size).await?; - Ok(()) - } else { - Err(anyhow!("unrecognized query '{}'", name)) - } - } - pub async fn run(self) -> Result<()> { + + pub async fn run(mut self) -> Result<()> { loop { - eprintln!("Looping"); - let Some(msg) = self.channel.recv_new().await? else { + eprintln!("RAND: Looping"); + + let Some(session) = self.channel.listen().await? else { eprintln!("Channel closed by remote"); break; }; - if msg.request { - let child_channel = self.channel.clone(); - tokio::spawn(async move { - if let Err(e) = - RandDataRunner::handle_query(child_channel, msg.id, msg.query, msg.key) - .await - { - eprintln!("handle_query failed: {e}"); - }; - }); - } else { - return Err(anyhow!("Did not expect a response-type message here")); - } + + tokio::spawn(async move { + if let Err(e) = handle_session(session).await { + eprintln!("handle_session failed: {e}"); + }; + }); } + Ok(()) } } @@ -93,6 +135,7 @@ impl RandDataRunner { struct RandDataPlugin { pub schema: Schema, } + impl RandDataPlugin { pub fn new() -> Self { let schema = Schema { @@ -109,6 +152,7 @@ impl Plugin for RandDataPlugin { type GetQuerySchemasStream = Pin> + Send + 'static>>; type InitiateQueryProtocolStream = ReceiverStream>; + async fn get_query_schemas( &self, _request: Request, @@ -117,6 +161,7 @@ impl Plugin for RandDataPlugin { .schema .clone())])))) } + async fn set_configuration( &self, request: Request, @@ -126,6 +171,7 @@ impl Plugin for RandDataPlugin { message: "".to_owned(), })) } + async fn get_default_policy_expression( &self, request: Request, @@ -134,18 +180,21 @@ impl Plugin for RandDataPlugin { policy_expression: "".to_owned(), })) } + async fn initiate_query_protocol( &self, request: Request>, ) -> Result, Status> { let rx = request.into_inner(); let (tx, out_rx) = mpsc::channel::>(4); + tokio::spawn(async move { - let channel = HcTransport::new(rx, tx); + let channel = HcSessionSocket::new(tx, rx); if let Err(e) = RandDataRunner::new(channel).run().await { eprintln!("rand_data plugin ended in error: {e}"); } }); + Ok(Response::new(ReceiverStream::new(out_rx))) } } @@ -162,9 +211,11 @@ async fn main() -> Result<(), Box> { let addr = format!("127.0.0.1:{}", args.port); let plugin = RandDataPlugin::new(); let svc = PluginServer::new(plugin); + Server::builder() .add_service(svc) .serve(addr.parse().unwrap()) .await?; + Ok(()) } diff --git a/plugins/dummy_sha256/Cargo.toml b/plugins/dummy_sha256/Cargo.toml new file mode 100644 index 00000000..409cb015 --- /dev/null +++ b/plugins/dummy_sha256/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "dummy_sha256" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +anyhow = "1.0.86" +clap = { version = "4.5.16", features = ["derive"] } +indexmap = "2.4.0" +prost = "0.13.1" +rand = "0.8.5" +serde_json = "1.0.125" +sha2 = "0.10.8" +tokio = { version = "1.39.2", features = ["rt"] } +tokio-stream = "0.1.15" +tonic = "0.12.1" diff --git a/plugins/dummy_sha256/src/hipcheck.rs b/plugins/dummy_sha256/src/hipcheck.rs new file mode 100644 index 00000000..50ce6bf2 --- /dev/null +++ b/plugins/dummy_sha256/src/hipcheck.rs @@ -0,0 +1,679 @@ +#![allow(clippy::enum_variant_names)] + +// This file is @generated by prost-build. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Configuration { + /// JSON string containing configuration data expected by the plugin, + /// pulled from the user's policy file. + #[prost(string, tag = "1")] + pub configuration: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ConfigurationResult { + /// The status of the configuration call. + #[prost(enumeration = "ConfigurationStatus", tag = "1")] + pub status: i32, + /// An optional error message, if there was an error. + #[prost(string, tag = "2")] + pub message: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PolicyExpression { + /// A policy expression, if the plugin has a default policy. + /// This MUST be filled in with any default values pulled from the plugin's + /// configuration. Hipcheck will only request the default policy _after_ + /// configuring the plugin. + #[prost(string, tag = "1")] + pub policy_expression: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Schema { + /// The name of the query being described by the schemas provided. + /// + /// If either the key and/or output schemas result in a message which is + /// too big, they may be chunked across multiple replies in the stream. + /// Replies with matching query names should have their fields concatenated + /// in the order received to reconstruct the chunks. + #[prost(string, tag = "1")] + pub query_name: ::prost::alloc::string::String, + /// The key schema, in JSON Schema format. + #[prost(string, tag = "2")] + pub key_schema: ::prost::alloc::string::String, + /// The output schema, in JSON Schema format. + #[prost(string, tag = "3")] + pub output_schema: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Query { + /// The ID of the request, used to associate requests and replies. + /// Odd numbers = initiated by `hc`. + /// Even numbers = initiated by a plugin. + #[prost(int32, tag = "1")] + pub id: i32, + /// The state of the query, indicating if this is a request or a reply, + /// and if it's a reply whether it's the end of the reply. + #[prost(enumeration = "QueryState", tag = "2")] + pub state: i32, + /// Publisher name and plugin name, when sent from Hipcheck to a plugin + /// to initiate a fresh query, are used by the receiving plugin to validate + /// that the query was intended for them. + /// + /// When a plugin is making a query to another plugin through Hipcheck, it's + /// used to indicate the destination plugin, and to indicate the plugin that + /// is replying when Hipcheck sends back the reply. + #[prost(string, tag = "3")] + pub publisher_name: ::prost::alloc::string::String, + #[prost(string, tag = "4")] + pub plugin_name: ::prost::alloc::string::String, + /// The name of the query being made, so the responding plugin knows what + /// to do with the provided data. + #[prost(string, tag = "5")] + pub query_name: ::prost::alloc::string::String, + /// The key for the query, as a JSON object. This is the data that Hipcheck's + /// incremental computation system will use to cache the response. + #[prost(string, tag = "6")] + pub key: ::prost::alloc::string::String, + /// The response for the query, as a JSON object. This will be cached by + /// Hipcheck for future queries matching the publisher name, plugin name, + /// query name, and key. + #[prost(string, tag = "7")] + pub output: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct Empty {} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum ConfigurationStatus { + /// An unknown error occured. + ErrorUnknown = 0, + /// No error; the operation was successful. + ErrorNone = 1, + /// The user failed to provide a required configuration item. + ErrorMissingRequiredConfiguration = 2, + /// The user provided a configuration item whose name was not recognized. + ErrorUnrecognizedConfiguration = 3, + /// The user provided a configuration item whose value is invalid. + ErrorInvalidConfigurationValue = 4, +} +impl ConfigurationStatus { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + ConfigurationStatus::ErrorUnknown => "ERROR_UNKNOWN", + ConfigurationStatus::ErrorNone => "ERROR_NONE", + ConfigurationStatus::ErrorMissingRequiredConfiguration => { + "ERROR_MISSING_REQUIRED_CONFIGURATION" + } + ConfigurationStatus::ErrorUnrecognizedConfiguration => { + "ERROR_UNRECOGNIZED_CONFIGURATION" + } + ConfigurationStatus::ErrorInvalidConfigurationValue => { + "ERROR_INVALID_CONFIGURATION_VALUE" + } + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "ERROR_UNKNOWN" => Some(Self::ErrorUnknown), + "ERROR_NONE" => Some(Self::ErrorNone), + "ERROR_MISSING_REQUIRED_CONFIGURATION" => Some(Self::ErrorMissingRequiredConfiguration), + "ERROR_UNRECOGNIZED_CONFIGURATION" => Some(Self::ErrorUnrecognizedConfiguration), + "ERROR_INVALID_CONFIGURATION_VALUE" => Some(Self::ErrorInvalidConfigurationValue), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum QueryState { + /// Something has gone wrong. + QueryUnspecified = 0, + /// We are submitting a new query. + QuerySubmit = 1, + /// We are replying to a query and expect more chunks. + QueryReplyInProgress = 2, + /// We are closing a reply to a query. If a query response is in one chunk, + /// just send this. If a query is in more than one chunk, send this with + /// the last message in the reply. This tells the receiver that all chunks + /// have been received. + QueryReplyComplete = 3, +} +impl QueryState { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + QueryState::QueryUnspecified => "QUERY_UNSPECIFIED", + QueryState::QuerySubmit => "QUERY_SUBMIT", + QueryState::QueryReplyInProgress => "QUERY_REPLY_IN_PROGRESS", + QueryState::QueryReplyComplete => "QUERY_REPLY_COMPLETE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "QUERY_UNSPECIFIED" => Some(Self::QueryUnspecified), + "QUERY_SUBMIT" => Some(Self::QuerySubmit), + "QUERY_REPLY_IN_PROGRESS" => Some(Self::QueryReplyInProgress), + "QUERY_REPLY_COMPLETE" => Some(Self::QueryReplyComplete), + _ => None, + } + } +} +/// Generated client implementations. +pub mod plugin_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::http::Uri; + use tonic::codegen::*; + #[derive(Debug, Clone)] + pub struct PluginClient { + inner: tonic::client::Grpc, + } + impl PluginClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl PluginClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> PluginClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + >>::Error: + Into + Send + Sync, + { + PluginClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /// * + /// Get schemas for all supported queries by the plugin. + /// + /// This is used by Hipcheck to validate that: + /// + /// - The plugin supports a default query taking a `target` type if used + /// as a top-level plugin in the user's policy file. + /// - That requests sent to the plugin and data returned by the plugin + /// match the schema during execution. + pub async fn get_query_schemas( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static("/hipcheck.Plugin/GetQuerySchemas"); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("hipcheck.Plugin", "GetQuerySchemas")); + self.inner.server_streaming(req, path, codec).await + } + /// * + /// Hipcheck sends all child nodes for the plugin from the user's policy + /// file to configure the plugin. + pub async fn set_configuration( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static("/hipcheck.Plugin/SetConfiguration"); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("hipcheck.Plugin", "SetConfiguration")); + self.inner.unary(req, path, codec).await + } + /// * + /// Get the default policy for a plugin, which may additionally depend on + /// the plugin's configuration. + pub async fn get_default_policy_expression( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result, tonic::Status> { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/hipcheck.Plugin/GetDefaultPolicyExpression"); + let mut req = request.into_request(); + req.extensions_mut().insert(GrpcMethod::new( + "hipcheck.Plugin", + "GetDefaultPolicyExpression", + )); + self.inner.unary(req, path, codec).await + } + /// * + /// Open a bidirectional streaming RPC to enable a request/response + /// protocol between Hipcheck and a plugin, where Hipcheck can issue + /// queries to the plugin, and the plugin may issue queries to _other_ + /// plugins through Hipcheck. + /// + /// Queries are cached by the publisher name, plugin name, query name, + /// and key, and if a match is found for those four values, then + /// Hipcheck will respond with the cached result of that prior matching + /// query rather than running the query again. + pub async fn initiate_query_protocol( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner.ready().await.map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = + http::uri::PathAndQuery::from_static("/hipcheck.Plugin/InitiateQueryProtocol"); + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert(GrpcMethod::new("hipcheck.Plugin", "InitiateQueryProtocol")); + self.inner.streaming(req, path, codec).await + } + } +} +/// Generated server implementations. +pub mod plugin_server { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with PluginServer. + #[async_trait] + pub trait Plugin: Send + Sync + 'static { + /// Server streaming response type for the GetQuerySchemas method. + type GetQuerySchemasStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + Send + + 'static; + /// * + /// Get schemas for all supported queries by the plugin. + /// + /// This is used by Hipcheck to validate that: + /// + /// - The plugin supports a default query taking a `target` type if used + /// as a top-level plugin in the user's policy file. + /// - That requests sent to the plugin and data returned by the plugin + /// match the schema during execution. + async fn get_query_schemas( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + /// * + /// Hipcheck sends all child nodes for the plugin from the user's policy + /// file to configure the plugin. + async fn set_configuration( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + /// * + /// Get the default policy for a plugin, which may additionally depend on + /// the plugin's configuration. + async fn get_default_policy_expression( + &self, + request: tonic::Request, + ) -> std::result::Result, tonic::Status>; + /// Server streaming response type for the InitiateQueryProtocol method. + type InitiateQueryProtocolStream: tonic::codegen::tokio_stream::Stream< + Item = std::result::Result, + > + Send + + 'static; + /// * + /// Open a bidirectional streaming RPC to enable a request/response + /// protocol between Hipcheck and a plugin, where Hipcheck can issue + /// queries to the plugin, and the plugin may issue queries to _other_ + /// plugins through Hipcheck. + /// + /// Queries are cached by the publisher name, plugin name, query name, + /// and key, and if a match is found for those four values, then + /// Hipcheck will respond with the cached result of that prior matching + /// query rather than running the query again. + async fn initiate_query_protocol( + &self, + request: tonic::Request>, + ) -> std::result::Result, tonic::Status>; + } + #[derive(Debug)] + pub struct PluginServer { + inner: Arc, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + impl PluginServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for PluginServer + where + T: Plugin, + B: Body + Send + 'static, + B::Error: Into + Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + match req.uri().path() { + "/hipcheck.Plugin/GetQuerySchemas" => { + #[allow(non_camel_case_types)] + struct GetQuerySchemasSvc(pub Arc); + impl tonic::server::ServerStreamingService for GetQuerySchemasSvc { + type Response = super::Schema; + type ResponseStream = T::GetQuerySchemasStream; + type Future = + BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::get_query_schemas(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = GetQuerySchemasSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.server_streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/hipcheck.Plugin/SetConfiguration" => { + #[allow(non_camel_case_types)] + struct SetConfigurationSvc(pub Arc); + impl tonic::server::UnaryService for SetConfigurationSvc { + type Response = super::ConfigurationResult; + type Future = BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::set_configuration(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = SetConfigurationSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/hipcheck.Plugin/GetDefaultPolicyExpression" => { + #[allow(non_camel_case_types)] + struct GetDefaultPolicyExpressionSvc(pub Arc); + impl tonic::server::UnaryService for GetDefaultPolicyExpressionSvc { + type Response = super::PolicyExpression; + type Future = BoxFuture, tonic::Status>; + fn call(&mut self, request: tonic::Request) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::get_default_policy_expression(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = GetDefaultPolicyExpressionSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/hipcheck.Plugin/InitiateQueryProtocol" => { + #[allow(non_camel_case_types)] + struct InitiateQueryProtocolSvc(pub Arc); + impl tonic::server::StreamingService for InitiateQueryProtocolSvc { + type Response = super::Query; + type ResponseStream = T::InitiateQueryProtocolStream; + type Future = + BoxFuture, tonic::Status>; + fn call( + &mut self, + request: tonic::Request>, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::initiate_query_protocol(&inner, request).await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let method = InitiateQueryProtocolSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.streaming(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => Box::pin(async move { + Ok(http::Response::builder() + .status(200) + .header("grpc-status", tonic::Code::Unimplemented as i32) + .header( + http::header::CONTENT_TYPE, + tonic::metadata::GRPC_CONTENT_TYPE, + ) + .body(empty_body()) + .unwrap()) + }), + } + } + } + impl Clone for PluginServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + impl tonic::server::NamedService for PluginServer { + const NAME: &'static str = "hipcheck.Plugin"; + } +} diff --git a/plugins/dummy_sha256/src/hipcheck_transport.rs b/plugins/dummy_sha256/src/hipcheck_transport.rs new file mode 100644 index 00000000..5d50c0f4 --- /dev/null +++ b/plugins/dummy_sha256/src/hipcheck_transport.rs @@ -0,0 +1,226 @@ +use crate::hipcheck::{Query as PluginQuery, QueryState}; +use anyhow::{anyhow, Result}; +use indexmap::map::IndexMap; +use serde_json::Value; +use std::collections::VecDeque; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; +use tonic::{codec::Streaming, Status}; + +#[derive(Debug)] +pub struct Query { + pub id: usize, + // if false, response + pub request: bool, + pub publisher: String, + pub plugin: String, + pub query: String, + pub key: Value, + pub output: Value, +} +impl TryFrom for Query { + type Error = anyhow::Error; + fn try_from(value: PluginQuery) -> Result { + use QueryState::*; + let request = match TryInto::::try_into(value.state)? { + QueryUnspecified => return Err(anyhow!("unspecified error from plugin")), + QueryReplyInProgress => { + return Err(anyhow!( + "invalid state QueryReplyInProgress for conversion to Query" + )) + } + QueryReplyComplete => false, + QuerySubmit => true, + }; + let key: Value = serde_json::from_str(value.key.as_str())?; + let output: Value = serde_json::from_str(value.output.as_str())?; + Ok(Query { + id: value.id as usize, + request, + publisher: value.publisher_name, + plugin: value.plugin_name, + query: value.query_name, + key, + output, + }) + } +} +impl TryFrom for PluginQuery { + type Error = anyhow::Error; + fn try_from(value: Query) -> Result { + let state_enum = match value.request { + true => QueryState::QuerySubmit, + false => QueryState::QueryReplyComplete, + }; + let key = serde_json::to_string(&value.key)?; + let output = serde_json::to_string(&value.output)?; + Ok(PluginQuery { + id: value.id as i32, + state: state_enum as i32, + publisher_name: value.publisher, + plugin_name: value.plugin, + query_name: value.query, + key, + output, + }) + } +} + +#[derive(Clone, Debug)] +pub struct HcTransport { + tx: mpsc::Sender>, + rx: Arc>, +} +impl HcTransport { + pub fn new(rx: Streaming, tx: mpsc::Sender>) -> Self { + HcTransport { + rx: Arc::new(Mutex::new(MultiplexedQueryReceiver::new(rx))), + tx, + } + } + pub async fn send(&self, query: Query) -> Result<()> { + let query: PluginQuery = query.try_into()?; + self.tx.send(Ok(query)).await?; + Ok(()) + } + pub async fn recv_new(&self) -> Result> { + let mut rx_handle = self.rx.lock().await; + match rx_handle.recv_new().await? { + Some(msg) => msg.try_into().map(Some), + None => Ok(None), + } + } + pub async fn recv(&self, id: usize) -> Result> { + use QueryState::*; + let id = id as i32; + let mut rx_handle = self.rx.lock().await; + let Some(mut msg_chunks) = rx_handle.recv(id).await? else { + return Ok(None); + }; + drop(rx_handle); + let mut raw = msg_chunks.pop_front().unwrap(); + let mut state: QueryState = raw.state.try_into()?; + + // If response is the first of a set of chunks, handle + if matches!(state, QueryReplyInProgress) { + while matches!(state, QueryReplyInProgress) { + // We expect another message. Pull it off the existing queue, + // or get a new one if we have run out + let next = match msg_chunks.pop_front() { + Some(msg) => msg, + None => { + // We ran out of messages, get a new batch + let mut rx_handle = self.rx.lock().await; + match rx_handle.recv(id).await? { + Some(x) => { + drop(rx_handle); + msg_chunks = x; + } + None => { + return Ok(None); + } + }; + msg_chunks.pop_front().unwrap() + } + }; + // By now we have our "next" message + state = next.state.try_into()?; + match state { + QueryUnspecified => return Err(anyhow!("unspecified error from plugin")), + QuerySubmit => { + return Err(anyhow!( + "plugin sent QuerySubmit state when reply chunk expected" + )) + } + QueryReplyInProgress | QueryReplyComplete => { + raw.output.push_str(next.output.as_str()); + } + }; + } + // Sanity check - after we've left this loop, there should be no left over message + if !msg_chunks.is_empty() { + return Err(anyhow!( + "received additional messages for id '{}' after QueryComplete status message", + id + )); + } + } + raw.try_into().map(Some) + } +} + +#[derive(Debug)] +pub struct MultiplexedQueryReceiver { + rx: Streaming, + // Unlike in HipCheck, backlog is an IndexMap to ensure the earliest received + // requests are handled first + backlog: IndexMap>, +} +impl MultiplexedQueryReceiver { + pub fn new(rx: Streaming) -> Self { + Self { + rx, + backlog: IndexMap::new(), + } + } + pub async fn recv_new(&mut self) -> Result> { + let opt_unhandled = self.backlog.iter().find(|(k, v)| { + if let Some(req) = v.front() { + return req.state() == QueryState::QuerySubmit; + } + false + }); + if let Some((k, v)) = opt_unhandled { + let id: i32 = *k; + let mut vec = self.backlog.shift_remove(&id).unwrap(); + // @Note - for now QuerySubmit doesn't chunk so we shouldn't expect + // multiple messages in the backlog for a new request + assert!(vec.len() == 1); + return Ok(vec.pop_front()); + } + // No backlog message, need to operate the receiver + loop { + let Some(raw) = self.rx.message().await? else { + // gRPC channel was closed + return Ok(None); + }; + if raw.state() == QueryState::QuerySubmit { + return Ok(Some(raw)); + } + match self.backlog.get_mut(&raw.id) { + Some(vec) => { + vec.push_back(raw); + } + None => { + self.backlog.insert(raw.id, VecDeque::from([raw])); + } + } + } + } + // @Invariant - this function will never return an empty VecDeque + pub async fn recv(&mut self, id: i32) -> Result>> { + // If we have 1+ messages on backlog for `id`, return them all, + // no need to waste time with successive calls + if let Some(msgs) = self.backlog.shift_remove(&id) { + return Ok(Some(msgs)); + } + // No backlog message, need to operate the receiver + loop { + let Some(raw) = self.rx.message().await? else { + // gRPC channel was closed + return Ok(None); + }; + if raw.id == id { + return Ok(Some(VecDeque::from([raw]))); + } + match self.backlog.get_mut(&raw.id) { + Some(vec) => { + vec.push_back(raw); + } + None => { + self.backlog.insert(raw.id, VecDeque::from([raw])); + } + } + } + } +} diff --git a/plugins/dummy_sha256/src/main.rs b/plugins/dummy_sha256/src/main.rs new file mode 100644 index 00000000..88d87d48 --- /dev/null +++ b/plugins/dummy_sha256/src/main.rs @@ -0,0 +1,181 @@ +#![allow(unused_variables)] + +mod hipcheck; +mod hipcheck_transport; + +use crate::hipcheck_transport::*; +use anyhow::{anyhow, Result}; +use clap::Parser; +use hipcheck::plugin_server::{Plugin, PluginServer}; +use hipcheck::{ + Configuration, ConfigurationResult, ConfigurationStatus, Empty, PolicyExpression, + Query as PluginQuery, Schema, +}; +use serde_json::{json, Value}; +use sha2::{Digest, Sha256}; +use std::pin::Pin; +use tokio::sync::mpsc; +use tokio_stream::{wrappers::ReceiverStream, Stream}; +use tonic::{transport::Server, Request, Response, Status, Streaming}; + +static SHA256_KEY_SCHEMA: &str = include_str!("query_schema_sha256.json"); +static SHA256_OUTPUT_SCHEMA: &str = include_str!("query_schema_sha256.json"); + +fn sha256(content: Vec) -> Vec { + let mut hasher = Sha256::new(); + hasher.update(content); + hasher.finalize().to_vec() +} + +pub async fn handle_sha256(channel: HcTransport, id: usize, key: Vec) -> Result<()> { + println!("SHA256-{id}: Key: {key:02x?}"); + let res = sha256(key); + println!("SHA256-{id}: Hash: {res:02x?}"); + let output = serde_json::to_value(res)?; + let resp = Query { + id, + request: false, + publisher: "".to_owned(), + plugin: "".to_owned(), + query: "".to_owned(), + key: json!(null), + output, + }; + channel.send(resp).await?; + Ok(()) +} +struct Sha256Runner { + channel: HcTransport, +} +impl Sha256Runner { + pub fn new(channel: HcTransport) -> Self { + Sha256Runner { channel } + } + async fn handle_query(channel: HcTransport, id: usize, name: String, key: Value) -> Result<()> { + if name == "sha256" { + let Value::Array(val_vec) = &key else { + return Err(anyhow!("get_rand argument must be a number")); + }; + let byte_vec = val_vec + .iter() + .map(|x| { + let Value::Number(val_byte) = x else { + return Err(anyhow!("expected all integers")); + }; + let Some(byte) = val_byte.as_u64() else { + return Err(anyhow!( + "sha256 input array must contain only unsigned integers" + )); + }; + Ok(byte as u8) + }) + .collect::>>()?; + handle_sha256(channel, id, byte_vec).await?; + Ok(()) + } else { + Err(anyhow!("unrecognized query '{}'", name)) + } + } + pub async fn run(self) -> Result<()> { + loop { + eprintln!("SHA256: Looping"); + let Some(msg) = self.channel.recv_new().await? else { + eprintln!("Channel closed by remote"); + break; + }; + if msg.request { + let child_channel = self.channel.clone(); + tokio::spawn(async move { + if let Err(e) = + Sha256Runner::handle_query(child_channel, msg.id, msg.query, msg.key).await + { + eprintln!("handle_query failed: {e}"); + }; + }); + } else { + return Err(anyhow!("Did not expect a response-type message here")); + } + } + Ok(()) + } +} + +#[derive(Debug)] +struct RandDataPlugin { + pub schema: Schema, +} +impl RandDataPlugin { + pub fn new() -> Self { + let schema = Schema { + query_name: "sha256".to_owned(), + key_schema: SHA256_KEY_SCHEMA.to_owned(), + output_schema: SHA256_OUTPUT_SCHEMA.to_owned(), + }; + RandDataPlugin { schema } + } +} + +#[tonic::async_trait] +impl Plugin for RandDataPlugin { + type GetQuerySchemasStream = + Pin> + Send + 'static>>; + type InitiateQueryProtocolStream = ReceiverStream>; + async fn get_query_schemas( + &self, + _request: Request, + ) -> Result, Status> { + Ok(Response::new(Box::pin(tokio_stream::iter(vec![Ok(self + .schema + .clone())])))) + } + async fn set_configuration( + &self, + request: Request, + ) -> Result, Status> { + Ok(Response::new(ConfigurationResult { + status: ConfigurationStatus::ErrorNone as i32, + message: "".to_owned(), + })) + } + async fn get_default_policy_expression( + &self, + request: Request, + ) -> Result, Status> { + Ok(Response::new(PolicyExpression { + policy_expression: "".to_owned(), + })) + } + async fn initiate_query_protocol( + &self, + request: Request>, + ) -> Result, Status> { + let rx = request.into_inner(); + let (tx, out_rx) = mpsc::channel::>(4); + tokio::spawn(async move { + let channel = HcTransport::new(rx, tx); + if let Err(e) = Sha256Runner::new(channel).run().await { + eprintln!("sha256 plugin ended in error: {e}"); + } + }); + Ok(Response::new(ReceiverStream::new(out_rx))) + } +} + +#[derive(Parser, Debug)] +struct Args { + #[arg(long)] + port: u16, +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let args = Args::try_parse().map_err(Box::new)?; + let addr = format!("127.0.0.1:{}", args.port); + let plugin = RandDataPlugin::new(); + let svc = PluginServer::new(plugin); + Server::builder() + .add_service(svc) + .serve(addr.parse().unwrap()) + .await?; + Ok(()) +} diff --git a/plugins/dummy_sha256/src/query_schema_sha256.json b/plugins/dummy_sha256/src/query_schema_sha256.json new file mode 100644 index 00000000..8b50ea30 --- /dev/null +++ b/plugins/dummy_sha256/src/query_schema_sha256.json @@ -0,0 +1,3 @@ +{ + "type": "integer" +}