Skip to content

Commit

Permalink
feat(host): impl API "/admin/pause" (#440)
Browse files Browse the repository at this point in the history
  • Loading branch information
keroro520 authored Dec 27, 2024
1 parent 5804f23 commit ddba6b0
Show file tree
Hide file tree
Showing 11 changed files with 487 additions and 7 deletions.
7 changes: 7 additions & 0 deletions host/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ pub enum HostError {
/// For task manager errors.
#[error("There was an error with the task manager: {0}")]
TaskManager(#[from] TaskManagerError),

/// For system paused state.
#[error("System is paused")]
SystemPaused,
}

impl IntoResponse for HostError {
Expand All @@ -91,6 +95,7 @@ impl IntoResponse for HostError {
HostError::Anyhow(e) => ("anyhow_error", e.to_string()),
HostError::HandleDropped => ("handle_dropped", "".to_owned()),
HostError::CapacityFull => ("capacity_full", "".to_owned()),
HostError::SystemPaused => ("system_paused", "".to_owned()),
};
let status = Status::Error {
error: error.to_owned(),
Expand Down Expand Up @@ -130,6 +135,7 @@ impl From<HostError> for TaskStatus {
HostError::RPC(e) => TaskStatus::NetworkFailure(e.to_string()),
HostError::Guest(e) => TaskStatus::GuestProverFailure(e.to_string()),
HostError::TaskManager(e) => TaskStatus::TaskDbCorruption(e.to_string()),
HostError::SystemPaused => TaskStatus::SystemPaused,
}
}
}
Expand All @@ -151,6 +157,7 @@ impl From<&HostError> for TaskStatus {
HostError::RPC(e) => TaskStatus::NetworkFailure(e.to_string()),
HostError::Guest(e) => TaskStatus::GuestProverFailure(e.to_string()),
HostError::TaskManager(e) => TaskStatus::TaskDbCorruption(e.to_string()),
HostError::SystemPaused => TaskStatus::SystemPaused,
}
}
}
39 changes: 35 additions & 4 deletions host/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::{alloc, path::PathBuf};

use anyhow::Context;
Expand Down Expand Up @@ -108,10 +110,11 @@ impl Opts {

/// Read the options from a file and merge it with the current options.
pub fn merge_from_file(&mut self) -> HostResult<()> {
let file = std::fs::File::open(&self.config_path)?;
let file = std::fs::File::open(&self.config_path).context("Failed to open config file")?;
let reader = std::io::BufReader::new(file);
let mut config: Value = serde_json::from_reader(reader)?;
let this = serde_json::to_value(&self)?;
let mut config: Value =
serde_json::from_reader(reader).context("Failed to read config file")?;
let this = serde_json::to_value(&self).context("Failed to deserialize Opts")?;
merge(&mut config, &this);

*self = serde_json::from_value(config)?;
Expand Down Expand Up @@ -150,15 +153,17 @@ pub struct ProverState {
pub opts: Opts,
pub chain_specs: SupportedChainSpecs,
pub task_channel: mpsc::Sender<Message>,
pause_flag: Arc<AtomicBool>,
}

#[derive(Debug, Serialize)]
#[derive(Debug)]
pub enum Message {
Cancel(ProofTaskDescriptor),
Task(ProofRequest),
TaskComplete(ProofRequest),
CancelAggregate(AggregationOnlyRequest),
Aggregate(AggregationOnlyRequest),
SystemPause(tokio::sync::oneshot::Sender<HostResult<()>>),
}

impl ProverState {
Expand Down Expand Up @@ -188,6 +193,7 @@ impl ProverState {
}

let (task_channel, receiver) = mpsc::channel::<Message>(opts.concurrency_limit);
let pause_flag = Arc::new(AtomicBool::new(false));

let opts_clone = opts.clone();
let chain_specs_clone = chain_specs.clone();
Expand All @@ -202,6 +208,7 @@ impl ProverState {
opts,
chain_specs,
task_channel,
pause_flag,
})
}

Expand All @@ -212,6 +219,30 @@ impl ProverState {
pub fn request_config(&self) -> ProofRequestOpt {
self.opts.proof_request_opt.clone()
}

pub fn is_paused(&self) -> bool {
self.pause_flag.load(Ordering::SeqCst)
}

/// Set the pause flag and notify the task manager to pause, then wait for the task manager to
/// finish the pause process.
///
/// Note that this function is blocking until the task manager finishes the pause process.
pub async fn set_pause(&self, paused: bool) -> HostResult<()> {
self.pause_flag.store(paused, Ordering::SeqCst);
if paused {
// Notify task manager to start pause process
let (sender, receiver) = tokio::sync::oneshot::channel();
self.task_channel
.try_send(Message::SystemPause(sender))
.context("Failed to send pause message")?;

// Wait for the pause message to be processed
let result = receiver.await.context("Failed to receive pause message")?;
return result;
}
Ok(())
}
}

#[global_allocator]
Expand Down
253 changes: 253 additions & 0 deletions host/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ impl ProofActor {
.expect("Couldn't acquire permit");
self.run_aggregate(request, permit).await;
}
Message::SystemPause(notifier) => {
let result = self.handle_system_pause().await;
let _ = notifier.send(result);
}
}
}
}
Expand Down Expand Up @@ -384,6 +388,90 @@ impl ProofActor {

Ok(())
}

