diff --git a/.github/workflows/ci-prover-e2e.yml b/.github/workflows/ci-prover-e2e.yml index 6076874c3710..7d037e0ab73c 100644 --- a/.github/workflows/ci-prover-e2e.yml +++ b/.github/workflows/ci-prover-e2e.yml @@ -86,7 +86,7 @@ jobs: ci_run zkstack prover run --component=witness-generator --round=all-rounds --docker=false &>prover_logs/witness-generator.log & - name: Run Circuit Prover run: | - ci_run zkstack prover run --component=circuit-prover --witness-vector-generator-count=10 --docker=false &>prover_logs/circuit_prover.log & + ci_run zkstack prover run --component=circuit-prover -l=23 -h=3 --docker=false &>prover_logs/circuit_prover.log & - name: Wait for prover jobs to finish env: DATABASE_URL: postgres://postgres:notsecurepassword@localhost:5432/zksync_prover_localhost_proving_chain diff --git a/core/lib/basic_types/src/prover_dal.rs b/core/lib/basic_types/src/prover_dal.rs index d86f79ba77aa..d2af75fe2ff5 100644 --- a/core/lib/basic_types/src/prover_dal.rs +++ b/core/lib/basic_types/src/prover_dal.rs @@ -1,5 +1,5 @@ //! Types exposed by the prover DAL for general-purpose use. -use std::{net::IpAddr, ops::Add, str::FromStr}; +use std::{net::IpAddr, ops::Add, str::FromStr, time::Instant}; use chrono::{DateTime, Duration, NaiveDateTime, NaiveTime, Utc}; use serde::{Deserialize, Serialize}; @@ -18,6 +18,23 @@ pub struct FriProverJobMetadata { pub sequence_number: usize, pub depth: u16, pub is_node_final_proof: bool, + pub pick_time: Instant, +} + +impl FriProverJobMetadata { + /// Checks whether the metadata corresponds to a scheduler proof or not. + pub fn is_scheduler_proof(&self) -> anyhow::Result { + if self.aggregation_round == AggregationRound::Scheduler { + if self.circuit_id != 1 { + return Err(anyhow::anyhow!( + "Invalid circuit id {} for Scheduler proof", + self.circuit_id + )); + } + return Ok(true); + } + Ok(false) + } } #[derive(Debug, Clone, Copy, Default)] diff --git a/prover/Cargo.lock b/prover/Cargo.lock index a60f77d44dd7..af249b435a6b 100644 --- a/prover/Cargo.lock +++ b/prover/Cargo.lock @@ -6511,9 +6511,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" dependencies = [ "futures-core", "pin-project-lite", @@ -7868,6 +7868,7 @@ dependencies = [ "tracing", "vise", "zkevm_test_harness", + "zksync_circuit_prover_service", "zksync_config", "zksync_core_leftovers", "zksync_env_config", @@ -7875,12 +7876,33 @@ dependencies = [ "zksync_prover_dal", "zksync_prover_fri_types", "zksync_prover_fri_utils", + "zksync_prover_job_processor", "zksync_prover_keystore", "zksync_queued_job_processor", "zksync_types", "zksync_utils", ] +[[package]] +name = "zksync_circuit_prover_service" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "shivini", + "tokio", + "tokio-util", + "tracing", + "vise", + "zkevm_test_harness", + "zksync_object_store", + "zksync_prover_dal", + "zksync_prover_fri_types", + "zksync_prover_job_processor", + "zksync_prover_keystore", + "zksync_types", +] + [[package]] name = "zksync_concurrency" version = "0.5.0" @@ -8533,6 +8555,21 @@ dependencies = [ "zksync_vlog", ] +[[package]] +name = "zksync_prover_job_processor" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "futures 0.3.30", + "strum", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", + "vise", +] + [[package]] name = "zksync_prover_keystore" version = "0.1.0" diff --git a/prover/Cargo.toml b/prover/Cargo.toml index e53efaae1968..15e819d77f7d 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -58,6 +58,7 @@ strum_macros = "0.26" tempfile = "3" tokio = "1" tokio-util = "0.7.11" +tokio-stream = "0.1.16" toml_edit = "0.14.4" tracing = "0.1" tracing-subscriber = "0.3" @@ -100,6 +101,8 @@ zksync_prover_fri_types = { path = "crates/lib/prover_fri_types" } zksync_prover_fri_utils = { path = "crates/lib/prover_fri_utils" } zksync_prover_keystore = { path = "crates/lib/keystore" } zksync_vk_setup_data_generator_server_fri = { path = "crates/bin/vk_setup_data_generator_server_fri" } +zksync_prover_job_processor = { path = "crates/lib/prover_job_processor" } +zksync_circuit_prover_service = { path = "crates/lib/circuit_prover_service" } zksync_prover_job_monitor = { path = "crates/bin/prover_job_monitor" } # for `perf` profiling diff --git a/prover/crates/bin/circuit_prover/Cargo.toml b/prover/crates/bin/circuit_prover/Cargo.toml index a5751a4cd9a6..d7b7a8ca80fd 100644 --- a/prover/crates/bin/circuit_prover/Cargo.toml +++ b/prover/crates/bin/circuit_prover/Cargo.toml @@ -1,5 +1,6 @@ [package] name = "zksync_circuit_prover" +description = "ZKsync circuit prover binary implementation" version.workspace = true edition.workspace = true authors.workspace = true @@ -8,6 +9,7 @@ repository.workspace = true license.workspace = true keywords.workspace = true categories.workspace = true +publish = false [dependencies] tokio = { workspace = true, features = ["macros", "time"] } @@ -29,6 +31,8 @@ zksync_prover_keystore = { workspace = true, features = ["gpu"] } zksync_env_config.workspace = true zksync_core_leftovers.workspace = true zksync_utils.workspace = true +zksync_circuit_prover_service.workspace = true +zksync_prover_job_processor.workspace = true vise.workspace = true shivini = { workspace = true, features = [ diff --git a/prover/crates/bin/circuit_prover/src/circuit_prover.rs b/prover/crates/bin/circuit_prover/src/circuit_prover.rs deleted file mode 100644 index 1a5f8aa0d974..000000000000 --- a/prover/crates/bin/circuit_prover/src/circuit_prover.rs +++ /dev/null @@ -1,397 +0,0 @@ -use std::{sync::Arc, time::Instant}; - -use anyhow::Context; -use shivini::{ - gpu_proof_config::GpuProofConfig, gpu_prove_from_external_witness_data, ProverContext, - ProverContextConfig, -}; -use tokio::{sync::mpsc::Receiver, task::JoinHandle}; -use tokio_util::sync::CancellationToken; -use zkevm_test_harness::prover_utils::{verify_base_layer_proof, verify_recursion_layer_proof}; -use zksync_object_store::ObjectStore; -use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; -use zksync_prover_fri_types::{ - circuit_definitions::{ - base_layer_proof_config, - boojum::{ - cs::implementations::{pow::NoPow, witness::WitnessVec}, - field::goldilocks::GoldilocksField, - worker::Worker, - }, - circuit_definitions::{ - base_layer::ZkSyncBaseLayerProof, recursion_layer::ZkSyncRecursionLayerProof, - }, - recursion_layer_proof_config, - }, - CircuitWrapper, FriProofWrapper, ProverArtifacts, WitnessVectorArtifactsTemp, -}; -use zksync_prover_keystore::GoldilocksGpuProverSetupData; -use zksync_types::protocol_version::ProtocolSemanticVersion; -use zksync_utils::panic_extractor::try_extract_panic_message; - -use crate::{ - metrics::CIRCUIT_PROVER_METRICS, - types::{DefaultTranscript, DefaultTreeHasher, Proof, VerificationKey}, - SetupDataCache, -}; - -/// In charge of proving circuits, given a Witness Vector source. -/// Both job runner & job executor. -#[derive(Debug)] -pub struct CircuitProver { - connection_pool: ConnectionPool, - object_store: Arc, - protocol_version: ProtocolSemanticVersion, - /// Witness Vector source receiver - receiver: Receiver, - /// Setup Data used for proving & proof verification - setup_data_cache: SetupDataCache, -} - -impl CircuitProver { - pub fn new( - connection_pool: ConnectionPool, - object_store: Arc, - protocol_version: ProtocolSemanticVersion, - receiver: Receiver, - max_allocation: Option, - setup_data_cache: SetupDataCache, - ) -> anyhow::Result<(Self, ProverContext)> { - // VRAM allocation - let prover_context = match max_allocation { - Some(max_allocation) => ProverContext::create_with_config( - ProverContextConfig::default().with_maximum_device_allocation(max_allocation), - ) - .context("failed initializing fixed gpu prover context")?, - None => ProverContext::create().context("failed initializing gpu prover context")?, - }; - Ok(( - Self { - connection_pool, - object_store, - protocol_version, - receiver, - setup_data_cache, - }, - prover_context, - )) - } - - /// Continuously polls `receiver` for Witness Vectors and proves them. - /// All job executions are persisted. - pub async fn run(mut self, cancellation_token: CancellationToken) -> anyhow::Result<()> { - while !cancellation_token.is_cancelled() { - let time = Instant::now(); - - let artifact = self - .receiver - .recv() - .await - .context("no Witness Vector Generators are available")?; - tracing::info!( - "Circuit Prover received job {:?} after: {:?}", - artifact.prover_job.job_id, - time.elapsed() - ); - CIRCUIT_PROVER_METRICS.job_wait_time.observe(time.elapsed()); - - self.prove(artifact, cancellation_token.clone()) - .await - .context("failed to prove circuit proof")?; - } - tracing::info!("Circuit Prover shut down."); - Ok(()) - } - - /// Proves a job, with persistence of execution. - async fn prove( - &self, - artifact: WitnessVectorArtifactsTemp, - cancellation_token: CancellationToken, - ) -> anyhow::Result<()> { - let time = Instant::now(); - let block_number = artifact.prover_job.block_number; - let job_id = artifact.prover_job.job_id; - let job_start_time = artifact.time; - let setup_data_key = artifact.prover_job.setup_data_key.crypto_setup_key(); - let setup_data = self - .setup_data_cache - .get(&setup_data_key) - .context(format!( - "failed to get setup data for key {setup_data_key:?}" - ))? - .clone(); - let task = tokio::task::spawn_blocking(move || { - let _span = tracing::info_span!("prove_circuit_proof", %block_number).entered(); - Self::prove_circuit_proof(artifact, setup_data).context("failed to prove circuit") - }); - - self.finish_task( - job_id, - time, - job_start_time, - task, - cancellation_token.clone(), - ) - .await?; - tracing::info!( - "Circuit Prover finished job {:?} in: {:?}", - job_id, - time.elapsed() - ); - CIRCUIT_PROVER_METRICS - .job_finished_time - .observe(time.elapsed()); - CIRCUIT_PROVER_METRICS - .full_proving_time - .observe(job_start_time.elapsed()); - Ok(()) - } - - /// Proves a job using crypto primitives (proof generation & proof verification). - #[tracing::instrument( - name = "Prover::prove_circuit_proof", - skip_all, - fields(l1_batch = % witness_vector_artifacts.prover_job.block_number) - )] - pub fn prove_circuit_proof( - witness_vector_artifacts: WitnessVectorArtifactsTemp, - setup_data: Arc, - ) -> anyhow::Result { - let time = Instant::now(); - let WitnessVectorArtifactsTemp { - witness_vector, - prover_job, - .. - } = witness_vector_artifacts; - - let job_id = prover_job.job_id; - let circuit_wrapper = prover_job.circuit_wrapper; - let block_number = prover_job.block_number; - - let (proof, circuit_id) = - Self::generate_proof(&circuit_wrapper, witness_vector, &setup_data) - .context(format!("failed to generate proof for job id {job_id}"))?; - - Self::verify_proof(&circuit_wrapper, &proof, &setup_data.vk).context(format!( - "failed to verify proof with job_id {job_id}, circuit_id: {circuit_id}" - ))?; - - let proof_wrapper = match &circuit_wrapper { - CircuitWrapper::Base(_) => { - FriProofWrapper::Base(ZkSyncBaseLayerProof::from_inner(circuit_id, proof)) - } - CircuitWrapper::Recursive(_) => { - FriProofWrapper::Recursive(ZkSyncRecursionLayerProof::from_inner(circuit_id, proof)) - } - CircuitWrapper::BasePartial(_) => { - return Self::partial_proof_error(); - } - }; - CIRCUIT_PROVER_METRICS - .crypto_primitives_time - .observe(time.elapsed()); - Ok(ProverArtifacts::new(block_number, proof_wrapper)) - } - - /// Generates a proof from crypto primitives. - fn generate_proof( - circuit_wrapper: &CircuitWrapper, - witness_vector: WitnessVec, - setup_data: &Arc, - ) -> anyhow::Result<(Proof, u8)> { - let time = Instant::now(); - - let worker = Worker::new(); - - let (gpu_proof_config, proof_config, circuit_id) = match circuit_wrapper { - CircuitWrapper::Base(circuit) => ( - GpuProofConfig::from_base_layer_circuit(circuit), - base_layer_proof_config(), - circuit.numeric_circuit_type(), - ), - CircuitWrapper::Recursive(circuit) => ( - GpuProofConfig::from_recursive_layer_circuit(circuit), - recursion_layer_proof_config(), - circuit.numeric_circuit_type(), - ), - CircuitWrapper::BasePartial(_) => { - return Self::partial_proof_error(); - } - }; - - let proof = - gpu_prove_from_external_witness_data::( - &gpu_proof_config, - &witness_vector, - proof_config, - &setup_data.setup, - &setup_data.vk, - (), - &worker, - ) - .context("crypto primitive: failed to generate proof")?; - CIRCUIT_PROVER_METRICS - .generate_proof_time - .observe(time.elapsed()); - Ok((proof.into(), circuit_id)) - } - - /// Verifies a proof from crypto primitives - fn verify_proof( - circuit_wrapper: &CircuitWrapper, - proof: &Proof, - verification_key: &VerificationKey, - ) -> anyhow::Result<()> { - let time = Instant::now(); - - let is_valid = match circuit_wrapper { - CircuitWrapper::Base(base_circuit) => { - verify_base_layer_proof::(base_circuit, proof, verification_key) - } - CircuitWrapper::Recursive(recursive_circuit) => { - verify_recursion_layer_proof::(recursive_circuit, proof, verification_key) - } - CircuitWrapper::BasePartial(_) => { - return Self::partial_proof_error(); - } - }; - - CIRCUIT_PROVER_METRICS - .verify_proof_time - .observe(time.elapsed()); - - if !is_valid { - return Err(anyhow::anyhow!("crypto primitive: failed to verify proof")); - } - Ok(()) - } - - /// This code path should never trigger. All proofs are hydrated during Witness Vector Generator. - /// If this triggers, it means that proof hydration in Witness Vector Generator was not done -- logic bug. - fn partial_proof_error() -> anyhow::Result { - Err(anyhow::anyhow!("received unexpected dehydrated proof")) - } - - /// Runs task to completion and persists result. - /// NOTE: Task may be cancelled mid-flight. - async fn finish_task( - &self, - job_id: u32, - time: Instant, - job_start_time: Instant, - task: JoinHandle>, - cancellation_token: CancellationToken, - ) -> anyhow::Result<()> { - tokio::select! { - _ = cancellation_token.cancelled() => { - tracing::info!("Stop signal received, shutting down Circuit Prover..."); - return Ok(()) - } - result = task => { - let error_message = match result { - Ok(Ok(prover_artifact)) => { - tracing::info!("Circuit Prover executed job {:?} in: {:?}", job_id, time.elapsed()); - CIRCUIT_PROVER_METRICS.execution_time.observe(time.elapsed()); - self - .save_result(job_id, job_start_time, prover_artifact) - .await.context("failed to save result")?; - return Ok(()) - } - Ok(Err(error)) => error.to_string(), - Err(error) => try_extract_panic_message(error), - }; - tracing::error!( - "Circuit Prover failed on job {:?} with error {:?}", - job_id, - error_message - ); - - self.save_failure(job_id, error_message).await.context("failed to save failure")?; - } - } - - Ok(()) - } - - /// Persists proof generated. - /// Job metadata is saved to database, whilst artifacts go to object store. - async fn save_result( - &self, - job_id: u32, - job_start_time: Instant, - artifacts: ProverArtifacts, - ) -> anyhow::Result<()> { - let time = Instant::now(); - let mut connection = self - .connection_pool - .connection() - .await - .context("failed to get db connection")?; - let proof = artifacts.proof_wrapper; - - let (_circuit_type, is_scheduler_proof) = match &proof { - FriProofWrapper::Base(base) => (base.numeric_circuit_type(), false), - FriProofWrapper::Recursive(recursive_circuit) => match recursive_circuit { - ZkSyncRecursionLayerProof::SchedulerCircuit(_) => { - (recursive_circuit.numeric_circuit_type(), true) - } - _ => (recursive_circuit.numeric_circuit_type(), false), - }, - }; - - let upload_time = Instant::now(); - let blob_url = self - .object_store - .put(job_id, &proof) - .await - .context("failed to upload to object store")?; - CIRCUIT_PROVER_METRICS - .artifact_upload_time - .observe(upload_time.elapsed()); - - let mut transaction = connection - .start_transaction() - .await - .context("failed to start db transaction")?; - transaction - .fri_prover_jobs_dal() - .save_proof(job_id, job_start_time.elapsed(), &blob_url) - .await; - if is_scheduler_proof { - transaction - .fri_proof_compressor_dal() - .insert_proof_compression_job( - artifacts.block_number, - &blob_url, - self.protocol_version, - ) - .await; - } - transaction - .commit() - .await - .context("failed to commit db transaction")?; - - tracing::info!( - "Circuit Prover saved job {:?} after {:?}", - job_id, - time.elapsed() - ); - CIRCUIT_PROVER_METRICS.save_time.observe(time.elapsed()); - - Ok(()) - } - - /// Persists job execution error to database. - async fn save_failure(&self, job_id: u32, error: String) -> anyhow::Result<()> { - self.connection_pool - .connection() - .await - .context("failed to get db connection")? - .fri_prover_jobs_dal() - .save_proof_error(job_id, error) - .await; - Ok(()) - } -} diff --git a/prover/crates/bin/circuit_prover/src/lib.rs b/prover/crates/bin/circuit_prover/src/lib.rs index 7d7ce1d96686..c25afe6e9b3b 100644 --- a/prover/crates/bin/circuit_prover/src/lib.rs +++ b/prover/crates/bin/circuit_prover/src/lib.rs @@ -1,13 +1,5 @@ -#![allow(incomplete_features)] // We have to use generic const exprs. -#![feature(generic_const_exprs)] -pub use backoff::Backoff; -pub use circuit_prover::CircuitProver; pub use metrics::PROVER_BINARY_METRICS; pub use types::{FinalizationHintsCache, SetupDataCache}; -pub use witness_vector_generator::WitnessVectorGenerator; -mod backoff; -mod circuit_prover; mod metrics; mod types; -mod witness_vector_generator; diff --git a/prover/crates/bin/circuit_prover/src/main.rs b/prover/crates/bin/circuit_prover/src/main.rs index e26f29ca995d..e115d1510657 100644 --- a/prover/crates/bin/circuit_prover/src/main.rs +++ b/prover/crates/bin/circuit_prover/src/main.rs @@ -6,11 +6,10 @@ use std::{ use anyhow::Context as _; use clap::Parser; +use shivini::{ProverContext, ProverContextConfig}; use tokio_util::sync::CancellationToken; -use zksync_circuit_prover::{ - Backoff, CircuitProver, FinalizationHintsCache, SetupDataCache, WitnessVectorGenerator, - PROVER_BINARY_METRICS, -}; +use zksync_circuit_prover::{FinalizationHintsCache, SetupDataCache, PROVER_BINARY_METRICS}; +use zksync_circuit_prover_service::job_runner::{circuit_prover_runner, WvgRunnerBuilder}; use zksync_config::{ configs::{FriProverConfig, ObservabilityConfig}, ObjectStoreConfig, @@ -22,82 +21,105 @@ use zksync_prover_fri_types::PROVER_PROTOCOL_SEMANTIC_VERSION; use zksync_prover_keystore::keystore::Keystore; use zksync_utils::wait_for_tasks::ManagedTasks; +/// On most commodity hardware, WVG can take ~30 seconds to complete. +/// GPU processing is ~1 second. +/// Typical setup is ~25 WVGs & 1 GPU. +/// Worst case scenario, you just picked all 25 WVGs (so you need 30 seconds to finish) +/// and another 25 for the GPU. +const GRACEFUL_SHUTDOWN_DURATION: Duration = Duration::from_secs(55); + +/// With current setup, only a single job is expected to be in flight. +/// This guarantees memory consumption is going to be fixed (1 job in memory, no more). +/// Additionally, helps with estimating graceful shutdown time. +/// Free side effect, if the machine dies, only 1 job is in "pending" state. +const CHANNEL_SIZE: usize = 1; + #[derive(Debug, Parser)] #[command(author = "Matter Labs", version)] struct Cli { - #[arg(long)] + /// Path to file configuration + #[arg(short = 'c', long)] pub(crate) config_path: Option, - #[arg(long)] + /// Path to file secrets + #[arg(short = 's', long)] pub(crate) secrets_path: Option, - /// Number of WVG jobs to run in parallel. - /// Default value is 1. - #[arg(long, default_value_t = 1)] - pub(crate) witness_vector_generator_count: usize, + /// Number of light witness vector generators to run in parallel. + /// Corresponds to 1 CPU thread & ~2GB of RAM. + #[arg(short = 'l', long, default_value_t = 1)] + light_wvg_count: usize, + /// Number of heavy witness vector generators to run in parallel. + /// Corresponds to 1 CPU thread & ~9GB of RAM. + #[arg(short = 'h', long, default_value_t = 1)] + heavy_wvg_count: usize, /// Max VRAM to allocate. Useful if you want to limit the size of VRAM used. /// None corresponds to allocating all available VRAM. - #[arg(long)] + #[arg(short = 'm', long)] pub(crate) max_allocation: Option, } #[tokio::main] async fn main() -> anyhow::Result<()> { - let time = Instant::now(); + let start_time = Instant::now(); let opt = Cli::parse(); let (observability_config, prover_config, object_store_config) = load_configs(opt.config_path)?; - let _observability_guard = observability_config .install() .context("failed to install observability")?; - let wvg_count = opt.witness_vector_generator_count as u32; - - let (connection_pool, object_store, setup_data_cache, hints) = load_resources( + let (connection_pool, object_store, prover_context, setup_data_cache, hints) = load_resources( opt.secrets_path, + opt.max_allocation, object_store_config, prover_config.setup_data_path.into(), - wvg_count, ) .await .context("failed to load configs")?; - PROVER_BINARY_METRICS.start_up.observe(time.elapsed()); + PROVER_BINARY_METRICS + .startup_time + .observe(start_time.elapsed()); let cancellation_token = CancellationToken::new(); - let backoff = Backoff::new(Duration::from_secs(5), Duration::from_secs(30)); let mut tasks = vec![]; - let (sender, receiver) = tokio::sync::mpsc::channel(5); - - tracing::info!("Starting {wvg_count} Witness Vector Generators."); - - for _ in 0..wvg_count { - let wvg = WitnessVectorGenerator::new( - object_store.clone(), - connection_pool.clone(), - PROVER_PROTOCOL_SEMANTIC_VERSION, - sender.clone(), - hints.clone(), - ); - tasks.push(tokio::spawn( - wvg.run(cancellation_token.clone(), backoff.clone()), - )); - } + let (witness_vector_sender, witness_vector_receiver) = tokio::sync::mpsc::channel(CHANNEL_SIZE); + + tracing::info!( + "Starting {} light WVGs and {} heavy WVGs.", + opt.light_wvg_count, + opt.heavy_wvg_count + ); + + let builder = WvgRunnerBuilder::new( + connection_pool.clone(), + object_store.clone(), + PROVER_PROTOCOL_SEMANTIC_VERSION, + hints.clone(), + witness_vector_sender, + cancellation_token.clone(), + ); + + let light_wvg_runner = builder.light_wvg_runner(opt.light_wvg_count); + let heavy_wvg_runner = builder.heavy_wvg_runner(opt.heavy_wvg_count); - // NOTE: Prover Context is the way VRAM is allocated. If it is dropped, the claim on VRAM allocation is dropped as well. - // It has to be kept until prover dies. Whilst it may be kept in prover struct, during cancellation, prover can `drop`, but the thread doing the processing can still be alive. - // This setup prevents segmentation faults and other nasty behavior during shutdown. - let (prover, _prover_context) = CircuitProver::new( + tasks.extend(light_wvg_runner.run()); + tasks.extend(heavy_wvg_runner.run()); + + // necessary as it has a connection_pool which will keep 1 connection active by default + drop(builder); + + let circuit_prover_runner = circuit_prover_runner( connection_pool, object_store, PROVER_PROTOCOL_SEMANTIC_VERSION, - receiver, - opt.max_allocation, setup_data_cache, - ) - .context("failed to create circuit prover")?; - tasks.push(tokio::spawn(prover.run(cancellation_token.clone()))); + witness_vector_receiver, + prover_context, + ); + + tasks.extend(circuit_prover_runner.run()); let mut tasks = ManagedTasks::new(tasks); tokio::select! { @@ -114,12 +136,15 @@ async fn main() -> anyhow::Result<()> { } } } - PROVER_BINARY_METRICS.run_time.observe(time.elapsed()); - tasks.complete(Duration::from_secs(5)).await; + let shutdown_time = Instant::now(); + tasks.complete(GRACEFUL_SHUTDOWN_DURATION).await; + PROVER_BINARY_METRICS + .shutdown_time + .observe(shutdown_time.elapsed()); + PROVER_BINARY_METRICS.run_time.observe(start_time.elapsed()); Ok(()) } - /// Loads configs necessary for proving. /// - observability config - for observability setup /// - prover config - necessary for setup data @@ -143,20 +168,21 @@ fn load_configs( tracing::info!("Loaded configs."); Ok((observability_config, prover_config, object_store_config)) } - /// Loads resources necessary for proving. /// - connection pool - necessary to pick & store jobs from database /// - object store - necessary for loading and storing artifacts to object store +/// - prover context - necessary for circuit proving; VRAM allocation /// - setup data - necessary for circuit proving /// - finalization hints - necessary for generating witness vectors async fn load_resources( secrets_path: Option, + max_gpu_vram_allocation: Option, object_store_config: ObjectStoreConfig, setup_data_path: PathBuf, - wvg_count: u32, ) -> anyhow::Result<( ConnectionPool, Arc, + ProverContext, SetupDataCache, FinalizationHintsCache, )> { @@ -165,9 +191,8 @@ async fn load_resources( let database_url = database_secrets .prover_url .context("no prover DB URl present")?; - - // 1 connection for the prover and one for each vector generator - let max_connections = 1 + wvg_count; + // 2 connections for the witness vector generator job pickers (1 each) and 1 for gpu circuit prover job saver + let max_connections = 3; let connection_pool = ConnectionPool::::builder(database_url, max_connections) .build() .await @@ -178,23 +203,34 @@ async fn load_resources( .await .context("failed to create object store")?; - tracing::info!("Loading mappings from disk..."); + let prover_context = match max_gpu_vram_allocation { + Some(max_allocation) => ProverContext::create_with_config( + ProverContextConfig::default().with_maximum_device_allocation(max_allocation), + ) + .context("failed initializing fixed gpu prover context")?, + None => ProverContext::create().context("failed initializing gpu prover context")?, + }; + + tracing::info!("Loading setup data from disk..."); let keystore = Keystore::locate().with_setup_path(Some(setup_data_path)); let setup_data_cache = keystore .load_all_setup_key_mapping() .await .context("failed to load setup key mapping")?; + + tracing::info!("Loading finalization hints from disk..."); let finalization_hints = keystore .load_all_finalization_hints_mapping() .await .context("failed to load finalization hints mapping")?; - tracing::info!("Loaded mappings from disk."); + tracing::info!("Finished loading mappings from disk."); Ok(( connection_pool, object_store, + prover_context, setup_data_cache, finalization_hints, )) diff --git a/prover/crates/bin/circuit_prover/src/metrics.rs b/prover/crates/bin/circuit_prover/src/metrics.rs index e9f445914795..f9b8c38e3e34 100644 --- a/prover/crates/bin/circuit_prover/src/metrics.rs +++ b/prover/crates/bin/circuit_prover/src/metrics.rs @@ -2,79 +2,20 @@ use std::time::Duration; use vise::{Buckets, Histogram, Metrics}; +/// Instrument prover binary lifecycle #[derive(Debug, Metrics)] #[metrics(prefix = "prover_binary")] pub struct ProverBinaryMetrics { /// How long does it take for prover to load data before it can produce proofs? #[metrics(buckets = Buckets::LATENCIES)] - pub start_up: Histogram, - /// How long has the prover been running? + pub startup_time: Histogram, + /// How long did the prover binary run for? #[metrics(buckets = Buckets::LATENCIES)] pub run_time: Histogram, -} - -#[vise::register] -pub static PROVER_BINARY_METRICS: vise::Global = vise::Global::new(); - -#[derive(Debug, Metrics)] -#[metrics(prefix = "witness_vector_generator")] -pub struct WitnessVectorGeneratorMetrics { - /// How long does witness vector generator waits before a job is available? - #[metrics(buckets = Buckets::LATENCIES)] - pub job_wait_time: Histogram, - /// How long does it take to load object store artifacts for a witness vector job? - #[metrics(buckets = Buckets::LATENCIES)] - pub artifact_download_time: Histogram, - /// How long does the crypto witness generation primitive take? - #[metrics(buckets = Buckets::LATENCIES)] - pub crypto_primitive_time: Histogram, - /// How long does it take for a job to be executed, from the moment it's loaded? - #[metrics(buckets = Buckets::LATENCIES)] - pub execution_time: Histogram, - /// How long does it take to send a job to prover? - /// This is relevant because prover queue can apply back-pressure. - #[metrics(buckets = Buckets::LATENCIES)] - pub send_time: Histogram, - /// How long does it take for a job to be considered finished, from the moment it's been loaded? - #[metrics(buckets = Buckets::LATENCIES)] - pub job_finished_time: Histogram, -} - -#[vise::register] -pub static WITNESS_VECTOR_GENERATOR_METRICS: vise::Global = - vise::Global::new(); - -#[derive(Debug, Metrics)] -#[metrics(prefix = "circuit_prover")] -pub struct CircuitProverMetrics { - /// How long does circuit prover wait before a job is available? - #[metrics(buckets = Buckets::LATENCIES)] - pub job_wait_time: Histogram, - /// How long does the crypto primitives (proof generation & verification) take? - #[metrics(buckets = Buckets::LATENCIES)] - pub crypto_primitives_time: Histogram, - /// How long does proof generation (crypto primitive) take? - #[metrics(buckets = Buckets::LATENCIES)] - pub generate_proof_time: Histogram, - /// How long does verify proof (crypto primitive) take? + /// How long does it take prover to gracefully shutdown? #[metrics(buckets = Buckets::LATENCIES)] - pub verify_proof_time: Histogram, - /// How long does it take for a job to be executed, from the moment it's loaded? - #[metrics(buckets = Buckets::LATENCIES)] - pub execution_time: Histogram, - /// How long does it take to upload proof to object store? - #[metrics(buckets = Buckets::LATENCIES)] - pub artifact_upload_time: Histogram, - /// How long does it take to save a job? - #[metrics(buckets = Buckets::LATENCIES)] - pub save_time: Histogram, - /// How long does it take for a job to be considered finished, from the moment it's been loaded? - #[metrics(buckets = Buckets::LATENCIES)] - pub job_finished_time: Histogram, - /// How long does it take a job to go from witness generation to having the proof saved? - #[metrics(buckets = Buckets::LATENCIES)] - pub full_proving_time: Histogram, + pub shutdown_time: Histogram, } #[vise::register] -pub static CIRCUIT_PROVER_METRICS: vise::Global = vise::Global::new(); +pub static PROVER_BINARY_METRICS: vise::Global = vise::Global::new(); diff --git a/prover/crates/bin/circuit_prover/src/types.rs b/prover/crates/bin/circuit_prover/src/types.rs index 52cdd48b6b50..e4e1fdc13b8f 100644 --- a/prover/crates/bin/circuit_prover/src/types.rs +++ b/prover/crates/bin/circuit_prover/src/types.rs @@ -1,31 +1,12 @@ use std::{collections::HashMap, sync::Arc}; use zksync_prover_fri_types::{ - circuit_definitions::boojum::{ - algebraic_props::{ - round_function::AbsorptionModeOverwrite, sponge::GoldilocksPoseidon2Sponge, - }, - cs::implementations::{ - proof::Proof as CryptoProof, setup::FinalizationHintsForProver, - transcript::GoldilocksPoisedon2Transcript, - verifier::VerificationKey as CryptoVerificationKey, - }, - field::goldilocks::{GoldilocksExt2, GoldilocksField}, - }, + circuit_definitions::boojum::cs::implementations::setup::FinalizationHintsForProver, ProverServiceDataKey, }; use zksync_prover_keystore::GoldilocksGpuProverSetupData; -// prover types -pub type DefaultTranscript = GoldilocksPoisedon2Transcript; -pub type DefaultTreeHasher = GoldilocksPoseidon2Sponge; - -type F = GoldilocksField; -type H = GoldilocksPoseidon2Sponge; -type Ext = GoldilocksExt2; -pub type Proof = CryptoProof; -pub type VerificationKey = CryptoVerificationKey; - +// TODO: To be moved to circuit_prover_service lib & adjusted to new type idiom // cache types pub type SetupDataCache = HashMap>; pub type FinalizationHintsCache = HashMap>; diff --git a/prover/crates/bin/circuit_prover/src/witness_vector_generator.rs b/prover/crates/bin/circuit_prover/src/witness_vector_generator.rs deleted file mode 100644 index cb2d2a256df9..000000000000 --- a/prover/crates/bin/circuit_prover/src/witness_vector_generator.rs +++ /dev/null @@ -1,345 +0,0 @@ -use std::{collections::HashMap, sync::Arc, time::Instant}; - -use anyhow::Context; -use tokio::{sync::mpsc::Sender, task::JoinHandle}; -use tokio_util::sync::CancellationToken; -use zksync_object_store::ObjectStore; -use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; -use zksync_prover_fri_types::{ - circuit_definitions::{ - boojum::{ - cs::implementations::setup::FinalizationHintsForProver, - field::goldilocks::GoldilocksField, - gadgets::queue::full_state_queue::FullStateCircuitQueueRawWitness, - }, - circuit_definitions::base_layer::ZkSyncBaseLayerCircuit, - }, - get_current_pod_name, - keys::RamPermutationQueueWitnessKey, - CircuitAuxData, CircuitWrapper, ProverJob, ProverServiceDataKey, RamPermutationQueueWitness, - WitnessVectorArtifactsTemp, -}; -use zksync_types::{protocol_version::ProtocolSemanticVersion, L1BatchNumber}; -use zksync_utils::panic_extractor::try_extract_panic_message; - -use crate::{metrics::WITNESS_VECTOR_GENERATOR_METRICS, Backoff, FinalizationHintsCache}; - -/// In charge of generating Witness Vectors and sending them to Circuit Prover. -/// Both job runner & job executor. -#[derive(Debug)] -pub struct WitnessVectorGenerator { - object_store: Arc, - connection_pool: ConnectionPool, - protocol_version: ProtocolSemanticVersion, - /// Finalization Hints used for Witness Vector generation - finalization_hints_cache: FinalizationHintsCache, - /// Witness Vector sender for Circuit Prover - sender: Sender, - pod_name: String, -} - -impl WitnessVectorGenerator { - pub fn new( - object_store: Arc, - connection_pool: ConnectionPool, - protocol_version: ProtocolSemanticVersion, - sender: Sender, - finalization_hints: HashMap>, - ) -> Self { - Self { - object_store, - connection_pool, - protocol_version, - finalization_hints_cache: finalization_hints, - sender, - pod_name: get_current_pod_name(), - } - } - - /// Continuously polls database for new prover jobs and generates witness vectors for them. - /// All job executions are persisted. - pub async fn run( - self, - cancellation_token: CancellationToken, - mut backoff: Backoff, - ) -> anyhow::Result<()> { - let mut get_job_timer = Instant::now(); - while !cancellation_token.is_cancelled() { - if let Some(prover_job) = self - .get_job() - .await - .context("failed to get next witness generation job")? - { - tracing::info!( - "Witness Vector Generator received job {:?} after: {:?}", - prover_job.job_id, - get_job_timer.elapsed() - ); - WITNESS_VECTOR_GENERATOR_METRICS - .job_wait_time - .observe(get_job_timer.elapsed()); - if let e @ Err(_) = self.generate(prover_job, cancellation_token.clone()).await { - // this means that the witness vector receiver is closed, no need to report the error, just return - if cancellation_token.is_cancelled() { - return Ok(()); - } - e.context("failed to generate witness")? - } - - // waiting for a job timer starts as soon as the other is finished - get_job_timer = Instant::now(); - backoff.reset(); - continue; - }; - self.backoff(&mut backoff, cancellation_token.clone()).await; - } - tracing::info!("Witness Vector Generator shut down."); - Ok(()) - } - - /// Retrieves a prover job from database, loads artifacts from object store and hydrates them. - async fn get_job(&self) -> anyhow::Result> { - let mut connection = self - .connection_pool - .connection() - .await - .context("failed to get db connection")?; - let prover_job_metadata = match connection - .fri_prover_jobs_dal() - .get_job(self.protocol_version, &self.pod_name) - .await - { - None => return Ok(None), - Some(job) => job, - }; - - let time = Instant::now(); - let circuit_wrapper = self - .object_store - .get(prover_job_metadata.into()) - .await - .context("failed to get circuit_wrapper from object store")?; - let artifact = match circuit_wrapper { - a @ CircuitWrapper::Base(_) => a, - a @ CircuitWrapper::Recursive(_) => a, - CircuitWrapper::BasePartial((circuit, aux_data)) => self - .fill_witness(circuit, aux_data, prover_job_metadata.block_number) - .await - .context("failed to fill witness")?, - }; - WITNESS_VECTOR_GENERATOR_METRICS - .artifact_download_time - .observe(time.elapsed()); - - let setup_data_key = ProverServiceDataKey { - circuit_id: prover_job_metadata.circuit_id, - round: prover_job_metadata.aggregation_round, - } - .crypto_setup_key(); - let prover_job = ProverJob::new( - prover_job_metadata.block_number, - prover_job_metadata.id, - artifact, - setup_data_key, - ); - Ok(Some(prover_job)) - } - - /// Prover artifact hydration. - async fn fill_witness( - &self, - circuit: ZkSyncBaseLayerCircuit, - aux_data: CircuitAuxData, - l1_batch_number: L1BatchNumber, - ) -> anyhow::Result { - if let ZkSyncBaseLayerCircuit::RAMPermutation(circuit_instance) = circuit { - let sorted_witness_key = RamPermutationQueueWitnessKey { - block_number: l1_batch_number, - circuit_subsequence_number: aux_data.circuit_subsequence_number as usize, - is_sorted: true, - }; - let sorted_witness: RamPermutationQueueWitness = self - .object_store - .get(sorted_witness_key) - .await - .context("failed to load sorted witness key")?; - - let unsorted_witness_key = RamPermutationQueueWitnessKey { - block_number: l1_batch_number, - circuit_subsequence_number: aux_data.circuit_subsequence_number as usize, - is_sorted: false, - }; - let unsorted_witness: RamPermutationQueueWitness = self - .object_store - .get(unsorted_witness_key) - .await - .context("failed to load unsorted witness key")?; - - let mut witness = circuit_instance.witness.take().unwrap(); - witness.unsorted_queue_witness = FullStateCircuitQueueRawWitness { - elements: unsorted_witness.witness.into(), - }; - witness.sorted_queue_witness = FullStateCircuitQueueRawWitness { - elements: sorted_witness.witness.into(), - }; - circuit_instance.witness.store(Some(witness)); - - return Ok(CircuitWrapper::Base( - ZkSyncBaseLayerCircuit::RAMPermutation(circuit_instance), - )); - } - Err(anyhow::anyhow!( - "unexpected circuit received with partial witness, expected RAM permutation, got {:?}", - circuit.short_description() - )) - } - - /// Generates witness vector, with persistence of execution. - async fn generate( - &self, - prover_job: ProverJob, - cancellation_token: CancellationToken, - ) -> anyhow::Result<()> { - let start_time = Instant::now(); - let finalization_hints = self - .finalization_hints_cache - .get(&prover_job.setup_data_key) - .context(format!( - "failed to get finalization hints for key {:?}", - &prover_job.setup_data_key - ))? - .clone(); - let job_id = prover_job.job_id; - let task = tokio::task::spawn_blocking(move || { - let block_number = prover_job.block_number; - let _span = tracing::info_span!("witness_vector_generator", %block_number).entered(); - Self::generate_witness_vector(prover_job, finalization_hints) - }); - - self.finish_task(job_id, start_time, task, cancellation_token.clone()) - .await?; - - tracing::info!( - "Witness Vector Generator finished job {:?} in: {:?}", - job_id, - start_time.elapsed() - ); - WITNESS_VECTOR_GENERATOR_METRICS - .job_finished_time - .observe(start_time.elapsed()); - Ok(()) - } - - /// Generates witness vector using crypto primitives. - #[tracing::instrument( - skip_all, - fields(l1_batch = % prover_job.block_number) - )] - pub fn generate_witness_vector( - prover_job: ProverJob, - finalization_hints: Arc, - ) -> anyhow::Result { - let time = Instant::now(); - let cs = match prover_job.circuit_wrapper.clone() { - CircuitWrapper::Base(base_circuit) => { - base_circuit.synthesis::(&finalization_hints) - } - CircuitWrapper::Recursive(recursive_circuit) => { - recursive_circuit.synthesis::(&finalization_hints) - } - // circuit must be hydrated during `get_job` - CircuitWrapper::BasePartial(_) => { - return Err(anyhow::anyhow!("received unexpected dehydrated proof")); - } - }; - WITNESS_VECTOR_GENERATOR_METRICS - .crypto_primitive_time - .observe(time.elapsed()); - Ok(WitnessVectorArtifactsTemp::new( - cs.witness.unwrap(), - prover_job, - time, - )) - } - - /// Runs task to completion and persists result. - /// NOTE: Task may be cancelled mid-flight. - async fn finish_task( - &self, - job_id: u32, - time: Instant, - task: JoinHandle>, - cancellation_token: CancellationToken, - ) -> anyhow::Result<()> { - tokio::select! { - _ = cancellation_token.cancelled() => { - tracing::info!("Stop signal received, shutting down Witness Vector Generator..."); - return Ok(()) - } - result = task => { - let error_message = match result { - Ok(Ok(witness_vector)) => { - tracing::info!("Witness Vector Generator executed job {:?} in: {:?}", job_id, time.elapsed()); - WITNESS_VECTOR_GENERATOR_METRICS.execution_time.observe(time.elapsed()); - self - .save_result(witness_vector, job_id) - .await - .context("failed to save result")?; - return Ok(()) - } - Ok(Err(error)) => error.to_string(), - Err(error) => try_extract_panic_message(error), - }; - tracing::error!("Witness Vector Generator failed on job {job_id:?} with error {error_message:?}"); - - self.save_failure(job_id, error_message).await.context("failed to save failure")?; - } - } - - Ok(()) - } - - /// Sends proof to Circuit Prover. - async fn save_result( - &self, - artifacts: WitnessVectorArtifactsTemp, - job_id: u32, - ) -> anyhow::Result<()> { - let time = Instant::now(); - self.sender - .send(artifacts) - .await - .context("failed to send witness vector to prover")?; - tracing::info!( - "Witness Vector Generator sent job {:?} after {:?}", - job_id, - time.elapsed() - ); - WITNESS_VECTOR_GENERATOR_METRICS - .send_time - .observe(time.elapsed()); - Ok(()) - } - - /// Persists job execution error to database - async fn save_failure(&self, job_id: u32, error: String) -> anyhow::Result<()> { - self.connection_pool - .connection() - .await - .context("failed to get db connection")? - .fri_prover_jobs_dal() - .save_proof_error(job_id, error) - .await; - Ok(()) - } - - /// Backs off, whilst being cancellation aware. - async fn backoff(&self, backoff: &mut Backoff, cancellation_token: CancellationToken) { - let backoff_duration = backoff.delay(); - tracing::info!("Backing off for {:?}...", backoff_duration); - // Error here corresponds to a timeout w/o receiving task cancel; we're OK with this. - tokio::time::timeout(backoff_duration, cancellation_token.cancelled()) - .await - .ok(); - } -} diff --git a/prover/crates/lib/circuit_prover_service/Cargo.toml b/prover/crates/lib/circuit_prover_service/Cargo.toml new file mode 100644 index 000000000000..ca7d1ede02f1 --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "zksync_circuit_prover_service" +description = "ZKsync circuit prover service implementation" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true + +[dependencies] +zksync_prover_job_processor.workspace = true +zksync_prover_fri_types.workspace = true +zksync_prover_keystore.workspace = true +zksync_prover_dal.workspace = true +zksync_types.workspace = true +zksync_object_store.workspace = true + +async-trait.workspace = true +anyhow.workspace = true +tokio = { workspace = true, features = ["macros", "time"] } +tokio-util.workspace = true +tracing.workspace = true + +shivini = { workspace = true, features = [ + "circuit_definitions", +] } +zkevm_test_harness.workspace = true +vise.workspace = true diff --git a/prover/crates/lib/circuit_prover_service/README.md b/prover/crates/lib/circuit_prover_service/README.md new file mode 100644 index 000000000000..3cc8a80e966d --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/README.md @@ -0,0 +1,96 @@ +# Circuit Prover Service + +This crate provides the building blocks for running circuit provers. Circuit proving is the heaviest part of the proving +process, being both the most time intensive and resource heavy part. + +The primitives exported by this lib are job runners, namely: + +- light_wvg_runner +- heavy_wvg_runner +- circuit_prover_runner + +The rest of the codebase simply covers the internals of creating a runner, which is an implementation of +`ProverJobProcessor`. + +## Witness Vector Generator Runner + +Runners related to synthesizing Witness Vector (the CPU heavy part of circuit proving). They are tied to +`prover_jobs_fri` table and operate over `ProverJobsFri` object storage bucket. + +Witness Vector Generators have big gaps in resource usages. Node proofs are the heavy jobs (~9GB RAM), whilst all others +are rather light (~2GB RAM). + +There are 2 ways to deal with this: + +1. run RAM left over / 9 which will result in RAM under utilization but simplify implementation +2. run multiple light WVG jobs, with a small amount of heavy WVG jobs. + +This implementation favors number 2. As such, `MetadataLoader` abstraction was introduced to force loading lighter and +heavier jobs. Heavier picker will try to prioritize nodes. If none are available, it falls back to light jobs in order +to maximize usage. + +### Job Picker + +Interacts with the database to get a job (as described above), loads the data from object store and then hydrates the +circuit. In current implementation, Ram Permutation circuits are sent separately in order to save RAM in basic witness +generation & reduce the amount of storage used by object store. A further optimization will be introduced later on, +which will remove the necessity of witness hydration on circuits. + +### Executor + +Straight forward, synthesizes witness vector from circuit. + +### Job Saver + +If successful, will provide data to GPU circuit prover over a channel. If it fails, will mark the database as such and +will later be retried (as marked by Prover Job Monitor). + +## GPU Circuit Prover + +Runners related to generating the circuit proof & verifying it. They are tied to `prover_jobs_fri` table and operate +over `ProverJobs` object storage bucket. + +### Job Picker + +Waits on information from (multiple) WVGs sent via a channel. + +### Executor + +Generates & verifies the circuit proof (on GPU). + +### Job Saver + +Persists information back to `prover_jobs_fri` table. Note that a job is picked by WVG & finished by CP. + +## Diagram + +```mermaid +sequenceDiagram + box Resources + participant db as Database + participant os as Object Store + end + box Heavy/Light Witness Vector Generator + participant wvg_p as Job Picker + participant wvg_e as Executor + participant wvg_s as Job Saver + end + box Circuit Prover + participant cp_p as Job Picker + participant cp_e as Executor + participant cp_s as Job Saver + end + wvg_p-->>db: Get job metadata + wvg_p-->>os: Get circuit + wvg_p-->>wvg_p: Hydrate circuit & get finalization hints + wvg_p-->>wvg_e: Provide metadata & circuit + wvg_e-->>wvg_e: Synthesize witness vector + wvg_e-->>wvg_s: Provide metadata & witness vector & circuit + wvg_s-->>cp_p: Provide metadata & witness vector & circuit + cp_p-->>cp_p: Get setup data + cp_p-->>cp_e: Provide metadata & witness vector & circuit + cp_e-->>cp_e: Prove & verify circuit proof + cp_e-->>cp_s: Provide metadata & proof + cp_s-->>os: Save proof + cp_s-->>db: Update job metadata +``` diff --git a/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/gpu_circuit_prover_executor.rs b/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/gpu_circuit_prover_executor.rs new file mode 100644 index 000000000000..043232a5003c --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/gpu_circuit_prover_executor.rs @@ -0,0 +1,73 @@ +use std::time::Instant; + +use anyhow::Context; +use shivini::ProverContext; +use zksync_prover_fri_types::FriProofWrapper; +use zksync_prover_job_processor::Executor; +use zksync_types::prover_dal::FriProverJobMetadata; + +use crate::{ + metrics::CIRCUIT_PROVER_METRICS, types::circuit_prover_payload::GpuCircuitProverPayload, +}; + +/// GpuCircuitProver executor implementation. +/// Generates circuit proof & verifies it. +/// NOTE: It requires prover context, which is the way Shivini allocates VRAM. +pub struct GpuCircuitProverExecutor { + _prover_context: ProverContext, +} + +impl GpuCircuitProverExecutor { + pub fn new(prover_context: ProverContext) -> Self { + Self { + _prover_context: prover_context, + } + } +} + +impl Executor for GpuCircuitProverExecutor { + type Input = GpuCircuitProverPayload; + type Output = FriProofWrapper; + type Metadata = FriProverJobMetadata; + + #[tracing::instrument( + name = "gpu_circuit_prover_executor", + skip_all, + fields(l1_batch = % metadata.block_number) + )] + fn execute( + &self, + input: Self::Input, + metadata: Self::Metadata, + ) -> anyhow::Result { + let start_time = Instant::now(); + tracing::info!( + "Started executing gpu circuit prover job {}, on batch {}, for circuit {}, at round {}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round + ); + let GpuCircuitProverPayload { + circuit, + witness_vector, + setup_data, + } = input; + + let proof_wrapper = circuit + .prove(witness_vector, setup_data) + .context("failed to gpu prove circuit")?; + tracing::info!( + "Finished executing gpu circuit prover job {}, on batch {}, for circuit {}, at round {} after {:?}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round, + start_time.elapsed() + ); + CIRCUIT_PROVER_METRICS + .prove_and_verify_time + .observe(start_time.elapsed()); + Ok(proof_wrapper) + } +} diff --git a/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/gpu_circuit_prover_job_picker.rs b/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/gpu_circuit_prover_job_picker.rs new file mode 100644 index 000000000000..76dc0cda66d3 --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/gpu_circuit_prover_job_picker.rs @@ -0,0 +1,92 @@ +use std::{collections::HashMap, sync::Arc, time::Instant}; + +use anyhow::Context; +use async_trait::async_trait; +use zksync_prover_fri_types::ProverServiceDataKey; +use zksync_prover_job_processor::JobPicker; +use zksync_prover_keystore::GoldilocksGpuProverSetupData; +use zksync_types::prover_dal::FriProverJobMetadata; + +use crate::{ + gpu_circuit_prover::GpuCircuitProverExecutor, + metrics::CIRCUIT_PROVER_METRICS, + types::{ + circuit_prover_payload::GpuCircuitProverPayload, + witness_vector_generator_execution_output::WitnessVectorGeneratorExecutionOutput, + }, +}; + +/// GpuCircuitProver job picker implementation. +/// Retrieves job & data from WVG job saver. +#[derive(Debug)] +pub struct GpuCircuitProverJobPicker { + receiver: + tokio::sync::mpsc::Receiver<(WitnessVectorGeneratorExecutionOutput, FriProverJobMetadata)>, + setup_data_cache: HashMap>, +} + +impl GpuCircuitProverJobPicker { + pub fn new( + receiver: tokio::sync::mpsc::Receiver<( + WitnessVectorGeneratorExecutionOutput, + FriProverJobMetadata, + )>, + setup_data_cache: HashMap>, + ) -> Self { + Self { + receiver, + setup_data_cache, + } + } +} + +#[async_trait] +impl JobPicker for GpuCircuitProverJobPicker { + type ExecutorType = GpuCircuitProverExecutor; + + async fn pick_job( + &mut self, + ) -> anyhow::Result> { + let start_time = Instant::now(); + tracing::info!("Started picking gpu circuit prover job"); + + let (wvg_output, metadata) = self + .receiver + .recv() + .await + .context("no witness vector generators are available, stopping...")?; + let WitnessVectorGeneratorExecutionOutput { + circuit, + witness_vector, + } = wvg_output; + + let key = ProverServiceDataKey { + circuit_id: metadata.circuit_id, + round: metadata.aggregation_round, + } + .crypto_setup_key(); + let setup_data = self + .setup_data_cache + .get(&key) + .context("failed to retrieve setup data from cache")? + .clone(); + + let payload = GpuCircuitProverPayload { + circuit, + witness_vector, + setup_data, + }; + tracing::info!( + "Finished picking gpu circuit prover job {}, on batch {}, for circuit {}, at round {} in {:?}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round, + start_time.elapsed() + ); + CIRCUIT_PROVER_METRICS + .load_time + .observe(start_time.elapsed()); + Ok(Some((payload, metadata))) + } +} diff --git a/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/gpu_circuit_prover_job_saver.rs b/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/gpu_circuit_prover_job_saver.rs new file mode 100644 index 000000000000..0ba28a0d9f5a --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/gpu_circuit_prover_job_saver.rs @@ -0,0 +1,126 @@ +use std::{sync::Arc, time::Instant}; + +use anyhow::Context; +use async_trait::async_trait; +use zksync_object_store::ObjectStore; +use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; +use zksync_prover_fri_types::FriProofWrapper; +use zksync_prover_job_processor::JobSaver; +use zksync_types::{protocol_version::ProtocolSemanticVersion, prover_dal::FriProverJobMetadata}; + +use crate::{gpu_circuit_prover::GpuCircuitProverExecutor, metrics::CIRCUIT_PROVER_METRICS}; + +/// GpuCircuitProver job saver implementation. +/// Persists the job execution to database. In case of success, artifacts are uploaded to object store. +#[derive(Debug)] +pub struct GpuCircuitProverJobSaver { + connection_pool: ConnectionPool, + object_store: Arc, + protocol_version: ProtocolSemanticVersion, +} + +impl GpuCircuitProverJobSaver { + pub fn new( + connection_pool: ConnectionPool, + object_store: Arc, + protocol_version: ProtocolSemanticVersion, + ) -> Self { + Self { + connection_pool, + object_store, + protocol_version, + } + } +} + +#[async_trait] +impl JobSaver for GpuCircuitProverJobSaver { + type ExecutorType = GpuCircuitProverExecutor; + + #[tracing::instrument( + name = "gpu_circuit_prover_job_saver", + skip_all, + fields(l1_batch = % data.1.block_number) + )] + async fn save_job_result( + &self, + data: (anyhow::Result, FriProverJobMetadata), + ) -> anyhow::Result<()> { + let start_time = Instant::now(); + let (result, metadata) = data; + tracing::info!( + "Started saving gpu circuit prover job {}, on batch {}, for circuit {}, at round {}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round + ); + + match result { + Ok(proof_wrapper) => { + let mut connection = self + .connection_pool + .connection() + .await + .context("failed to get db connection")?; + + let is_scheduler_proof = metadata.is_scheduler_proof()?; + + let blob_url = self + .object_store + .put(metadata.id, &proof_wrapper) + .await + .context("failed to upload to object store")?; + + let mut transaction = connection + .start_transaction() + .await + .context("failed to start db transaction")?; + transaction + .fri_prover_jobs_dal() + .save_proof(metadata.id, metadata.pick_time.elapsed(), &blob_url) + .await; + if is_scheduler_proof { + transaction + .fri_proof_compressor_dal() + .insert_proof_compression_job( + metadata.block_number, + &blob_url, + self.protocol_version, + ) + .await; + } + transaction + .commit() + .await + .context("failed to commit db transaction")?; + } + Err(error) => { + let error_message = error.to_string(); + tracing::error!("GPU circuit prover failed: {:?}", error_message); + self.connection_pool + .connection() + .await + .context("failed to get db connection")? + .fri_prover_jobs_dal() + .save_proof_error(metadata.id, error_message) + .await; + } + }; + tracing::info!( + "Finished saving gpu circuit prover job {}, on batch {}, for circuit {}, at round {} after {:?}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round, + start_time.elapsed() + ); + CIRCUIT_PROVER_METRICS + .save_time + .observe(start_time.elapsed()); + CIRCUIT_PROVER_METRICS + .full_time + .observe(metadata.pick_time.elapsed()); + Ok(()) + } +} diff --git a/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/mod.rs b/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/mod.rs new file mode 100644 index 000000000000..7dff12aa2cc6 --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/gpu_circuit_prover/mod.rs @@ -0,0 +1,8 @@ +pub use gpu_circuit_prover_executor::GpuCircuitProverExecutor; +pub use gpu_circuit_prover_job_picker::GpuCircuitProverJobPicker; +pub use gpu_circuit_prover_job_saver::GpuCircuitProverJobSaver; + +mod gpu_circuit_prover_executor; + +mod gpu_circuit_prover_job_picker; +mod gpu_circuit_prover_job_saver; diff --git a/prover/crates/lib/circuit_prover_service/src/job_runner.rs b/prover/crates/lib/circuit_prover_service/src/job_runner.rs new file mode 100644 index 000000000000..2e102fd40e33 --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/job_runner.rs @@ -0,0 +1,144 @@ +use std::{collections::HashMap, sync::Arc}; + +use shivini::ProverContext; +use tokio_util::sync::CancellationToken; +use zksync_object_store::ObjectStore; +use zksync_prover_dal::{ConnectionPool, Prover}; +use zksync_prover_fri_types::{ + circuit_definitions::boojum::cs::implementations::setup::FinalizationHintsForProver, + get_current_pod_name, ProverServiceDataKey, +}; +use zksync_prover_job_processor::{Backoff, BackoffAndCancellable, JobRunner}; +use zksync_prover_keystore::GoldilocksGpuProverSetupData; +use zksync_types::{protocol_version::ProtocolSemanticVersion, prover_dal::FriProverJobMetadata}; + +use crate::{ + gpu_circuit_prover::{ + GpuCircuitProverExecutor, GpuCircuitProverJobPicker, GpuCircuitProverJobSaver, + }, + types::witness_vector_generator_execution_output::WitnessVectorGeneratorExecutionOutput, + witness_vector_generator::{ + HeavyWitnessVectorMetadataLoader, LightWitnessVectorMetadataLoader, + WitnessVectorGeneratorExecutor, WitnessVectorGeneratorJobPicker, + WitnessVectorGeneratorJobSaver, WitnessVectorMetadataLoader, + }, +}; + +/// Convenience struct helping with building Witness Vector Generator runners. +#[derive(Debug)] +pub struct WvgRunnerBuilder { + connection_pool: ConnectionPool, + object_store: Arc, + protocol_version: ProtocolSemanticVersion, + finalization_hints_cache: HashMap>, + sender: + tokio::sync::mpsc::Sender<(WitnessVectorGeneratorExecutionOutput, FriProverJobMetadata)>, + cancellation_token: CancellationToken, + pod_name: String, +} + +impl WvgRunnerBuilder { + pub fn new( + connection_pool: ConnectionPool, + object_store: Arc, + protocol_version: ProtocolSemanticVersion, + finalization_hints_cache: HashMap>, + sender: tokio::sync::mpsc::Sender<( + WitnessVectorGeneratorExecutionOutput, + FriProverJobMetadata, + )>, + cancellation_token: CancellationToken, + ) -> Self { + Self { + connection_pool, + object_store, + protocol_version, + finalization_hints_cache, + sender, + cancellation_token, + pod_name: get_current_pod_name(), + } + } + + /// Witness Vector Generator runner implementation for light jobs. + pub fn light_wvg_runner( + &self, + count: usize, + ) -> JobRunner< + WitnessVectorGeneratorExecutor, + WitnessVectorGeneratorJobPicker, + WitnessVectorGeneratorJobSaver, + > { + let metadata_loader = + LightWitnessVectorMetadataLoader::new(self.pod_name.clone(), self.protocol_version); + + self.wvg_runner(count, metadata_loader) + } + + /// Witness Vector Generator runner implementation that prioritizes heavy jobs over light jobs. + pub fn heavy_wvg_runner( + &self, + count: usize, + ) -> JobRunner< + WitnessVectorGeneratorExecutor, + WitnessVectorGeneratorJobPicker, + WitnessVectorGeneratorJobSaver, + > { + let metadata_loader = + HeavyWitnessVectorMetadataLoader::new(self.pod_name.clone(), self.protocol_version); + + self.wvg_runner(count, metadata_loader) + } + + /// Creates a Witness Vector Generator job runner with specified MetadataLoader. + /// The MetadataLoader makes the difference between heavy & light WVG runner. + fn wvg_runner( + &self, + count: usize, + metadata_loader: ML, + ) -> JobRunner< + WitnessVectorGeneratorExecutor, + WitnessVectorGeneratorJobPicker, + WitnessVectorGeneratorJobSaver, + > { + let executor = WitnessVectorGeneratorExecutor; + let job_picker = WitnessVectorGeneratorJobPicker::new( + self.connection_pool.clone(), + self.object_store.clone(), + self.finalization_hints_cache.clone(), + metadata_loader, + ); + let job_saver = + WitnessVectorGeneratorJobSaver::new(self.connection_pool.clone(), self.sender.clone()); + let backoff = Backoff::default(); + + JobRunner::new( + executor, + job_picker, + job_saver, + count, + Some(BackoffAndCancellable::new( + backoff, + self.cancellation_token.clone(), + )), + ) + } +} + +/// Circuit Prover runner implementation. +pub fn circuit_prover_runner( + connection_pool: ConnectionPool, + object_store: Arc, + protocol_version: ProtocolSemanticVersion, + setup_data_cache: HashMap>, + receiver: tokio::sync::mpsc::Receiver<( + WitnessVectorGeneratorExecutionOutput, + FriProverJobMetadata, + )>, + prover_context: ProverContext, +) -> JobRunner { + let executor = GpuCircuitProverExecutor::new(prover_context); + let job_picker = GpuCircuitProverJobPicker::new(receiver, setup_data_cache); + let job_saver = GpuCircuitProverJobSaver::new(connection_pool, object_store, protocol_version); + JobRunner::new(executor, job_picker, job_saver, 1, None) +} diff --git a/prover/crates/lib/circuit_prover_service/src/lib.rs b/prover/crates/lib/circuit_prover_service/src/lib.rs new file mode 100644 index 000000000000..0d7b146cc43b --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/lib.rs @@ -0,0 +1,7 @@ +#![allow(incomplete_features)] // Crypto code uses generic const exprs +#![feature(generic_const_exprs)] +mod gpu_circuit_prover; +pub mod job_runner; +mod metrics; +mod types; +mod witness_vector_generator; diff --git a/prover/crates/lib/circuit_prover_service/src/metrics.rs b/prover/crates/lib/circuit_prover_service/src/metrics.rs new file mode 100644 index 000000000000..c102422c4771 --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/metrics.rs @@ -0,0 +1,46 @@ +use std::time::Duration; + +use vise::{Buckets, Histogram, Metrics}; + +/// Metrics for witness vector generator execution +#[derive(Debug, Metrics)] +#[metrics(prefix = "witness_vector_generator")] +pub struct WitnessVectorGeneratorMetrics { + /// How long does it take to load witness vector inputs? + #[metrics(buckets = Buckets::LATENCIES)] + pub pick_time: Histogram, + /// How long does it take to synthesize witness vector? + #[metrics(buckets = Buckets::LATENCIES)] + pub synthesize_time: Histogram, + /// How long does it take to send witness vectors to gpu prover? + #[metrics(buckets = Buckets::LATENCIES)] + pub transfer_time: Histogram, + /// How long does it take to save witness vector failure? + #[metrics(buckets = Buckets::LATENCIES)] + pub save_time: Histogram, +} + +#[vise::register] +pub static WITNESS_VECTOR_GENERATOR_METRICS: vise::Global = + vise::Global::new(); + +/// Metrics for GPU circuit prover execution +#[derive(Debug, Metrics)] +#[metrics(prefix = "circuit_prover")] +pub struct CircuitProverMetrics { + /// How long does it take to load prover inputs? + #[metrics(buckets = Buckets::LATENCIES)] + pub load_time: Histogram, + /// How long does it take to prove & verify? + #[metrics(buckets = Buckets::LATENCIES)] + pub prove_and_verify_time: Histogram, + /// How long does it take to save prover results? + #[metrics(buckets = Buckets::LATENCIES)] + pub save_time: Histogram, + /// How long does it take finish a prover job from witness vector to circuit prover? + #[metrics(buckets = Buckets::LATENCIES)] + pub full_time: Histogram, +} + +#[vise::register] +pub static CIRCUIT_PROVER_METRICS: vise::Global = vise::Global::new(); diff --git a/prover/crates/lib/circuit_prover_service/src/types/circuit.rs b/prover/crates/lib/circuit_prover_service/src/types/circuit.rs new file mode 100644 index 000000000000..19c05666b2c5 --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/types/circuit.rs @@ -0,0 +1,152 @@ +use std::sync::Arc; + +use anyhow::Context; +use shivini::{gpu_proof_config::GpuProofConfig, gpu_prove_from_external_witness_data}; +use zkevm_test_harness::{ + boojum::cs::implementations::setup::FinalizationHintsForProver, + prover_utils::{verify_base_layer_proof, verify_recursion_layer_proof}, +}; +use zksync_prover_fri_types::{ + circuit_definitions::{ + base_layer_proof_config, + boojum::{ + algebraic_props::{ + round_function::AbsorptionModeOverwrite, sponge::GoldilocksPoseidon2Sponge, + }, + cs::implementations::{ + pow::NoPow, proof::Proof as CryptoProof, transcript::GoldilocksPoisedon2Transcript, + witness::WitnessVec, + }, + field::goldilocks::{GoldilocksExt2, GoldilocksField}, + worker::Worker, + }, + circuit_definitions::{ + base_layer::{ZkSyncBaseLayerCircuit, ZkSyncBaseLayerProof}, + recursion_layer::{ZkSyncRecursionLayerProof, ZkSyncRecursiveLayerCircuit}, + }, + recursion_layer_proof_config, + }, + FriProofWrapper, +}; +use zksync_prover_keystore::GoldilocksGpuProverSetupData; + +type Transcript = GoldilocksPoisedon2Transcript; +type Field = GoldilocksField; +type Hasher = GoldilocksPoseidon2Sponge; +type Extension = GoldilocksExt2; +type Proof = CryptoProof; + +/// Hydrated circuit. +/// Circuits are currently dehydrated for memory and storage reasons. +/// Circuits are hydrated on the flight where necessary. +// TODO: This enum will be merged with CircuitWrapper once BWG changes are done. +#[allow(clippy::large_enum_variant)] +pub enum Circuit { + Base(ZkSyncBaseLayerCircuit), + Recursive(ZkSyncRecursiveLayerCircuit), +} + +impl Circuit { + /// Generates proof for given witness vector. + /// Expects setup_data to match witness vector. + pub(crate) fn prove( + &self, + witness_vector: WitnessVec, + setup_data: Arc, + ) -> anyhow::Result { + let worker = Worker::new(); + + match self { + Circuit::Base(circuit) => { + let proof = Self::prove_base(circuit, witness_vector, setup_data, worker)?; + let circuit_id = circuit.numeric_circuit_type(); + Ok(FriProofWrapper::Base(ZkSyncBaseLayerProof::from_inner( + circuit_id, proof, + ))) + } + Circuit::Recursive(circuit) => { + let proof = Self::prove_recursive(circuit, witness_vector, setup_data, worker)?; + let circuit_id = circuit.numeric_circuit_type(); + Ok(FriProofWrapper::Recursive( + ZkSyncRecursionLayerProof::from_inner(circuit_id, proof), + )) + } + } + } + + /// Prove & verify base circuit. + fn prove_base( + circuit: &ZkSyncBaseLayerCircuit, + witness_vector: WitnessVec, + setup_data: Arc, + worker: Worker, + ) -> anyhow::Result { + let span = tracing::info_span!("prove_base_circuit").entered(); + let gpu_proof_config = GpuProofConfig::from_base_layer_circuit(circuit); + let boojum_proof_config = base_layer_proof_config(); + let proof = gpu_prove_from_external_witness_data::( + &gpu_proof_config, + &witness_vector, + boojum_proof_config, + &setup_data.setup, + &setup_data.vk, + (), + &worker, + ) + .context("failed to generate base proof")? + .into(); + drop(span); + let _span = tracing::info_span!("verify_base_circuit").entered(); + if !verify_base_layer_proof::(circuit, &proof, &setup_data.vk) { + return Err(anyhow::anyhow!("failed to verify base proof")); + } + Ok(proof) + } + + /// Prove & verify recursive circuit. + fn prove_recursive( + circuit: &ZkSyncRecursiveLayerCircuit, + witness_vector: WitnessVec, + setup_data: Arc, + worker: Worker, + ) -> anyhow::Result { + let span = tracing::info_span!("prove_recursive_circuit").entered(); + let gpu_proof_config = GpuProofConfig::from_recursive_layer_circuit(circuit); + let boojum_proof_config = recursion_layer_proof_config(); + let proof = gpu_prove_from_external_witness_data::( + &gpu_proof_config, + &witness_vector, + boojum_proof_config, + &setup_data.setup, + &setup_data.vk, + (), + &worker, + ) + .context("failed to generate recursive proof")? + .into(); + drop(span); + let _span = tracing::info_span!("verify_recursive_circuit").entered(); + if !verify_recursion_layer_proof::(circuit, &proof, &setup_data.vk) { + return Err(anyhow::anyhow!("failed to verify recursive proof")); + } + Ok(proof) + } + + /// Synthesize vector for a given circuit. + /// Expects finalization hints to match circuit. + pub(crate) fn synthesize_vector( + &self, + finalization_hints: Arc, + ) -> anyhow::Result> { + let _span = tracing::info_span!("synthesize_vector").entered(); + + let cs = match self { + Circuit::Base(circuit) => circuit.synthesis::(&finalization_hints), + Circuit::Recursive(circuit) => { + circuit.synthesis::(&finalization_hints) + } + }; + cs.witness + .context("circuit is missing witness post synthesis") + } +} diff --git a/prover/crates/lib/circuit_prover_service/src/types/circuit_prover_payload.rs b/prover/crates/lib/circuit_prover_service/src/types/circuit_prover_payload.rs new file mode 100644 index 000000000000..925b7b318ccc --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/types/circuit_prover_payload.rs @@ -0,0 +1,15 @@ +use std::sync::Arc; + +use zksync_prover_fri_types::circuit_definitions::boojum::{ + cs::implementations::witness::WitnessVec, field::goldilocks::GoldilocksField, +}; +use zksync_prover_keystore::GoldilocksGpuProverSetupData; + +use crate::types::circuit::Circuit; + +/// Payload used as input for GPU circuit prover. +pub struct GpuCircuitProverPayload { + pub circuit: Circuit, + pub witness_vector: WitnessVec, + pub setup_data: Arc, +} diff --git a/prover/crates/lib/circuit_prover_service/src/types/mod.rs b/prover/crates/lib/circuit_prover_service/src/types/mod.rs new file mode 100644 index 000000000000..cbbf0d885f7a --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/types/mod.rs @@ -0,0 +1,4 @@ +pub mod circuit; +pub mod circuit_prover_payload; +pub mod witness_vector_generator_execution_output; +pub mod witness_vector_generator_payload; diff --git a/prover/crates/lib/circuit_prover_service/src/types/witness_vector_generator_execution_output.rs b/prover/crates/lib/circuit_prover_service/src/types/witness_vector_generator_execution_output.rs new file mode 100644 index 000000000000..593f825f8f99 --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/types/witness_vector_generator_execution_output.rs @@ -0,0 +1,11 @@ +use zksync_prover_fri_types::circuit_definitions::boojum::{ + cs::implementations::witness::WitnessVec, field::goldilocks::GoldilocksField, +}; + +use crate::types::circuit::Circuit; + +/// Witness vector generator output. Used as input for GPU circuit provers. +pub struct WitnessVectorGeneratorExecutionOutput { + pub circuit: Circuit, + pub witness_vector: WitnessVec, +} diff --git a/prover/crates/lib/circuit_prover_service/src/types/witness_vector_generator_payload.rs b/prover/crates/lib/circuit_prover_service/src/types/witness_vector_generator_payload.rs new file mode 100644 index 000000000000..409e178ac61a --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/types/witness_vector_generator_payload.rs @@ -0,0 +1,11 @@ +use std::sync::Arc; + +use zksync_prover_fri_types::circuit_definitions::boojum::cs::implementations::setup::FinalizationHintsForProver; + +use crate::types::circuit::Circuit; + +/// Payload used as input for Witness Vector Generator. +pub struct WitnessVectorGeneratorPayload { + pub circuit: Circuit, + pub finalization_hints: Arc, +} diff --git a/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/mod.rs b/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/mod.rs new file mode 100644 index 000000000000..d5b140dac94f --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/mod.rs @@ -0,0 +1,11 @@ +pub use witness_vector_generator_executor::WitnessVectorGeneratorExecutor; +pub use witness_vector_generator_job_picker::WitnessVectorGeneratorJobPicker; +pub use witness_vector_generator_job_saver::WitnessVectorGeneratorJobSaver; +pub use witness_vector_generator_metadata_loader::{ + HeavyWitnessVectorMetadataLoader, LightWitnessVectorMetadataLoader, WitnessVectorMetadataLoader, +}; + +mod witness_vector_generator_executor; +mod witness_vector_generator_job_picker; +mod witness_vector_generator_job_saver; +mod witness_vector_generator_metadata_loader; diff --git a/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_executor.rs b/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_executor.rs new file mode 100644 index 000000000000..e9dd7e31fd63 --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_executor.rs @@ -0,0 +1,66 @@ +use std::time::Instant; + +use anyhow::Context; +use zksync_prover_job_processor::Executor; +use zksync_types::prover_dal::FriProverJobMetadata; + +use crate::{ + metrics::WITNESS_VECTOR_GENERATOR_METRICS, + types::{ + witness_vector_generator_execution_output::WitnessVectorGeneratorExecutionOutput, + witness_vector_generator_payload::WitnessVectorGeneratorPayload, + }, +}; + +/// WitnessVectorGenerator executor implementation. +/// Synthesizes witness vectors to be later be used in GPU circuit proving. +#[derive(Debug)] +pub struct WitnessVectorGeneratorExecutor; + +impl Executor for WitnessVectorGeneratorExecutor { + type Input = WitnessVectorGeneratorPayload; + type Output = WitnessVectorGeneratorExecutionOutput; + type Metadata = FriProverJobMetadata; + + #[tracing::instrument( + name = "witness_vector_generator_executor", + skip_all, + fields(l1_batch = % metadata.block_number) + )] + fn execute( + &self, + input: Self::Input, + metadata: Self::Metadata, + ) -> anyhow::Result { + let start_time = Instant::now(); + tracing::info!( + "Started executing witness vector generator job {}, on batch {}, for circuit {}, at round {}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round + ); + let WitnessVectorGeneratorPayload { + circuit, + finalization_hints, + } = input; + let witness_vector = circuit + .synthesize_vector(finalization_hints) + .context("failed to generate witness vector")?; + tracing::info!( + "Finished executing witness vector generator job {}, on batch {}, for circuit {}, at round {} in {:?}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round, + start_time.elapsed() + ); + WITNESS_VECTOR_GENERATOR_METRICS + .synthesize_time + .observe(start_time.elapsed()); + Ok(WitnessVectorGeneratorExecutionOutput { + circuit, + witness_vector, + }) + } +} diff --git a/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_job_picker.rs b/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_job_picker.rs new file mode 100644 index 000000000000..76e0f151c7ca --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_job_picker.rs @@ -0,0 +1,167 @@ +use std::{collections::HashMap, sync::Arc, time::Instant}; + +use anyhow::Context; +use async_trait::async_trait; +use zksync_object_store::ObjectStore; +use zksync_prover_dal::{ConnectionPool, Prover}; +use zksync_prover_fri_types::{ + circuit_definitions::{ + boojum::{ + cs::implementations::setup::FinalizationHintsForProver, + gadgets::queue::full_state_queue::FullStateCircuitQueueRawWitness, + }, + circuit_definitions::base_layer::ZkSyncBaseLayerCircuit, + }, + keys::RamPermutationQueueWitnessKey, + CircuitAuxData, CircuitWrapper, ProverServiceDataKey, RamPermutationQueueWitness, +}; +use zksync_prover_job_processor::JobPicker; +use zksync_types::{prover_dal::FriProverJobMetadata, L1BatchNumber}; + +use crate::{ + metrics::WITNESS_VECTOR_GENERATOR_METRICS, + types::{circuit::Circuit, witness_vector_generator_payload::WitnessVectorGeneratorPayload}, + witness_vector_generator::{ + witness_vector_generator_metadata_loader::WitnessVectorMetadataLoader, + WitnessVectorGeneratorExecutor, + }, +}; + +/// WitnessVectorGenerator job picker implementation. +/// Picks job from database (via MetadataLoader) and gets data from object store. +#[derive(Debug)] +pub struct WitnessVectorGeneratorJobPicker { + connection_pool: ConnectionPool, + object_store: Arc, + finalization_hints_cache: HashMap>, + metadata_loader: ML, +} + +impl WitnessVectorGeneratorJobPicker { + pub fn new( + connection_pool: ConnectionPool, + object_store: Arc, + finalization_hints_cache: HashMap>, + metadata_loader: ML, + ) -> Self { + Self { + connection_pool, + object_store, + finalization_hints_cache, + metadata_loader, + } + } + + /// Hydrates job data with witness information which is stored separately. + /// This is done in order to save RAM & storage. + // TODO: Once new BWG is done, this won't be necessary. + async fn fill_witness( + &self, + circuit: ZkSyncBaseLayerCircuit, + aux_data: CircuitAuxData, + l1_batch_number: L1BatchNumber, + ) -> anyhow::Result { + if let ZkSyncBaseLayerCircuit::RAMPermutation(circuit_instance) = circuit { + let sorted_witness_key = RamPermutationQueueWitnessKey { + block_number: l1_batch_number, + circuit_subsequence_number: aux_data.circuit_subsequence_number as usize, + is_sorted: true, + }; + let sorted_witness: RamPermutationQueueWitness = self + .object_store + .get(sorted_witness_key) + .await + .context("failed to load sorted witness key")?; + + let unsorted_witness_key = RamPermutationQueueWitnessKey { + block_number: l1_batch_number, + circuit_subsequence_number: aux_data.circuit_subsequence_number as usize, + is_sorted: false, + }; + let unsorted_witness: RamPermutationQueueWitness = self + .object_store + .get(unsorted_witness_key) + .await + .context("failed to load unsorted witness key")?; + + let mut witness = circuit_instance.witness.take().unwrap(); + witness.unsorted_queue_witness = FullStateCircuitQueueRawWitness { + elements: unsorted_witness.witness.into(), + }; + witness.sorted_queue_witness = FullStateCircuitQueueRawWitness { + elements: sorted_witness.witness.into(), + }; + circuit_instance.witness.store(Some(witness)); + + return Ok(Circuit::Base(ZkSyncBaseLayerCircuit::RAMPermutation( + circuit_instance, + ))); + } + Err(anyhow::anyhow!( + "unexpected circuit received with partial witness, expected RAM permutation, got {:?}", + circuit.short_description() + )) + } +} + +#[async_trait] +impl JobPicker for WitnessVectorGeneratorJobPicker { + type ExecutorType = WitnessVectorGeneratorExecutor; + async fn pick_job( + &mut self, + ) -> anyhow::Result> { + let start_time = Instant::now(); + tracing::info!("Started picking witness vector generator job"); + let connection = self + .connection_pool + .connection() + .await + .context("failed to get db connection")?; + let metadata = match self.metadata_loader.load_metadata(connection).await { + None => return Ok(None), + Some(metadata) => metadata, + }; + + let circuit_wrapper = self + .object_store + .get(metadata.into()) + .await + .context("failed to get circuit_wrapper from object store")?; + let circuit = match circuit_wrapper { + CircuitWrapper::Base(circuit) => Circuit::Base(circuit), + CircuitWrapper::Recursive(circuit) => Circuit::Recursive(circuit), + CircuitWrapper::BasePartial((circuit, aux_data)) => self + .fill_witness(circuit, aux_data, metadata.block_number) + .await + .context("failed to fill witness")?, + }; + + let key = ProverServiceDataKey { + circuit_id: metadata.circuit_id, + round: metadata.aggregation_round, + } + .crypto_setup_key(); + let finalization_hints = self + .finalization_hints_cache + .get(&key) + .context("failed to retrieve finalization key from cache")? + .clone(); + + let payload = WitnessVectorGeneratorPayload { + circuit, + finalization_hints, + }; + tracing::info!( + "Finished picking witness vector generator job {}, on batch {}, for circuit {}, at round {} in {:?}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round, + start_time.elapsed() + ); + WITNESS_VECTOR_GENERATOR_METRICS + .pick_time + .observe(start_time.elapsed()); + Ok(Some((payload, metadata))) + } +} diff --git a/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_job_saver.rs b/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_job_saver.rs new file mode 100644 index 000000000000..86e04472b299 --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_job_saver.rs @@ -0,0 +1,114 @@ +use std::time::Instant; + +use anyhow::Context; +use async_trait::async_trait; +use zksync_prover_dal::{ConnectionPool, Prover, ProverDal}; +use zksync_prover_job_processor::JobSaver; +use zksync_types::prover_dal::FriProverJobMetadata; + +use crate::{ + metrics::WITNESS_VECTOR_GENERATOR_METRICS, + types::witness_vector_generator_execution_output::WitnessVectorGeneratorExecutionOutput, + witness_vector_generator::WitnessVectorGeneratorExecutor, +}; + +/// WitnessVectorGenerator job saver implementation. +/// On successful execution, sends data further to gpu circuit prover. +/// On error, marks the job as failed in database. +#[derive(Debug)] +pub struct WitnessVectorGeneratorJobSaver { + connection_pool: ConnectionPool, + sender: + tokio::sync::mpsc::Sender<(WitnessVectorGeneratorExecutionOutput, FriProverJobMetadata)>, +} + +impl WitnessVectorGeneratorJobSaver { + pub fn new( + connection_pool: ConnectionPool, + sender: tokio::sync::mpsc::Sender<( + WitnessVectorGeneratorExecutionOutput, + FriProverJobMetadata, + )>, + ) -> Self { + Self { + connection_pool, + sender, + } + } +} + +#[async_trait] +impl JobSaver for WitnessVectorGeneratorJobSaver { + type ExecutorType = WitnessVectorGeneratorExecutor; + + #[tracing::instrument( + name = "witness_vector_generator_save_job", + skip_all, + fields(l1_batch = % data.1.block_number) + )] + async fn save_job_result( + &self, + data: ( + anyhow::Result, + FriProverJobMetadata, + ), + ) -> anyhow::Result<()> { + let start_time = Instant::now(); + let (result, metadata) = data; + match result { + Ok(payload) => { + tracing::info!( + "Started transferring witness vector generator job {}, on batch {}, for circuit {}, at round {}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round + ); + if self.sender.send((payload, metadata)).await.is_err() { + tracing::warn!("circuit prover shut down prematurely"); + return Ok(()); + } + tracing::info!( + "Finished transferring witness vector generator job {}, on batch {}, for circuit {}, at round {} in {:?}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round, + start_time.elapsed() + ); + WITNESS_VECTOR_GENERATOR_METRICS + .transfer_time + .observe(start_time.elapsed()); + } + Err(err) => { + tracing::error!("Witness vector generation failed: {:?}", err); + tracing::info!( + "Started saving failure for witness vector generator job {}, on batch {}, for circuit {}, at round {}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round + ); + self.connection_pool + .connection() + .await + .context("failed to get db connection")? + .fri_prover_jobs_dal() + .save_proof_error(metadata.id, err.to_string()) + .await; + tracing::info!( + "Finished saving failure for witness vector generator job {}, on batch {}, for circuit {}, at round {} in {:?}", + metadata.id, + metadata.block_number, + metadata.circuit_id, + metadata.aggregation_round, + start_time.elapsed() + ); + WITNESS_VECTOR_GENERATOR_METRICS + .save_time + .observe(start_time.elapsed()); + } + } + Ok(()) + } +} diff --git a/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_metadata_loader.rs b/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_metadata_loader.rs new file mode 100644 index 000000000000..bb0b6ec6e94c --- /dev/null +++ b/prover/crates/lib/circuit_prover_service/src/witness_vector_generator/witness_vector_generator_metadata_loader.rs @@ -0,0 +1,83 @@ +use async_trait::async_trait; +use zksync_prover_dal::{Connection, Prover, ProverDal}; +use zksync_types::{protocol_version::ProtocolSemanticVersion, prover_dal::FriProverJobMetadata}; + +/// Trait responsible for describing the job loading interface. +/// This is necessary as multiple strategies are necessary for loading jobs (which require different implementations). +#[async_trait] +pub trait WitnessVectorMetadataLoader: Sync + Send + 'static { + async fn load_metadata( + &self, + connection: Connection<'_, Prover>, + ) -> Option; +} + +/// Light job MetadataLoader. +/// +/// Most jobs are light, apart from nodes. This loader will only pick non nodes jobs. +#[derive(Debug)] +pub struct LightWitnessVectorMetadataLoader { + pod_name: String, + protocol_version: ProtocolSemanticVersion, +} + +impl LightWitnessVectorMetadataLoader { + pub fn new(pod_name: String, protocol_version: ProtocolSemanticVersion) -> Self { + Self { + pod_name, + protocol_version, + } + } +} + +#[async_trait] +impl WitnessVectorMetadataLoader for LightWitnessVectorMetadataLoader { + async fn load_metadata( + &self, + mut connection: Connection<'_, Prover>, + ) -> Option { + connection + .fri_prover_jobs_dal() + .get_light_job(self.protocol_version, &self.pod_name) + .await + } +} + +/// Heavy job MetadataLoader. +/// +/// Most jobs are light, apart from nodes. This loader will only prioritize node jobs. +/// If none are available, it will fall back to light jobs. +#[derive(Debug)] +pub struct HeavyWitnessVectorMetadataLoader { + pod_name: String, + protocol_version: ProtocolSemanticVersion, +} + +impl HeavyWitnessVectorMetadataLoader { + pub fn new(pod_name: String, protocol_version: ProtocolSemanticVersion) -> Self { + Self { + pod_name, + protocol_version, + } + } +} + +#[async_trait] +impl WitnessVectorMetadataLoader for HeavyWitnessVectorMetadataLoader { + async fn load_metadata( + &self, + mut connection: Connection<'_, Prover>, + ) -> Option { + let metadata = connection + .fri_prover_jobs_dal() + .get_heavy_job(self.protocol_version, &self.pod_name) + .await; + if metadata.is_some() { + return metadata; + } + connection + .fri_prover_jobs_dal() + .get_light_job(self.protocol_version, &self.pod_name) + .await + } +} diff --git a/prover/crates/lib/prover_dal/.sqlx/query-3b3193bfac70b5fe69bf3bb7ba5a234c19578572973094b21ddbb3876da6bb95.json b/prover/crates/lib/prover_dal/.sqlx/query-4d89c375af2c211a8a896cad7c99d2c9ff0d28f4662913ef7c2cf6fa1aa430d4.json similarity index 65% rename from prover/crates/lib/prover_dal/.sqlx/query-3b3193bfac70b5fe69bf3bb7ba5a234c19578572973094b21ddbb3876da6bb95.json rename to prover/crates/lib/prover_dal/.sqlx/query-4d89c375af2c211a8a896cad7c99d2c9ff0d28f4662913ef7c2cf6fa1aa430d4.json index 962979344b4b..f84489dd6523 100644 --- a/prover/crates/lib/prover_dal/.sqlx/query-3b3193bfac70b5fe69bf3bb7ba5a234c19578572973094b21ddbb3876da6bb95.json +++ b/prover/crates/lib/prover_dal/.sqlx/query-4d89c375af2c211a8a896cad7c99d2c9ff0d28f4662913ef7c2cf6fa1aa430d4.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n UPDATE prover_jobs_fri\n SET\n status = 'in_progress',\n attempts = attempts + 1,\n updated_at = NOW(),\n processing_started_at = NOW(),\n picked_by = $3\n WHERE\n id = (\n SELECT\n id\n FROM\n prover_jobs_fri\n WHERE\n status = 'queued'\n AND protocol_version = $1\n AND protocol_version_patch = $2\n ORDER BY\n l1_batch_number ASC,\n aggregation_round ASC,\n circuit_id ASC,\n id ASC\n LIMIT\n 1\n FOR UPDATE\n SKIP LOCKED\n )\n RETURNING\n prover_jobs_fri.id,\n prover_jobs_fri.l1_batch_number,\n prover_jobs_fri.circuit_id,\n prover_jobs_fri.aggregation_round,\n prover_jobs_fri.sequence_number,\n prover_jobs_fri.depth,\n prover_jobs_fri.is_node_final_proof\n ", + "query": "\n UPDATE prover_jobs_fri\n SET\n status = 'in_progress',\n attempts = attempts + 1,\n updated_at = NOW(),\n processing_started_at = NOW(),\n picked_by = $3\n WHERE\n id = (\n SELECT\n id\n FROM\n prover_jobs_fri\n WHERE\n status = 'queued'\n AND protocol_version = $1\n AND protocol_version_patch = $2\n AND aggregation_round = $4\n ORDER BY\n l1_batch_number ASC,\n circuit_id ASC,\n id ASC\n LIMIT\n 1\n FOR UPDATE\n SKIP LOCKED\n )\n RETURNING\n prover_jobs_fri.id,\n prover_jobs_fri.l1_batch_number,\n prover_jobs_fri.circuit_id,\n prover_jobs_fri.aggregation_round,\n prover_jobs_fri.sequence_number,\n prover_jobs_fri.depth,\n prover_jobs_fri.is_node_final_proof\n ", "describe": { "columns": [ { @@ -43,7 +43,8 @@ "Left": [ "Int4", "Int4", - "Text" + "Text", + "Int2" ] }, "nullable": [ @@ -56,5 +57,5 @@ false ] }, - "hash": "3b3193bfac70b5fe69bf3bb7ba5a234c19578572973094b21ddbb3876da6bb95" + "hash": "4d89c375af2c211a8a896cad7c99d2c9ff0d28f4662913ef7c2cf6fa1aa430d4" } diff --git a/prover/crates/lib/prover_dal/.sqlx/query-79b5ad4ef1ba888c3ffdb27cf2203367ae4cf57703c532fe3dfe18924c3c9492.json b/prover/crates/lib/prover_dal/.sqlx/query-79b5ad4ef1ba888c3ffdb27cf2203367ae4cf57703c532fe3dfe18924c3c9492.json new file mode 100644 index 000000000000..d1db20fbdbea --- /dev/null +++ b/prover/crates/lib/prover_dal/.sqlx/query-79b5ad4ef1ba888c3ffdb27cf2203367ae4cf57703c532fe3dfe18924c3c9492.json @@ -0,0 +1,61 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE prover_jobs_fri\n SET\n status = 'in_progress',\n attempts = attempts + 1,\n updated_at = NOW(),\n processing_started_at = NOW(),\n picked_by = $3\n WHERE\n id = (\n SELECT\n id\n FROM\n prover_jobs_fri\n WHERE\n status = 'queued'\n AND protocol_version = $1\n AND protocol_version_patch = $2\n AND aggregation_round != $4\n ORDER BY\n l1_batch_number ASC,\n aggregation_round ASC,\n circuit_id ASC,\n id ASC\n LIMIT\n 1\n FOR UPDATE\n SKIP LOCKED\n )\n RETURNING\n prover_jobs_fri.id,\n prover_jobs_fri.l1_batch_number,\n prover_jobs_fri.circuit_id,\n prover_jobs_fri.aggregation_round,\n prover_jobs_fri.sequence_number,\n prover_jobs_fri.depth,\n prover_jobs_fri.is_node_final_proof\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "l1_batch_number", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "circuit_id", + "type_info": "Int2" + }, + { + "ordinal": 3, + "name": "aggregation_round", + "type_info": "Int2" + }, + { + "ordinal": 4, + "name": "sequence_number", + "type_info": "Int4" + }, + { + "ordinal": 5, + "name": "depth", + "type_info": "Int4" + }, + { + "ordinal": 6, + "name": "is_node_final_proof", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [ + "Int4", + "Int4", + "Text", + "Int2" + ] + }, + "nullable": [ + false, + false, + false, + false, + false, + false, + false + ] + }, + "hash": "79b5ad4ef1ba888c3ffdb27cf2203367ae4cf57703c532fe3dfe18924c3c9492" +} diff --git a/prover/crates/lib/prover_dal/src/fri_prover_dal.rs b/prover/crates/lib/prover_dal/src/fri_prover_dal.rs index a0420b056125..8efa8e2f6837 100644 --- a/prover/crates/lib/prover_dal/src/fri_prover_dal.rs +++ b/prover/crates/lib/prover_dal/src/fri_prover_dal.rs @@ -1,5 +1,10 @@ #![doc = include_str!("../doc/FriProverDal.md")] -use std::{collections::HashMap, convert::TryFrom, str::FromStr, time::Duration}; +use std::{ + collections::HashMap, + convert::TryFrom, + str::FromStr, + time::{Duration, Instant}, +}; use zksync_basic_types::{ basic_fri_types::{ @@ -60,8 +65,11 @@ impl FriProverDal<'_, '_> { /// - within the lowest batch, look at the lowest aggregation level (move up the proof tree) /// - pick the same type of circuit for as long as possible, this maximizes GPU cache reuse /// - /// NOTE: Most of this function is a duplicate of `get_next_job()`. Get next job will be deleted together with old prover. - pub async fn get_job( + /// Most of this function is similar to `get_light_job()`. + /// The 2 differ in the type of jobs they will load. Node jobs are heavy in resource utilization. + /// + /// NOTE: This function retrieves only node jobs. + pub async fn get_heavy_job( &mut self, protocol_version: ProtocolSemanticVersion, picked_by: &str, @@ -85,6 +93,84 @@ impl FriProverDal<'_, '_> { status = 'queued' AND protocol_version = $1 AND protocol_version_patch = $2 + AND aggregation_round = $4 + ORDER BY + l1_batch_number ASC, + circuit_id ASC, + id ASC + LIMIT + 1 + FOR UPDATE + SKIP LOCKED + ) + RETURNING + prover_jobs_fri.id, + prover_jobs_fri.l1_batch_number, + prover_jobs_fri.circuit_id, + prover_jobs_fri.aggregation_round, + prover_jobs_fri.sequence_number, + prover_jobs_fri.depth, + prover_jobs_fri.is_node_final_proof + "#, + protocol_version.minor as i32, + protocol_version.patch.0 as i32, + picked_by, + AggregationRound::NodeAggregation as i64, + ) + .fetch_optional(self.storage.conn()) + .await + .expect("failed to get prover job") + .map(|row| FriProverJobMetadata { + id: row.id as u32, + block_number: L1BatchNumber(row.l1_batch_number as u32), + circuit_id: row.circuit_id as u8, + aggregation_round: AggregationRound::try_from(i32::from(row.aggregation_round)) + .unwrap(), + sequence_number: row.sequence_number as usize, + depth: row.depth as u16, + is_node_final_proof: row.is_node_final_proof, + pick_time: Instant::now(), + }) + } + + /// Retrieves the next prover job to be proven. Called by WVGs. + /// + /// Prover jobs must be thought of as ordered. + /// Prover must prioritize proving such jobs that will make the chain move forward the fastest. + /// Current ordering: + /// - pick the lowest batch + /// - within the lowest batch, look at the lowest aggregation level (move up the proof tree) + /// - pick the same type of circuit for as long as possible, this maximizes GPU cache reuse + /// + /// Most of this function is similar to `get_heavy_job()`. + /// The 2 differ in the type of jobs they will load. Node jobs are heavy in resource utilization. + /// + /// NOTE: This function retrieves all jobs but nodes. + pub async fn get_light_job( + &mut self, + protocol_version: ProtocolSemanticVersion, + picked_by: &str, + ) -> Option { + sqlx::query!( + r#" + UPDATE prover_jobs_fri + SET + status = 'in_progress', + attempts = attempts + 1, + updated_at = NOW(), + processing_started_at = NOW(), + picked_by = $3 + WHERE + id = ( + SELECT + id + FROM + prover_jobs_fri + WHERE + status = 'queued' + AND protocol_version = $1 + AND protocol_version_patch = $2 + AND aggregation_round != $4 ORDER BY l1_batch_number ASC, aggregation_round ASC, @@ -107,6 +193,7 @@ impl FriProverDal<'_, '_> { protocol_version.minor as i32, protocol_version.patch.0 as i32, picked_by, + AggregationRound::NodeAggregation as i64 ) .fetch_optional(self.storage.conn()) .await @@ -120,6 +207,7 @@ impl FriProverDal<'_, '_> { sequence_number: row.sequence_number as usize, depth: row.depth as u16, is_node_final_proof: row.is_node_final_proof, + pick_time: Instant::now(), }) } @@ -181,9 +269,9 @@ impl FriProverDal<'_, '_> { sequence_number: row.sequence_number as usize, depth: row.depth as u16, is_node_final_proof: row.is_node_final_proof, + pick_time: Instant::now(), }) } - pub async fn get_next_job_for_circuit_id_round( &mut self, circuits_to_pick: &[CircuitIdRoundTuple], @@ -271,6 +359,7 @@ impl FriProverDal<'_, '_> { sequence_number: row.sequence_number as usize, depth: row.depth as u16, is_node_final_proof: row.is_node_final_proof, + pick_time: Instant::now(), }) } @@ -359,6 +448,7 @@ impl FriProverDal<'_, '_> { sequence_number: row.sequence_number as usize, depth: row.depth as u16, is_node_final_proof: row.is_node_final_proof, + pick_time: Instant::now(), }) .unwrap() } diff --git a/prover/crates/lib/prover_job_processor/Cargo.toml b/prover/crates/lib/prover_job_processor/Cargo.toml new file mode 100644 index 000000000000..5197b33b1f95 --- /dev/null +++ b/prover/crates/lib/prover_job_processor/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "zksync_prover_job_processor" +description = "ZKsync Prover Job Processor" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true + +[dependencies] +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +tokio.workspace = true +tokio-stream.workspace = true +tokio-util.workspace = true +tracing.workspace = true +vise.workspace = true +strum.workspace = true diff --git a/prover/crates/lib/prover_job_processor/README.md b/prover/crates/lib/prover_job_processor/README.md new file mode 100644 index 000000000000..5eea5476d05d --- /dev/null +++ b/prover/crates/lib/prover_job_processor/README.md @@ -0,0 +1,152 @@ +# Prover Job Processor + +Prover Job Processor aims to be a small "framework" that allows building prover components at break-neck speeds. + +## Context + +Previously, prover components were hand tailored and had similar issues spread across the codebase. The "framework"'s +purpose is to standardize implementations and lift the undifferentiated work from prover component developers. + +## How it works + +The "framework" exports 4 main primitives: + +- executor +- job_picker +- job_saver +- job_runner + +### Executor + +This is the most important trait. It is meant to execute the crypto primitives (or any other payloads) and defines what +the inputs are, what is the metadata that has to travel with it and what the output will be. Executors will receive +information from Job Picker and will provide it further to Job Saver. + +For example, this could witness vector generator (synthesis of witness vector) or circuit prover (GPU circuit proving & +verification). Each would define what they need as input to operate and what they'll output. + +### Job Picker + +The starting point of the process. This trait is tied to Executor and will pick a metadata & input that corresponds to +the Executor. Job Picker picks information and provides it to Executor. + +As examples, for witness vector generator it would be a query to the database & a query to object storage. For circuit +prover, it would be waiting on the communication channel between witness vector generator and circuit prover. + +### Job Saver + +The final point of the process. This trait is tied to Executor and will receive metadata & output that corresponds to +the Executor. Job Saver receives information from Executor and saves it. + +Continuing with the same examples, for witness vector generator it would send the information to the communication +channel between witness vector generator & circuit prover. For circuit prover example, it would simply store the +information to database & object store. + +### Job Runner + +A wrapper over all 3 traits above, ensuring they communicate to each other as expected & they are spawned as +long-running threads. + +## Diagram + +```mermaid +sequenceDiagram + participant p as Job Picker + participant e as Executor + participant s as Job Saver + + p-->>p: Get metadata & input + p-->>e: Provide metadata & input + e-->>e: Execute + e-->>s: Provide metadata & output + s-->>s: Save output +``` + +## How to use it + +If you want to create a new prover component, you'd need to first define what are the communication boundaries: + +- metadata +- input +- output + +With these out of the way, you can specify the Executor and even integrate the crypto primitive. At this point in time +you could fully cover it with unit tests to make sure the functionality works as intended. + +Moving forward, you'll need to understand where you get this information and where you store it. These are your Job +Picker & Job saver. NOTE: Just like the executor, you need to implement the logic of executing/picking/saving a single +job, the "framework" will take care of looping it over and transmitting the details from one end to another. + +Once done, provide them as arguments to JobRunner, call `your_job_runner.run()` and you're good to go. + +TODO: Add example once testing is in place. + +## More (internal) details + +There are a few things that we've glossed over, let's get into details: + +### Back-offs & cancelling + +As you might've guessed, from a production point of view, you need to make sure that the process can die gracefully (k8s +sigterm), without being a nuisance to your dependencies (think DB or object store). As such, job picker can have an +optional component responsible for back-off & cancelling. + +### How do components communicate + +Internally, `JobRunner` wraps all 3 primitives into a task that are looping in a `while channel.recv() {}`. Each task is +slightly special, but the logic is far from complex. + +### Limitations + +Back off & cancelling is implemented only for job picker. Whilst it might sound inconvenient, in practice it works +great. When the cancel is received, the job picker will stop picking jobs, the executor will keep executing until there +are no more jobs in the receiver and the saver will save all jobs until there are no more jobs received from executor. + +Backoff is currently hardcoded, but it is trivial to make it more configurable. + +Whilst not a limitation, the first version is applied only to `circuit_provers`. It's very likely that more enhancements +will be needed to accommodate the rest of the codebase. Treat this as work in progress. + +## Objectives + +The "framework" wants to achieve the following: + +1. Reduce code complexity & technical debt (modularize the codebase) +2. Empower testability of the prover codebase +3. Optimize prover components for speed and multi-datacenter/multi-cloud setups +4. Increase speed of delivery of prover components +5. Enable external shops to implement their own flavors of prover components + +### 1. Reduce code complexity & technical debt (modularize the codebase) + +Previously, most prover components were custom written. This meant that the same logic was reimplemented across multiple +components. Whilst the "framework" doesn't fully solve the problem, it drastically reduces the amount of code needed to +start a new components. + +The rest of the code duplication can be tackled in the future as part of the node framework. + +### 2. Empower testability of the prover codebase + +Due to the entangled nature of the code, prover codebase was difficult to test. Current modular setup enables testing in +isolation each component. (not exactly true, given cryptography dependencies are too heavy - but will be true in the new +prover implementation) + +### 3. Optimize prover components for speed and multi-datacenter/multi-cloud setups + +Previously, provers were running "sync". Load job, once loaded, execute it, once executed, save its result. Whilst this +is fine, all steps can be done in parallel. This becomes super important when database and running machine are far away +and the round trip to database can cause up to 50% of the entire time. In a multi-cloud (read as future) setup, this +becomes even more painful. For free, we remove the current bottleneck from database (which was previous bottleneck, due +to # of connections). + +### 4. Increase speed of delivery of prover components + +Boojum release was rather slow and even releasing the current `circuit_prover` took longer than anticipated. Given +upcoming prover updates, this release sets us for success going forward. Furthermore, experimenting with different +setups becomes a matter of days, rather than months. + +### 5. Enable external shops to implement their own flavors of prover components + +Most external folks have to fork zksync-era and keep an up-to-date fork if anything needs to be modified. The framework +allows using the executors, whilst defining custom pickers/savers. This will be a massive time-save for any external +shop that wants to innovate on top of zksync-era's provers. diff --git a/prover/crates/bin/circuit_prover/src/backoff.rs b/prover/crates/lib/prover_job_processor/src/backoff_and_cancellable.rs similarity index 60% rename from prover/crates/bin/circuit_prover/src/backoff.rs rename to prover/crates/lib/prover_job_processor/src/backoff_and_cancellable.rs index 6ddb3d94be35..15d80404dc71 100644 --- a/prover/crates/bin/circuit_prover/src/backoff.rs +++ b/prover/crates/lib/prover_job_processor/src/backoff_and_cancellable.rs @@ -1,5 +1,24 @@ use std::{ops::Mul, time::Duration}; +use tokio_util::sync::CancellationToken; + +/// Utility struct that provides cancellation awareness & backoff capabilities. +/// They usually go hand in hand, having a wrapper over both simplifies implementation. +#[derive(Debug, Clone)] +pub struct BackoffAndCancellable { + pub(crate) backoff: Backoff, + pub(crate) cancellation_token: CancellationToken, +} + +impl BackoffAndCancellable { + pub fn new(backoff: Backoff, cancellation_token: CancellationToken) -> Self { + Self { + backoff, + cancellation_token, + } + } +} + /// Backoff - convenience structure that takes care of backoff timings. #[derive(Debug, Clone)] pub struct Backoff { @@ -7,12 +26,10 @@ pub struct Backoff { current_delay: Duration, max_delay: Duration, } - impl Backoff { /// The delay multiplication coefficient. // Currently it's hardcoded, but could be provided in the constructor. const DELAY_MULTIPLIER: u32 = 2; - /// Create a backoff with base_delay (first delay) and max_delay (maximum delay possible). pub fn new(base_delay: Duration, max_delay: Duration) -> Self { Backoff { @@ -37,3 +54,10 @@ impl Backoff { self.current_delay = self.base_delay; } } + +impl Default for Backoff { + /// Sensible database specific delays. + fn default() -> Self { + Self::new(Duration::from_secs(1), Duration::from_secs(5)) + } +} diff --git a/prover/crates/lib/prover_job_processor/src/executor.rs b/prover/crates/lib/prover_job_processor/src/executor.rs new file mode 100644 index 000000000000..80b019960e3e --- /dev/null +++ b/prover/crates/lib/prover_job_processor/src/executor.rs @@ -0,0 +1,11 @@ +/// Executor trait, responsible for defining what a job's execution will look like. +/// The trait covers what it expects as input, what it'll offer as output and what metadata needs to travel together with the input. +/// This is the backbone of the `prover_job_processor` from a user's point of view. +pub trait Executor: Send + Sync + 'static { + type Input: Send; + type Output: Send; + type Metadata: Send + Clone; + + fn execute(&self, input: Self::Input, metadata: Self::Metadata) + -> anyhow::Result; +} diff --git a/prover/crates/lib/prover_job_processor/src/job_picker.rs b/prover/crates/lib/prover_job_processor/src/job_picker.rs new file mode 100644 index 000000000000..74ecbcde5d74 --- /dev/null +++ b/prover/crates/lib/prover_job_processor/src/job_picker.rs @@ -0,0 +1,18 @@ +use async_trait::async_trait; + +use crate::Executor; + +/// Job Picker trait, in charge of getting a new job for executor. +/// NOTE: Job Pickers are tied to an executor, which ensures input/output/metadata types match. +#[async_trait] +pub trait JobPicker: Send + Sync + 'static { + type ExecutorType: Executor; + async fn pick_job( + &mut self, + ) -> anyhow::Result< + Option<( + ::Input, + ::Metadata, + )>, + >; +} diff --git a/prover/crates/lib/prover_job_processor/src/job_runner.rs b/prover/crates/lib/prover_job_processor/src/job_runner.rs new file mode 100644 index 000000000000..2a2d803e206d --- /dev/null +++ b/prover/crates/lib/prover_job_processor/src/job_runner.rs @@ -0,0 +1,69 @@ +use tokio::task::JoinHandle; + +use crate::{ + task_wiring::{JobPickerTask, JobSaverTask, Task, WorkerPool}, + BackoffAndCancellable, Executor, JobPicker, JobSaver, +}; + +/// It's preferred to have a minimal amount of jobs in flight at any given time. +/// This ensures that memory usage is minimized, in case of failures, a small amount of jobs is lost and +/// components can apply back pressure to each other in case of misconfiguration. +const CHANNEL_SIZE: usize = 1; + +/// The "framework" wrapper that runs the entire machinery. +/// Job Runner is responsible for tying together tasks (picker, executor, saver) and starting them. +#[derive(Debug)] +pub struct JobRunner +where + E: Executor, + P: JobPicker, + S: JobSaver, +{ + executor: E, + picker: P, + saver: S, + num_workers: usize, + picker_backoff_and_cancellable: Option, +} + +impl JobRunner +where + E: Executor, + P: JobPicker, + S: JobSaver, +{ + pub fn new( + executor: E, + picker: P, + saver: S, + num_workers: usize, + picker_backoff_and_cancellable: Option, + ) -> Self { + Self { + executor, + picker, + saver, + num_workers, + picker_backoff_and_cancellable, + } + } + + /// Runs job runner tasks. + pub fn run(self) -> Vec>> { + let (input_tx, input_rx) = + tokio::sync::mpsc::channel::<(E::Input, E::Metadata)>(CHANNEL_SIZE); + let (result_tx, result_rx) = + tokio::sync::mpsc::channel::<(anyhow::Result, E::Metadata)>(CHANNEL_SIZE); + + let picker_task = + JobPickerTask::new(self.picker, input_tx, self.picker_backoff_and_cancellable); + let worker_pool = WorkerPool::new(self.executor, self.num_workers, input_rx, result_tx); + let saver_task = JobSaverTask::new(self.saver, result_rx); + + vec![ + tokio::spawn(picker_task.run()), + tokio::spawn(worker_pool.run()), + tokio::spawn(saver_task.run()), + ] + } +} diff --git a/prover/crates/lib/prover_job_processor/src/job_saver.rs b/prover/crates/lib/prover_job_processor/src/job_saver.rs new file mode 100644 index 000000000000..4c0833dd77a4 --- /dev/null +++ b/prover/crates/lib/prover_job_processor/src/job_saver.rs @@ -0,0 +1,19 @@ +use async_trait::async_trait; + +use crate::Executor; + +/// Job Saver trait, in charge of getting the result from the executor and dispatching it. +/// Dispatch could be storing it, or sending to a separate component. +/// NOTE: Job Savers are tied to an executor, which ensures input/output/metadata types match. +#[async_trait] +pub trait JobSaver: Send + Sync + 'static { + type ExecutorType: Executor; + + async fn save_job_result( + &self, + data: ( + anyhow::Result<::Output>, + ::Metadata, + ), + ) -> anyhow::Result<()>; +} diff --git a/prover/crates/lib/prover_job_processor/src/lib.rs b/prover/crates/lib/prover_job_processor/src/lib.rs new file mode 100644 index 000000000000..02847be533ff --- /dev/null +++ b/prover/crates/lib/prover_job_processor/src/lib.rs @@ -0,0 +1,19 @@ +pub use backoff_and_cancellable::{Backoff, BackoffAndCancellable}; +pub use executor::Executor; +pub use job_picker::JobPicker; +pub use job_runner::JobRunner; +pub use job_saver::JobSaver; + +mod backoff_and_cancellable; +mod executor; +mod job_picker; +mod job_runner; +mod job_saver; +mod task_wiring; + +// convenience aliases to simplify declarations +type Input

