From 36a56145b25c3b18fbcd3af5b1f2ab71b521cba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petar=20Vujovi=C4=87?= Date: Tue, 23 Jul 2024 11:53:27 +0200 Subject: [PATCH] fix(lib,provers,tasks): move from sync to async trait (#328) --- Cargo.lock | 1 + lib/Cargo.toml | 1 + lib/src/prover.rs | 8 +++--- provers/risc0/driver/src/bonsai.rs | 2 +- provers/risc0/driver/src/lib.rs | 4 +-- provers/sp1/driver/src/lib.rs | 14 ++++++----- tasks/src/adv_sqlite.rs | 40 ++++++++++++++---------------- tasks/src/lib.rs | 20 ++++++++------- tasks/src/mem_db.rs | 37 +++++++++++---------------- 9 files changed, 62 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cf17c3af3..b5dfebcd0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5800,6 +5800,7 @@ dependencies = [ "alloy-rpc-types", "alloy-sol-types", "anyhow", + "async-trait", "bincode", "cfg-if", "chrono", diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 75cb0d04f..3ce048f59 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -46,6 +46,7 @@ utoipa = { workspace = true } cfg-if = { workspace = true } tracing = { workspace = true } bincode = { workspace = true } +async-trait = { workspace = true } # [target.'cfg(feature = "std")'.dependencies] flate2 = { workspace = true, optional = true } diff --git a/lib/src/prover.rs b/lib/src/prover.rs index e43b511ce..a854409eb 100644 --- a/lib/src/prover.rs +++ b/lib/src/prover.rs @@ -37,14 +37,16 @@ pub struct Proof { pub kzg_proof: Option, } +#[async_trait::async_trait] pub trait IdWrite: Send { - fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()>; + async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()>; - fn remove_id(&mut self, key: ProofKey) -> ProverResult<()>; + async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()>; } +#[async_trait::async_trait] pub trait IdStore: IdWrite { - fn read_id(&self, key: ProofKey) -> ProverResult; + async fn read_id(&self, key: ProofKey) -> ProverResult; } #[allow(async_fn_in_trait)] diff --git a/provers/risc0/driver/src/bonsai.rs b/provers/risc0/driver/src/bonsai.rs index 7be4ea5d4..6f5071678 100644 --- a/provers/risc0/driver/src/bonsai.rs +++ b/provers/risc0/driver/src/bonsai.rs @@ -219,7 +219,7 @@ pub async fn prove_bonsai( )?; if let Some(id_store) = id_store { - id_store.store_id(proof_key, session.uuid.clone())?; + id_store.store_id(proof_key, session.uuid.clone()).await?; } verify_bonsai_receipt(image_id, expected_output, session.uuid.clone(), 8).await diff --git a/provers/risc0/driver/src/lib.rs b/provers/risc0/driver/src/lib.rs index 0269d8387..69ad6425f 100644 --- a/provers/risc0/driver/src/lib.rs +++ b/provers/risc0/driver/src/lib.rs @@ -112,11 +112,11 @@ impl Prover for Risc0Prover { } async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> { - let uuid = id_store.read_id(key)?; + let uuid = id_store.read_id(key).await?; cancel_proof(uuid) .await .map_err(|e| ProverError::GuestError(e.to_string()))?; - id_store.remove_id(key)?; + id_store.remove_id(key).await?; Ok(()) } } diff --git a/provers/sp1/driver/src/lib.rs b/provers/sp1/driver/src/lib.rs index 42937bc42..a8dc96f39 100644 --- a/provers/sp1/driver/src/lib.rs +++ b/provers/sp1/driver/src/lib.rs @@ -69,10 +69,12 @@ impl Prover for Sp1Prover { ProverError::GuestError("Sp1: creating proof failed".to_owned()) })?; if let Some(id_store) = id_store { - id_store.store_id( - (input.chain_spec.chain_id, output.hash, SP1_PROVER_CODE), - proof_id.clone(), - )?; + id_store + .store_id( + (input.chain_spec.chain_id, output.hash, SP1_PROVER_CODE), + proof_id.clone(), + ) + .await?; } let proof = { let mut is_claimed = false; @@ -136,7 +138,7 @@ impl Prover for Sp1Prover { } async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> { - let proof_id = id_store.read_id(key)?; + let proof_id = id_store.read_id(key).await?; let private_key = env::var("SP1_PRIVATE_KEY").map_err(|_| { ProverError::GuestError("SP1_PRIVATE_KEY must be set for remote proving".to_owned()) })?; @@ -145,7 +147,7 @@ impl Prover for Sp1Prover { .unclaim_proof(proof_id, UnclaimReason::Abandoned, "".to_owned()) .await .map_err(|_| ProverError::GuestError("Sp1: couldn't unclaim proof".to_owned()))?; - id_store.remove_id(key)?; + id_store.remove_id(key).await?; Ok(()) } } diff --git a/tasks/src/adv_sqlite.rs b/tasks/src/adv_sqlite.rs index f362bbb74..ea6811807 100644 --- a/tasks/src/adv_sqlite.rs +++ b/tasks/src/adv_sqlite.rs @@ -166,7 +166,7 @@ use raiko_lib::{ use rusqlite::{ named_params, {Connection, OpenFlags}, }; -use tokio::{runtime::Builder, sync::Mutex}; +use tokio::sync::Mutex; use crate::{ TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult, @@ -833,34 +833,30 @@ impl TaskDb { } } +#[async_trait::async_trait] impl IdWrite for SqliteTaskManager { - fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { - let rt = Builder::new_current_thread().enable_all().build()?; - rt.block_on(async move { - let task_db = self.arc_task_db.lock().await; - task_db.store_id(key, id) - }) - .map_err(|e| ProverError::StoreError(e.to_string())) + async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { + let task_db = self.arc_task_db.lock().await; + task_db + .store_id(key, id) + .map_err(|e| ProverError::StoreError(e.to_string())) } - fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { - let rt = Builder::new_current_thread().enable_all().build()?; - rt.block_on(async move { - let task_db = self.arc_task_db.lock().await; - task_db.remove_id(key) - }) - .map_err(|e| ProverError::StoreError(e.to_string())) + async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { + let task_db = self.arc_task_db.lock().await; + task_db + .remove_id(key) + .map_err(|e| ProverError::StoreError(e.to_string())) } } +#[async_trait::async_trait] impl IdStore for SqliteTaskManager { - fn read_id(&self, key: ProofKey) -> ProverResult { - let rt = Builder::new_current_thread().enable_all().build()?; - rt.block_on(async move { - let task_db = self.arc_task_db.lock().await; - task_db.read_id(key) - }) - .map_err(|e| ProverError::StoreError(e.to_string())) + async fn read_id(&self, key: ProofKey) -> ProverResult { + let task_db = self.arc_task_db.lock().await; + task_db + .read_id(key) + .map_err(|e| ProverError::StoreError(e.to_string())) } } diff --git a/tasks/src/lib.rs b/tasks/src/lib.rs index 728ff7475..a241f3d16 100644 --- a/tasks/src/lib.rs +++ b/tasks/src/lib.rs @@ -179,27 +179,29 @@ pub struct TaskManagerWrapper { manager: TaskManagerInstance, } +#[async_trait::async_trait] impl IdWrite for TaskManagerWrapper { - fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { + async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { match &mut self.manager { - TaskManagerInstance::InMemory(ref mut manager) => manager.store_id(key, id), - TaskManagerInstance::Sqlite(ref mut manager) => manager.store_id(key, id), + TaskManagerInstance::InMemory(ref mut manager) => manager.store_id(key, id).await, + TaskManagerInstance::Sqlite(ref mut manager) => manager.store_id(key, id).await, } } - fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { + async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { match &mut self.manager { - TaskManagerInstance::InMemory(ref mut manager) => manager.remove_id(key), - TaskManagerInstance::Sqlite(ref mut manager) => manager.remove_id(key), + TaskManagerInstance::InMemory(ref mut manager) => manager.remove_id(key).await, + TaskManagerInstance::Sqlite(ref mut manager) => manager.remove_id(key).await, } } } +#[async_trait::async_trait] impl IdStore for TaskManagerWrapper { - fn read_id(&self, key: ProofKey) -> ProverResult { + async fn read_id(&self, key: ProofKey) -> ProverResult { match &self.manager { - TaskManagerInstance::InMemory(manager) => manager.read_id(key), - TaskManagerInstance::Sqlite(manager) => manager.read_id(key), + TaskManagerInstance::InMemory(manager) => manager.read_id(key).await, + TaskManagerInstance::Sqlite(manager) => manager.read_id(key).await, } } } diff --git a/tasks/src/mem_db.rs b/tasks/src/mem_db.rs index 13137760f..15832574e 100644 --- a/tasks/src/mem_db.rs +++ b/tasks/src/mem_db.rs @@ -14,7 +14,7 @@ use std::{ use chrono::Utc; use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}; -use tokio::{runtime::Builder, sync::Mutex}; +use tokio::sync::Mutex; use tracing::{debug, info}; use crate::{ @@ -143,34 +143,27 @@ impl InMemoryTaskDb { } } +#[async_trait::async_trait] impl IdWrite for InMemoryTaskManager { - fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { - let rt = Builder::new_current_thread().enable_all().build()?; - rt.block_on(async move { - let mut db = self.db.lock().await; - db.store_id(key, id) - }) - .map_err(|e| ProverError::StoreError(e.to_string())) + async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> { + let mut db = self.db.lock().await; + db.store_id(key, id) + .map_err(|e| ProverError::StoreError(e.to_string())) } - fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { - let rt = Builder::new_current_thread().enable_all().build()?; - rt.block_on(async move { - let mut db = self.db.lock().await; - db.remove_id(key) - }) - .map_err(|e| ProverError::StoreError(e.to_string())) + async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> { + let mut db = self.db.lock().await; + db.remove_id(key) + .map_err(|e| ProverError::StoreError(e.to_string())) } } +#[async_trait::async_trait] impl IdStore for InMemoryTaskManager { - fn read_id(&self, key: ProofKey) -> ProverResult { - let rt = Builder::new_current_thread().enable_all().build()?; - rt.block_on(async move { - let mut db = self.db.lock().await; - db.read_id(key) - }) - .map_err(|e| ProverError::StoreError(e.to_string())) + async fn read_id(&self, key: ProofKey) -> ProverResult { + let mut db = self.db.lock().await; + db.read_id(key) + .map_err(|e| ProverError::StoreError(e.to_string())) } }