async fn cancel_all_running_tasks(&mut self) -> HostResult<()> {
info!("Cancelling all running tasks");

// Clone all tasks to avoid holding locks to avoid deadlock, they will be locked by other
// internal functions.
let running_tasks = {
let running_tasks = self.running_tasks.lock().await;
(*running_tasks).clone()
};

// Cancel all running tasks, don't stop even if any task fails.
let mut final_result = Ok(());
for proof_task_descriptor in running_tasks.keys() {
match self.cancel_task(proof_task_descriptor.clone()).await {
Ok(()) => {
info!(
"Cancel task during system pause, task: {:?}",
proof_task_descriptor
);
}
Err(e) => {
error!(
"Failed to cancel task during system pause: {}, task: {:?}",
e, proof_task_descriptor
);
final_result = final_result.and(Err(e));
}
}
}
final_result
}

async fn cancel_all_aggregation_tasks(&mut self) -> HostResult<()> {
info!("Cancelling all aggregation tasks");

// Clone all tasks to avoid holding locks to avoid deadlock, they will be locked by other
// internal functions.
let aggregate_tasks = {
let aggregate_tasks = self.aggregate_tasks.lock().await;
(*aggregate_tasks).clone()
};

// Cancel all aggregation tasks, don't stop even if any task fails.
let mut final_result = Ok(());
for request in aggregate_tasks.keys() {
match self.cancel_aggregation_task(request.clone()).await {
Ok(()) => {
info!(
"Cancel aggregation task during system pause, task: {}",
request
);
}
Err(e) => {
error!(
"Failed to cancel aggregation task during system pause: {}, task: {}",
e, request
);
final_result = final_result.and(Err(e));
}
}
}
final_result
}

async fn handle_system_pause(&mut self) -> HostResult<()> {
info!("System pausing");

let mut final_result = Ok(());

self.pending_tasks.lock().await.clear();

if let Err(e) = self.cancel_all_running_tasks().await {
final_result = final_result.and(Err(e));
}

if let Err(e) = self.cancel_all_aggregation_tasks().await {
final_result = final_result.and(Err(e));
}

// TODO(Kero): make sure all tasks are saved to database, including pending tasks.

final_result
}
}

