Skip to content

Commit

Permalink
fix(lib,provers,tasks): move from sync to async trait (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
petarvujovic98 authored Jul 23, 2024
1 parent 959bdea commit 36a5614
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 65 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ utoipa = { workspace = true }
cfg-if = { workspace = true }
tracing = { workspace = true }
bincode = { workspace = true }
async-trait = { workspace = true }

# [target.'cfg(feature = "std")'.dependencies]
flate2 = { workspace = true, optional = true }
Expand Down
8 changes: 5 additions & 3 deletions lib/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ pub struct Proof {
pub kzg_proof: Option<String>,
}

#[async_trait::async_trait]
pub trait IdWrite: Send {
fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()>;
async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()>;

fn remove_id(&mut self, key: ProofKey) -> ProverResult<()>;
async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()>;
}

#[async_trait::async_trait]
pub trait IdStore: IdWrite {
fn read_id(&self, key: ProofKey) -> ProverResult<String>;
async fn read_id(&self, key: ProofKey) -> ProverResult<String>;
}

#[allow(async_fn_in_trait)]
Expand Down
2 changes: 1 addition & 1 deletion provers/risc0/driver/src/bonsai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ pub async fn prove_bonsai<O: Eq + Debug + DeserializeOwned>(
)?;

if let Some(id_store) = id_store {
id_store.store_id(proof_key, session.uuid.clone())?;
id_store.store_id(proof_key, session.uuid.clone()).await?;
}

verify_bonsai_receipt(image_id, expected_output, session.uuid.clone(), 8).await
Expand Down
4 changes: 2 additions & 2 deletions provers/risc0/driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ impl Prover for Risc0Prover {
}

async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> {
let uuid = id_store.read_id(key)?;
let uuid = id_store.read_id(key).await?;
cancel_proof(uuid)
.await
.map_err(|e| ProverError::GuestError(e.to_string()))?;
id_store.remove_id(key)?;
id_store.remove_id(key).await?;
Ok(())
}
}
Expand Down
14 changes: 8 additions & 6 deletions provers/sp1/driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ impl Prover for Sp1Prover {
ProverError::GuestError("Sp1: creating proof failed".to_owned())
})?;
if let Some(id_store) = id_store {
id_store.store_id(
(input.chain_spec.chain_id, output.hash, SP1_PROVER_CODE),
proof_id.clone(),
)?;
id_store
.store_id(
(input.chain_spec.chain_id, output.hash, SP1_PROVER_CODE),
proof_id.clone(),
)
.await?;
}
let proof = {
let mut is_claimed = false;
Expand Down Expand Up @@ -136,7 +138,7 @@ impl Prover for Sp1Prover {
}

async fn cancel(key: ProofKey, id_store: Box<&mut dyn IdStore>) -> ProverResult<()> {
let proof_id = id_store.read_id(key)?;
let proof_id = id_store.read_id(key).await?;
let private_key = env::var("SP1_PRIVATE_KEY").map_err(|_| {
ProverError::GuestError("SP1_PRIVATE_KEY must be set for remote proving".to_owned())
})?;
Expand All @@ -145,7 +147,7 @@ impl Prover for Sp1Prover {
.unclaim_proof(proof_id, UnclaimReason::Abandoned, "".to_owned())
.await
.map_err(|_| ProverError::GuestError("Sp1: couldn't unclaim proof".to_owned()))?;
id_store.remove_id(key)?;
id_store.remove_id(key).await?;
Ok(())
}
}
Expand Down
40 changes: 18 additions & 22 deletions tasks/src/adv_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ use raiko_lib::{
use rusqlite::{
named_params, {Connection, OpenFlags},
};
use tokio::{runtime::Builder, sync::Mutex};
use tokio::sync::Mutex;

use crate::{
TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult,
Expand Down Expand Up @@ -833,34 +833,30 @@ impl TaskDb {
}
}

#[async_trait::async_trait]
impl IdWrite for SqliteTaskManager {
fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
let rt = Builder::new_current_thread().enable_all().build()?;
rt.block_on(async move {
let task_db = self.arc_task_db.lock().await;
task_db.store_id(key, id)
})
.map_err(|e| ProverError::StoreError(e.to_string()))
async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
let task_db = self.arc_task_db.lock().await;
task_db
.store_id(key, id)
.map_err(|e| ProverError::StoreError(e.to_string()))
}

fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
let rt = Builder::new_current_thread().enable_all().build()?;
rt.block_on(async move {
let task_db = self.arc_task_db.lock().await;
task_db.remove_id(key)
})
.map_err(|e| ProverError::StoreError(e.to_string()))
async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
let task_db = self.arc_task_db.lock().await;
task_db
.remove_id(key)
.map_err(|e| ProverError::StoreError(e.to_string()))
}
}