= <

::ExecutorType as Executor>::Input; +type PickerMetadata

= <

::ExecutorType as Executor>::Metadata; + +type Output = <::ExecutorType as Executor>::Output; +type SaverMetadata = <::ExecutorType as Executor>::Metadata; diff --git a/prover/crates/lib/prover_job_processor/src/task_wiring/job_picker_task.rs b/prover/crates/lib/prover_job_processor/src/task_wiring/job_picker_task.rs new file mode 100644 index 000000000000..f3e5e3ea4686 --- /dev/null +++ b/prover/crates/lib/prover_job_processor/src/task_wiring/job_picker_task.rs @@ -0,0 +1,77 @@ +use anyhow::Context; +use async_trait::async_trait; + +use crate::{task_wiring::task::Task, BackoffAndCancellable, Input, JobPicker, PickerMetadata}; + +/// Wrapper over JobPicker. Makes it a continuous task, picking tasks until cancelled. +#[derive(Debug)] +pub struct JobPickerTask { + picker: P, + input_tx: tokio::sync::mpsc::Sender<(Input

, PickerMetadata

)>, + backoff_and_cancellable: Option, +} + +impl JobPickerTask

{ + pub fn new( + picker: P, + input_tx: tokio::sync::mpsc::Sender<(Input

, PickerMetadata

)>, + backoff_and_cancellable: Option, + ) -> Self { + Self { + picker, + input_tx, + backoff_and_cancellable, + } + } + + /// Backs off for the specified amount of time or until cancel is received, if available. + async fn backoff(&mut self) { + if let Some(backoff_and_cancellable) = &mut self.backoff_and_cancellable { + let backoff_duration = backoff_and_cancellable.backoff.delay(); + tracing::info!("Backing off for {:?}...", backoff_duration); + // Error here corresponds to a timeout w/o receiving task_wiring cancel; we're OK with this. + tokio::time::timeout( + backoff_duration, + backoff_and_cancellable.cancellation_token.cancelled(), + ) + .await + .ok(); + } + } + + /// Resets backoff to initial state, if available. + fn reset_backoff(&mut self) { + if let Some(backoff_and_cancellable) = &mut self.backoff_and_cancellable { + backoff_and_cancellable.backoff.reset(); + } + } + + /// Checks if the task is cancelled, if available. + fn is_cancelled(&self) -> bool { + if let Some(backoff_and_cancellable) = &self.backoff_and_cancellable { + return backoff_and_cancellable.cancellation_token.is_cancelled(); + } + false + } +} + +#[async_trait] +impl Task for JobPickerTask

{ + async fn run(mut self) -> anyhow::Result<()> { + while !self.is_cancelled() { + match self.picker.pick_job().await.context("failed to pick job")? { + Some((input, metadata)) => { + self.input_tx.send((input, metadata)).await.map_err(|err| { + anyhow::anyhow!("job picker failed to pass job to executor: {}", err) + })?; + self.reset_backoff(); + } + None => { + self.backoff().await; + } + } + } + tracing::info!("Stop signal received, shutting down JobPickerTask..."); + Ok(()) + } +} diff --git a/prover/crates/lib/prover_job_processor/src/task_wiring/job_saver_task.rs b/prover/crates/lib/prover_job_processor/src/task_wiring/job_saver_task.rs new file mode 100644 index 000000000000..8573821bc902 --- /dev/null +++ b/prover/crates/lib/prover_job_processor/src/task_wiring/job_saver_task.rs @@ -0,0 +1,33 @@ +use anyhow::Context; +use async_trait::async_trait; + +use crate::{task_wiring::task::Task, JobSaver, Output, SaverMetadata}; + +/// Wrapper over JobSaver. Makes it a continuous task, picking tasks until execution channel is closed. +#[derive(Debug)] +pub struct JobSaverTask { + saver: S, + result_rx: tokio::sync::mpsc::Receiver<(anyhow::Result>, SaverMetadata)>, +} + +impl JobSaverTask { + pub fn new( + saver: S, + result_rx: tokio::sync::mpsc::Receiver<(anyhow::Result>, SaverMetadata)>, + ) -> Self { + Self { saver, result_rx } + } +} + +#[async_trait] +impl Task for JobSaverTask { + async fn run(mut self) -> anyhow::Result<()> { + while let Some(data) = self.result_rx.recv().await { + self.saver + .save_job_result(data) + .await + .context("failed to save result")?; + } + Ok(()) + } +} diff --git a/prover/crates/lib/prover_job_processor/src/task_wiring/mod.rs b/prover/crates/lib/prover_job_processor/src/task_wiring/mod.rs new file mode 100644 index 000000000000..4b1ded605f50 --- /dev/null +++ b/prover/crates/lib/prover_job_processor/src/task_wiring/mod.rs @@ -0,0 +1,9 @@ +pub use job_picker_task::JobPickerTask; +pub use job_saver_task::JobSaverTask; +pub use task::Task; +pub use worker_pool::WorkerPool; + +mod job_picker_task; +mod job_saver_task; +mod task; +mod worker_pool; diff --git a/prover/crates/lib/prover_job_processor/src/task_wiring/task.rs b/prover/crates/lib/prover_job_processor/src/task_wiring/task.rs new file mode 100644 index 000000000000..68f8156b67c1 --- /dev/null +++ b/prover/crates/lib/prover_job_processor/src/task_wiring/task.rs @@ -0,0 +1,7 @@ +use async_trait::async_trait; + +/// Convenience trait to tie together all task wrappers. +#[async_trait] +pub trait Task { + async fn run(mut self) -> anyhow::Result<()>; +} diff --git a/prover/crates/lib/prover_job_processor/src/task_wiring/worker_pool.rs b/prover/crates/lib/prover_job_processor/src/task_wiring/worker_pool.rs new file mode 100644 index 000000000000..2f788ae99746 --- /dev/null +++ b/prover/crates/lib/prover_job_processor/src/task_wiring/worker_pool.rs @@ -0,0 +1,64 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use futures::stream::StreamExt; +use tokio_stream::wrappers::ReceiverStream; + +use crate::{executor::Executor, task_wiring::Task}; + +/// Wrapper over Executor. Makes it a continuous task, picking tasks until picker channel is closed. +/// It can execute multiple concurrent executors, up to specified limit. +#[derive(Debug)] +pub struct WorkerPool +where + E: Executor, +{ + executor: E, + num_workers: usize, + input_rx: tokio::sync::mpsc::Receiver<(E::Input, E::Metadata)>, + result_tx: tokio::sync::mpsc::Sender<(anyhow::Result, E::Metadata)>, +} + +impl WorkerPool { + pub fn new( + executor: E, + num_workers: usize, + input_rx: tokio::sync::mpsc::Receiver<(E::Input, E::Metadata)>, + result_tx: tokio::sync::mpsc::Sender<(anyhow::Result, E::Metadata)>, + ) -> Self { + Self { + executor, + num_workers, + input_rx, + result_tx, + } + } +} + +#[async_trait] +impl Task for WorkerPool { + async fn run(mut self) -> anyhow::Result<()> { + let executor = Arc::new(self.executor); + let num_workers = self.num_workers; + let stream = ReceiverStream::new(self.input_rx); + + stream + .for_each_concurrent(num_workers, move |(input, metadata)| { + let executor = executor.clone(); + let result_tx = self.result_tx.clone(); + let exec_metadata = metadata.clone(); + async move { + let payload = + tokio::task::spawn_blocking(move || executor.execute(input, exec_metadata)) + .await + .expect("failed executing"); + result_tx + .send((payload, metadata)) + .await + .expect("job saver channel has been closed unexpectedly"); + } + }) + .await; + Ok(()) + } +} diff --git a/zkstack_cli/crates/zkstack/completion/_zkstack.zsh b/zkstack_cli/crates/zkstack/completion/_zkstack.zsh index f0e10b465b6a..23d7ff2802c7 100644 --- a/zkstack_cli/crates/zkstack/completion/_zkstack.zsh +++ b/zkstack_cli/crates/zkstack/completion/_zkstack.zsh @@ -1922,7 +1922,11 @@ _arguments "${_arguments_options[@]}" : \ '--round=[]:ROUND:(all-rounds basic-circuits leaf-aggregation node-aggregation recursion-tip scheduler)' \ '--threads=[]:THREADS:_default' \ '--max-allocation=[Memory allocation limit in bytes (for prover component)]:MAX_ALLOCATION:_default' \ -'--witness-vector-generator-count=[]:WITNESS_VECTOR_GENERATOR_COUNT:_default' \ +'-l+[]:LIGHT_WVG_COUNT:_default' \ +'--light-wvg-count=[]:LIGHT_WVG_COUNT:_default' \ +'-h+[]:HEAVY_WVG_COUNT:_default' \ +'--heavy-wvg-count=[]:HEAVY_WVG_COUNT:_default' \ +'-m+[]:MAX_ALLOCATION:_default' \ '--max-allocation=[]:MAX_ALLOCATION:_default' \ '--docker=[]:DOCKER:(true false)' \ '--tag=[]:TAG:_default' \ diff --git a/zkstack_cli/crates/zkstack/completion/zkstack.fish b/zkstack_cli/crates/zkstack/completion/zkstack.fish index dacc27d88089..ef3e689e4292 100644 --- a/zkstack_cli/crates/zkstack/completion/zkstack.fish +++ b/zkstack_cli/crates/zkstack/completion/zkstack.fish @@ -500,8 +500,9 @@ complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_ complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -l round -r -f -a "{all-rounds\t'',basic-circuits\t'',leaf-aggregation\t'',node-aggregation\t'',recursion-tip\t'',scheduler\t''}" complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -l threads -r complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -l max-allocation -d 'Memory allocation limit in bytes (for prover component)' -r -complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -l witness-vector-generator-count -r -complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -l max-allocation -r +complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -s l -l light-wvg-count -r +complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -s h -l heavy-wvg-count -r +complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -s m -l max-allocation -r complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -l docker -r -f -a "{true\t'',false\t''}" complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -l tag -r complete -c zkstack -n "__fish_zkstack_using_subcommand prover; and __fish_seen_subcommand_from run" -l chain -d 'Chain to use' -r diff --git a/zkstack_cli/crates/zkstack/completion/zkstack.sh b/zkstack_cli/crates/zkstack/completion/zkstack.sh index 0cf89ed4ef3f..125e080f6761 100644 --- a/zkstack_cli/crates/zkstack/completion/zkstack.sh +++ b/zkstack_cli/crates/zkstack/completion/zkstack.sh @@ -7338,7 +7338,7 @@ _zkstack() { return 0 ;; zkstack__prover__run) - opts="-v -h --component --round --threads --max-allocation --witness-vector-generator-count --max-allocation --docker --tag --verbose --chain --ignore-prerequisites --help" + opts="-l -h -m -v -h --component --round --threads --max-allocation --light-wvg-count --heavy-wvg-count --max-allocation --docker --tag --verbose --chain --ignore-prerequisites --help" if [[ ${cur} == -* || ${COMP_CWORD} -eq 3 ]] ; then COMPREPLY=( $(compgen -W "${opts}" -- "${cur}") ) return 0 @@ -7360,7 +7360,19 @@ _zkstack() { COMPREPLY=($(compgen -f "${cur}")) return 0 ;; - --witness-vector-generator-count) + --light-wvg-count) + COMPREPLY=($(compgen -f "${cur}")) + return 0 + ;; + -l) + COMPREPLY=($(compgen -f "${cur}")) + return 0 + ;; + --heavy-wvg-count) + COMPREPLY=($(compgen -f "${cur}")) + return 0 + ;; + -h) COMPREPLY=($(compgen -f "${cur}")) return 0 ;; @@ -7368,6 +7380,10 @@ _zkstack() { COMPREPLY=($(compgen -f "${cur}")) return 0 ;; + -m) + COMPREPLY=($(compgen -f "${cur}")) + return 0 + ;; --docker) COMPREPLY=($(compgen -W "true false" -- "${cur}")) return 0 diff --git a/zkstack_cli/crates/zkstack/src/commands/prover/args/run.rs b/zkstack_cli/crates/zkstack/src/commands/prover/args/run.rs index b79af777673c..4b3a16a38fca 100644 --- a/zkstack_cli/crates/zkstack/src/commands/prover/args/run.rs +++ b/zkstack_cli/crates/zkstack/src/commands/prover/args/run.rs @@ -176,16 +176,16 @@ impl ProverComponent { args.fri_prover_args.max_allocation.unwrap() )); }; - if args - .circuit_prover_args - .witness_vector_generator_count - .is_some() - { + if args.circuit_prover_args.light_wvg_count.is_some() { additional_args.push(format!( - "--witness-vector-generator-count={}", - args.circuit_prover_args - .witness_vector_generator_count - .unwrap() + "--light-wvg-count={}", + args.circuit_prover_args.light_wvg_count.unwrap() + )); + }; + if args.circuit_prover_args.heavy_wvg_count.is_some() { + additional_args.push(format!( + "--heavy-wvg-count={}", + args.circuit_prover_args.heavy_wvg_count.unwrap() )); }; } @@ -242,9 +242,11 @@ impl WitnessVectorGeneratorArgs { #[derive(Debug, Clone, Parser, Default)] pub struct CircuitProverArgs { - #[clap(long)] - pub witness_vector_generator_count: Option, - #[clap(long)] + #[clap(short = 'l', long)] + pub light_wvg_count: Option, + #[clap(short = 'h', long)] + pub heavy_wvg_count: Option, + #[clap(short = 'm', long)] pub max_allocation: Option, } @@ -257,15 +259,21 @@ impl CircuitProverArgs { return Ok(Self::default()); } - let witness_vector_generator_count = - self.witness_vector_generator_count.unwrap_or_else(|| { - Prompt::new("Number of WVG jobs to run in parallel") - .default("1") - .ask() - }); + let light_wvg_count = self.light_wvg_count.unwrap_or_else(|| { + Prompt::new("Number of light WVG jobs to run in parallel") + .default("8") + .ask() + }); + + let heavy_wvg_count = self.heavy_wvg_count.unwrap_or_else(|| { + Prompt::new("Number of heavy WVG jobs to run in parallel") + .default("2") + .ask() + }); Ok(CircuitProverArgs { - witness_vector_generator_count: Some(witness_vector_generator_count), + light_wvg_count: Some(light_wvg_count), + heavy_wvg_count: Some(heavy_wvg_count), max_allocation: self.max_allocation, }) }