pub async fn handle_proof(
Expand Down Expand Up @@ -483,3 +571,168 @@ pub async fn handle_proof(

Ok(proof)
}

#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;

#[tokio::test]
async fn test_handle_system_pause_happy_path() {
let (tx, rx) = mpsc::channel(100);
let mut actor = setup_actor_with_tasks(tx, rx);

let result = actor.handle_system_pause().await;
assert!(result.is_ok());
}

#[tokio::test]
async fn test_handle_system_pause_with_pending_tasks() {
let (tx, rx) = mpsc::channel(100);
let mut actor = setup_actor_with_tasks(tx, rx);

// Add some pending tasks
actor.pending_tasks.lock().await.push_back(ProofRequest {
block_number: 1,
l1_inclusion_block_number: 1,
network: "test".to_string(),
l1_network: "test".to_string(),
graffiti: B256::ZERO,
prover: Default::default(),
proof_type: Default::default(),
blob_proof_type: Default::default(),
prover_args: HashMap::new(),
});

let result = actor.handle_system_pause().await;
assert!(result.is_ok());

// Verify pending tasks were cleared
assert_eq!(actor.pending_tasks.lock().await.len(), 0);
}

#[tokio::test]
async fn test_handle_system_pause_with_running_tasks() {
let (tx, rx) = mpsc::channel(100);
let mut actor = setup_actor_with_tasks(tx, rx);

// Add some running tasks
let task_descriptor = ProofTaskDescriptor::default();
let cancellation_token = CancellationToken::new();
actor
.running_tasks
.lock()
.await
.insert(task_descriptor.clone(), cancellation_token.clone());

let result = actor.handle_system_pause().await;
assert!(result.is_ok());

// Verify running tasks were cancelled
assert!(cancellation_token.is_cancelled());

// TODO(Kero): Cancelled tasks should be removed from running_tasks
// assert_eq!(actor.running_tasks.lock().await.len(), 0);
}

#[tokio::test]
async fn test_handle_system_pause_with_aggregation_tasks() {
let (tx, rx) = mpsc::channel(100);
let mut actor = setup_actor_with_tasks(tx, rx);

// Add some aggregation tasks
let request = AggregationOnlyRequest::default();
let cancellation_token = CancellationToken::new();
actor
.aggregate_tasks
.lock()
.await
.insert(request.clone(), cancellation_token.clone());

let result = actor.handle_system_pause().await;
assert!(result.is_ok());

// Verify aggregation tasks were cancelled
assert!(cancellation_token.is_cancelled());
// TODO(Kero): Cancelled tasks should be removed from aggregate_tasks
// assert_eq!(actor.aggregate_tasks.lock().await.len(), 0);
}

#[tokio::test]
async fn test_handle_system_pause_with_failures() {
let (tx, rx) = mpsc::channel(100);
let mut actor = setup_actor_with_tasks(tx, rx);

// Add some pending tasks
{
actor.pending_tasks.lock().await.push_back(ProofRequest {
block_number: 1,
l1_inclusion_block_number: 1,
network: "test".to_string(),
l1_network: "test".to_string(),
graffiti: B256::ZERO,
prover: Default::default(),
proof_type: Default::default(),
blob_proof_type: Default::default(),
prover_args: HashMap::new(),
});
}

let good_running_task_token = {
// Add some running tasks
let task_descriptor = ProofTaskDescriptor::default();
let cancellation_token = CancellationToken::new();
actor
.running_tasks
.lock()
.await
.insert(task_descriptor.clone(), cancellation_token.clone());
cancellation_token
};

let good_aggregation_task_token = {
// Add some aggregation tasks
let request = AggregationOnlyRequest::default();
let cancellation_token = CancellationToken::new();
actor
.aggregate_tasks
.lock()
.await
.insert(request.clone(), cancellation_token.clone());
cancellation_token
};

// Setup tasks that will fail to cancel
{
let task_descriptor_should_fail_cause_not_supported_error = ProofTaskDescriptor {
proof_system: ProofType::Risc0,
..Default::default()
};
actor.running_tasks.lock().await.insert(
task_descriptor_should_fail_cause_not_supported_error.clone(),
CancellationToken::new(),
);
}

let result = actor.handle_system_pause().await;

// Verify error contains all accumulated errors
assert!(matches!(
result,
Err(HostError::Core(RaikoError::FeatureNotSupportedError(..)))
));
assert!(good_running_task_token.is_cancelled());
assert!(good_aggregation_task_token.is_cancelled());
assert!(actor.pending_tasks.lock().await.is_empty());
}

// Helper function to setup actor with common test configuration
fn setup_actor_with_tasks(tx: Sender<Message>, rx: Receiver<Message>) -> ProofActor {
let opts = Opts {
concurrency_limit: 4,
..Default::default()
};

ProofActor::new(tx, rx, opts, SupportedChainSpecs::default())
}
}
Loading

0 comments on commit ddba6b0

Please sign in to comment.