#[async_trait::async_trait]
impl IdStore for SqliteTaskManager {
fn read_id(&self, key: ProofKey) -> ProverResult<String> {
let rt = Builder::new_current_thread().enable_all().build()?;
rt.block_on(async move {
let task_db = self.arc_task_db.lock().await;
task_db.read_id(key)
})
.map_err(|e| ProverError::StoreError(e.to_string()))
async fn read_id(&self, key: ProofKey) -> ProverResult<String> {
let task_db = self.arc_task_db.lock().await;
task_db
.read_id(key)
.map_err(|e| ProverError::StoreError(e.to_string()))
}
}

Expand Down
20 changes: 11 additions & 9 deletions tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,27 +179,29 @@ pub struct TaskManagerWrapper {
manager: TaskManagerInstance,
}

#[async_trait::async_trait]
impl IdWrite for TaskManagerWrapper {
fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
match &mut self.manager {
TaskManagerInstance::InMemory(ref mut manager) => manager.store_id(key, id),
TaskManagerInstance::Sqlite(ref mut manager) => manager.store_id(key, id),
TaskManagerInstance::InMemory(ref mut manager) => manager.store_id(key, id).await,
TaskManagerInstance::Sqlite(ref mut manager) => manager.store_id(key, id).await,
}
}

fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
match &mut self.manager {
TaskManagerInstance::InMemory(ref mut manager) => manager.remove_id(key),
TaskManagerInstance::Sqlite(ref mut manager) => manager.remove_id(key),
TaskManagerInstance::InMemory(ref mut manager) => manager.remove_id(key).await,
TaskManagerInstance::Sqlite(ref mut manager) => manager.remove_id(key).await,
}
}
}

#[async_trait::async_trait]
impl IdStore for TaskManagerWrapper {
fn read_id(&self, key: ProofKey) -> ProverResult<String> {
async fn read_id(&self, key: ProofKey) -> ProverResult<String> {
match &self.manager {
TaskManagerInstance::InMemory(manager) => manager.read_id(key),
TaskManagerInstance::Sqlite(manager) => manager.read_id(key),
TaskManagerInstance::InMemory(manager) => manager.read_id(key).await,
TaskManagerInstance::Sqlite(manager) => manager.read_id(key).await,
}
}
}
Expand Down
37 changes: 15 additions & 22 deletions tasks/src/mem_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use std::{

use chrono::Utc;
use raiko_lib::prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult};
use tokio::{runtime::Builder, sync::Mutex};
use tokio::sync::Mutex;
use tracing::{debug, info};

use crate::{
Expand Down Expand Up @@ -143,34 +143,27 @@ impl InMemoryTaskDb {
}
}

#[async_trait::async_trait]
impl IdWrite for InMemoryTaskManager {
fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
let rt = Builder::new_current_thread().enable_all().build()?;
rt.block_on(async move {
let mut db = self.db.lock().await;
db.store_id(key, id)
})
.map_err(|e| ProverError::StoreError(e.to_string()))
async fn store_id(&mut self, key: ProofKey, id: String) -> ProverResult<()> {
let mut db = self.db.lock().await;
db.store_id(key, id)
.map_err(|e| ProverError::StoreError(e.to_string()))
}

fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
let rt = Builder::new_current_thread().enable_all().build()?;
rt.block_on(async move {
let mut db = self.db.lock().await;
db.remove_id(key)
})
.map_err(|e| ProverError::StoreError(e.to_string()))
async fn remove_id(&mut self, key: ProofKey) -> ProverResult<()> {
let mut db = self.db.lock().await;
db.remove_id(key)
.map_err(|e| ProverError::StoreError(e.to_string()))
}
}

#[async_trait::async_trait]
impl IdStore for InMemoryTaskManager {
fn read_id(&self, key: ProofKey) -> ProverResult<String> {
let rt = Builder::new_current_thread().enable_all().build()?;
rt.block_on(async move {
let mut db = self.db.lock().await;
db.read_id(key)
})
.map_err(|e| ProverError::StoreError(e.to_string()))
async fn read_id(&self, key: ProofKey) -> ProverResult<String> {
let mut db = self.db.lock().await;
db.read_id(key)
.map_err(|e| ProverError::StoreError(e.to_string()))
}
}

Expand Down

0 comments on commit 36a5614

Please sign in to comment.