From 2f28a8075fe116a5f5233bfc65628f89e16c3ed3 Mon Sep 17 00:00:00 2001 From: Nathaniel Cook Date: Tue, 14 Jan 2025 10:55:08 -0700 Subject: [PATCH] refactor: add simpler shutdown handling This change adds a shutdown crate that simplifies shutdown handling. The API removes the need to create as many async move blocks and clone broadcast channels. --- Cargo.lock | 13 +++++++ Cargo.toml | 2 ++ anchor-service/Cargo.toml | 3 ++ anchor-service/src/anchor_batch.rs | 34 ++++++------------ api/Cargo.toml | 1 + api/src/server.rs | 7 ++-- api/src/tests.rs | 5 +-- flight/Cargo.toml | 8 +++-- one/Cargo.toml | 1 + one/src/daemon.rs | 55 ++++++++++-------------------- one/src/lib.rs | 9 +++-- shutdown/Cargo.toml | 11 ++++++ shutdown/src/lib.rs | 39 +++++++++++++++++++++ 13 files changed, 115 insertions(+), 73 deletions(-) create mode 100644 shutdown/Cargo.toml create mode 100644 shutdown/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 1db6193c..db17f490 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2100,6 +2100,7 @@ dependencies = [ "multibase 0.9.1", "multihash-codetable", "serde", + "shutdown", "sqlx", "tokio", "tracing", @@ -2131,6 +2132,7 @@ dependencies = [ "serde", "serde_ipld_dagcbor", "serde_json", + "shutdown", "swagger", "test-log", "tikv-jemalloc-ctl", @@ -2317,11 +2319,13 @@ dependencies = [ "http 1.1.0", "mockall", "object_store", + "shutdown", "test-log", "tokio", "tokio-stream", "tonic 0.12.3", "tracing", + "tracing-subscriber", ] [[package]] @@ -2496,6 +2500,7 @@ dependencies = [ "prometheus-client", "recon", "serde_ipld_dagcbor", + "shutdown", "signal-hook", "signal-hook-tokio", "swagger", @@ -10249,6 +10254,14 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "shutdown" +version = "0.47.3" +dependencies = [ + "futures", + "tokio", +] + [[package]] name = "signal-hook" version = "0.3.17" diff --git a/Cargo.toml b/Cargo.toml index 1d5864af..ba3eb6f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ members = [ "peer-svc", "pipeline", "recon", + "shutdown", "sql", "validation", "beetle/iroh-bitswap", @@ -182,6 +183,7 @@ serde_qs = "0.10.1" serde_with = "2.1" sha2 = { version = "0.10", default-features = false } sha3 = "0.10" +shutdown = { path = "./shutdown/" } smallvec = "1.10" # pragma optimize hangs forver on 0.8, possibly due to libsqlite-sys upgrade sqlx = { version = "0.7", features = ["sqlite", "runtime-tokio", "chrono"] } diff --git a/anchor-service/Cargo.toml b/anchor-service/Cargo.toml index b8686156..5199c5f4 100644 --- a/anchor-service/Cargo.toml +++ b/anchor-service/Cargo.toml @@ -29,3 +29,6 @@ chrono.workspace = true [features] test-network = [] + +[dev-dependencies] +shutdown.workspace = true diff --git a/anchor-service/src/anchor_batch.rs b/anchor-service/src/anchor_batch.rs index 14272631..283a958b 100644 --- a/anchor-service/src/anchor_batch.rs +++ b/anchor-service/src/anchor_batch.rs @@ -72,7 +72,7 @@ impl AnchorService { /// - Store the TimeEvents using the AnchorClient /// /// This function will run indefinitely, or until the process is shutdown. - pub async fn run(&mut self, shutdown_signal: impl Future) { + pub async fn run(mut self, shutdown_signal: impl Future) { let shutdown_signal = shutdown_signal.fuse(); pin_mut!(shutdown_signal); @@ -235,7 +235,8 @@ mod tests { use ceramic_core::NodeKey; use ceramic_sql::sqlite::SqlitePool; use expect_test::expect_file; - use tokio::{sync::broadcast, time::sleep}; + use shutdown::Shutdown; + use tokio::time::sleep; use super::AnchorService; use crate::{MockAnchorEventService, MockCas}; @@ -248,7 +249,7 @@ mod tests { let node_id = NodeKey::random().id(); let anchor_interval = Duration::from_millis(5); let anchor_batch_size = 1000000; - let mut anchor_service = AnchorService::new( + let anchor_service = AnchorService::new( tx_manager, event_service.clone(), pool, @@ -256,20 +257,14 @@ mod tests { anchor_interval, anchor_batch_size, ); - let (shutdown_signal_tx, mut shutdown_signal) = broadcast::channel::<()>(1); - tokio::spawn(async move { - anchor_service - .run(async move { - let _ = shutdown_signal.recv().await; - }) - .await - }); + let shutdown = Shutdown::new(); + tokio::spawn(anchor_service.run(shutdown.wait_fut())); while event_service.events.lock().unwrap().is_empty() { sleep(Duration::from_millis(1)).await; } expect_file!["./test-data/test_anchor_service_run.txt"] .assert_debug_eq(&event_service.events.lock().unwrap()); - shutdown_signal_tx.send(()).unwrap(); + shutdown.shutdown(); } #[tokio::test] @@ -280,7 +275,7 @@ mod tests { let node_id = NodeKey::random().id(); let anchor_interval = Duration::from_millis(5); let anchor_batch_size = 1000000; - let mut anchor_service = AnchorService::new( + let anchor_service = AnchorService::new( tx_manager, event_service.clone(), pool, @@ -288,20 +283,13 @@ mod tests { anchor_interval, anchor_batch_size, ); - let (shutdown_signal_tx, mut shutdown_signal) = broadcast::channel::<()>(1); - // let mut shutdown_signal = shutdown_signal_rx.resubscribe(); - tokio::spawn(async move { - anchor_service - .run(async move { - let _ = shutdown_signal.recv().await; - }) - .await - }); + let shutdown = Shutdown::new(); + tokio::spawn(anchor_service.run(shutdown.wait_fut())); while event_service.events.lock().unwrap().is_empty() { sleep(Duration::from_millis(1)).await; } expect_file!["./test-data/test_anchor_service_run_1.txt"] .assert_debug_eq(&event_service.events.lock().unwrap()); - shutdown_signal_tx.send(()).unwrap(); + shutdown.shutdown(); } } diff --git a/api/Cargo.toml b/api/Cargo.toml index d95b690d..c84837ff 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -26,6 +26,7 @@ recon.workspace = true serde.workspace = true serde_ipld_dagcbor.workspace = true serde_json.workspace = true +shutdown.workspace = true swagger.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/api/src/server.rs b/api/src/server.rs index 1f5d4a73..53bf827e 100644 --- a/api/src/server.rs +++ b/api/src/server.rs @@ -52,6 +52,7 @@ use datafusion::logical_expr::{col, lit, BuiltInWindowFunction, Expr, ExprFuncti use futures::TryFutureExt; use multiaddr::Protocol; use recon::Key; +use shutdown::Shutdown; use swagger::{ApiError, ByteArray}; #[cfg(not(target_env = "msvc"))] use tikv_jemalloc_ctl::epoch; @@ -401,7 +402,7 @@ where model: Arc, p2p: P, pipeline: Option, - shutdown_signal: broadcast::Receiver<()>, + shutdown_signal: Shutdown, ) -> Self { let (tx, event_rx) = tokio::sync::mpsc::channel::(1024); let event_store = model.clone(); @@ -433,7 +434,7 @@ where event_store: Arc, mut event_rx: tokio::sync::mpsc::Receiver, node_id: NodeId, - mut shutdown_signal: broadcast::Receiver<()>, + shutdown_signal: Shutdown, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_millis(FLUSH_INTERVAL_MS)); @@ -455,7 +456,7 @@ where events.extend(buf); } } - _ = shutdown_signal.recv() => { + _ = shutdown_signal.wait_fut() => { tracing::debug!("Insert many task got shutdown signal"); shutdown = true; } diff --git a/api/src/tests.rs b/api/src/tests.rs index f145a27b..bf090a1f 100644 --- a/api/src/tests.rs +++ b/api/src/tests.rs @@ -32,6 +32,7 @@ use mockall::{mock, predicate}; use multiaddr::Multiaddr; use multibase::Base; use recon::Key; +use shutdown::Shutdown; use test_log::test; use tokio::join; @@ -202,8 +203,8 @@ where M: EventService + 'static, P: P2PService, { - let (_, rx) = tokio::sync::broadcast::channel(1); - Server::new(node_id, network, interest, model, p2p, pipeline, rx) + let shutdown = Shutdown::new(); + Server::new(node_id, network, interest, model, p2p, pipeline, shutdown) } #[test(tokio::test)] diff --git a/flight/Cargo.toml b/flight/Cargo.toml index 46bb4090..30030598 100644 --- a/flight/Cargo.toml +++ b/flight/Cargo.toml @@ -26,12 +26,14 @@ tracing.workspace = true ceramic-arrow-test.workspace = true ceramic-pipeline.workspace = true expect-test.workspace = true -tokio = { workspace = true, features = ["macros", "rt"] } -test-log.workspace = true http.workspace = true -tokio-stream = { workspace = true, features = ["net"] } mockall.workspace = true object_store.workspace = true +shutdown.workspace = true +test-log.workspace = true +tokio = { workspace = true, features = ["macros", "rt"] } +tokio-stream = { workspace = true, features = ["net"] } +tracing-subscriber.workspace = true [package.metadata.cargo-machete] ignored = [ diff --git a/one/Cargo.toml b/one/Cargo.toml index c84f58ac..b9449ced 100644 --- a/one/Cargo.toml +++ b/one/Cargo.toml @@ -51,6 +51,7 @@ object_store.workspace = true prometheus-client.workspace = true recon.workspace = true serde_ipld_dagcbor.workspace = true +shutdown.workspace = true signal-hook = "0.3.17" signal-hook-tokio = { version = "0.3.1", features = ["futures-v0_3"] } swagger.workspace = true diff --git a/one/src/daemon.rs b/one/src/daemon.rs index f5ac7d9c..8ce1fe40 100644 --- a/one/src/daemon.rs +++ b/one/src/daemon.rs @@ -20,11 +20,11 @@ use clap::Args; use object_store::aws::AmazonS3Builder; use object_store::local::LocalFileSystem; use recon::{Recon, ReconInterestProvider}; +use shutdown::{Shutdown, ShutdownSignal}; use signal_hook::consts::signal::*; use signal_hook_tokio::Signals; use std::sync::Arc; use swagger::{auth::MakeAllowAllAuthenticator, EmptyContext}; -use tokio::sync::broadcast; use tracing::{debug, error, info, warn}; #[derive(Args, Debug)] @@ -338,14 +338,14 @@ async fn get_eth_rpc_providers( fn spawn_database_optimizer( sqlite_pool: SqlitePool, - mut shutdown: tokio::sync::broadcast::Receiver<()>, + mut shutdown: ShutdownSignal, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { let mut duration = std::time::Duration::from_secs(60 * 60 * 24); // once daily loop { // recreate interval in case it's been shortened due to error tokio::select! { - _ = shutdown.recv() => { + _ = &mut shutdown => { break; } _ = tokio::time::sleep(duration) => { @@ -408,11 +408,11 @@ pub async fn run(opts: DaemonOpts) -> Result<()> { debug!(dir = %opts.p2p_key_dir.display(), "using p2p key directory"); // Setup shutdown signal - let (shutdown_signal_tx, mut shutdown_signal) = broadcast::channel::<()>(1); + let shutdown = Shutdown::new(); let signals = Signals::new([SIGHUP, SIGTERM, SIGINT, SIGQUIT])?; let handle = signals.handle(); debug!("starting signal handler task"); - let signals_handle = tokio::spawn(handle_signals(signals, shutdown_signal_tx)); + let signals_handle = tokio::spawn(handle_signals(signals, shutdown.clone())); // Construct sqlite_pool let sqlite_pool = opts @@ -424,8 +424,10 @@ pub async fn run(opts: DaemonOpts) -> Result<()> { // spawn (and run) optimize right before we start using the database (e.g. ordering events) info!("running initial sqlite database optimize, this may take quite a while on large databases."); sqlite_pool.optimize(true).await?; - let ss = shutdown_signal.resubscribe(); - Some(spawn_database_optimizer(sqlite_pool.clone(), ss)) + Some(spawn_database_optimizer( + sqlite_pool.clone(), + shutdown.wait_fut(), + )) } else { None }; @@ -436,14 +438,11 @@ pub async fn run(opts: DaemonOpts) -> Result<()> { let peer_svc = Arc::new(PeerService::new(sqlite_pool.clone())); let interest_svc = Arc::new(InterestService::new(sqlite_pool.clone())); let event_validation = opts.event_validation.unwrap_or(true); - let mut ss = shutdown_signal.resubscribe(); let event_svc = Arc::new( EventService::try_new( sqlite_pool.clone(), ceramic_event_svc::UndeliveredEventReview::Process { - shutdown_signal: Box::new(async move { - let _ = ss.recv().await; - }), + shutdown_signal: Box::new(shutdown.wait_fut()), }, event_validation, rpc_providers, @@ -636,14 +635,10 @@ pub async fn run(opts: DaemonOpts) -> Result<()> { // Start aggregator let aggregator_handle = if opts.aggregator.unwrap_or_default() { - let mut ss = shutdown_signal.resubscribe(); let ctx = ctx.clone(); + let s = shutdown.wait_fut(); Some(tokio::spawn(async move { - if let Err(err) = ceramic_pipeline::aggregator::run(ctx, async move { - let _ = ss.recv().await; - }) - .await - { + if let Err(err) = ceramic_pipeline::aggregator::run(ctx, s).await { error!(%err, "aggregator task failed"); } })) @@ -651,14 +646,9 @@ pub async fn run(opts: DaemonOpts) -> Result<()> { None }; - let mut ss = shutdown_signal.resubscribe(); let pipeline_ctx = ctx.clone(); - let flight_handle = tokio::spawn(async move { - ceramic_flight::server::run(ctx, addr, async move { - let _ = ss.recv().await; - }) - .await - }); + let flight_handle = + tokio::spawn(ceramic_flight::server::run(ctx, addr, shutdown.wait_fut())); (Some(pipeline_ctx), aggregator_handle, Some(flight_handle)) } else { (None, None, None) @@ -679,7 +669,7 @@ pub async fn run(opts: DaemonOpts) -> Result<()> { Duration::from_secs(opts.anchor_poll_interval), opts.anchor_poll_retry_count, ); - let mut anchor_service = AnchorService::new( + let anchor_service = AnchorService::new( Arc::new(remote_cas), event_svc.clone(), sqlite_pool.clone(), @@ -688,14 +678,7 @@ pub async fn run(opts: DaemonOpts) -> Result<()> { opts.anchor_batch_size, ); - let mut shutdown_signal = shutdown_signal.resubscribe(); - Some(tokio::spawn(async move { - anchor_service - .run(async move { - let _ = shutdown_signal.recv().await; - }) - .await - })) + Some(tokio::spawn(anchor_service.run(shutdown.wait_fut()))) } else { None }; @@ -708,7 +691,7 @@ pub async fn run(opts: DaemonOpts) -> Result<()> { Arc::new(model_svc), ipfs.client(), pipeline_ctx, - shutdown_signal.resubscribe(), + shutdown.clone(), ); if opts.authentication { ceramic_server.with_authentication(true); @@ -740,9 +723,7 @@ pub async fn run(opts: DaemonOpts) -> Result<()> { hyper::server::Server::try_bind(&opts.bind_address.parse()?) .map_err(|e| anyhow!("Failed to bind address: {}. {}", opts.bind_address, e))? .serve(service) - .with_graceful_shutdown(async move { - let _ = shutdown_signal.recv().await; - }) + .with_graceful_shutdown(shutdown.wait_fut()) .await?; debug!("api server finished, starting shutdown..."); diff --git a/one/src/lib.rs b/one/src/lib.rs index a9ccaa0b..dd68ece3 100644 --- a/one/src/lib.rs +++ b/one/src/lib.rs @@ -20,10 +20,11 @@ use multibase::Base; use multihash::Multihash; use multihash_codetable::Code; use multihash_derive::Hasher; +use shutdown::Shutdown; use signal_hook_tokio::Signals; use std::str::FromStr; use std::{env, path::PathBuf}; -use tokio::{io::AsyncReadExt, sync::broadcast}; +use tokio::io::AsyncReadExt; use tracing::{debug, error, info, warn}; #[derive(Parser, Debug)] @@ -343,15 +344,13 @@ impl DBOpts { } } -async fn handle_signals(mut signals: Signals, shutdown: broadcast::Sender<()>) { +async fn handle_signals(mut signals: Signals, shutdown: Shutdown) { let mut shutdown = Some(shutdown); while let Some(signal) = signals.next().await { debug!(?signal, "signal received"); if let Some(shutdown) = shutdown.take() { info!("sending shutdown message"); - shutdown - .send(()) - .expect("should be able to send shutdown message"); + shutdown.shutdown(); } } } diff --git a/shutdown/Cargo.toml b/shutdown/Cargo.toml new file mode 100644 index 00000000..fd84fcc7 --- /dev/null +++ b/shutdown/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "shutdown" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +tokio.workspace = true +futures.workspace = true diff --git a/shutdown/src/lib.rs b/shutdown/src/lib.rs new file mode 100644 index 00000000..b5f6d9a7 --- /dev/null +++ b/shutdown/src/lib.rs @@ -0,0 +1,39 @@ +use futures::future::BoxFuture; +use tokio::sync::broadcast; + +/// A shutdown signal is a future that resolve to unit. +pub type ShutdownSignal = BoxFuture<'static, ()>; + +/// Shutdown can be used to signal shutdown across many different tasks. +/// Shutdown is cheaply clonable so it can be shared with as many tasks as needed. +#[derive(Clone)] +pub struct Shutdown { + tx: broadcast::Sender<()>, +} + +impl Default for Shutdown { + fn default() -> Self { + Self::new() + } +} + +impl Shutdown { + pub fn new() -> Self { + let (tx, _rx) = broadcast::channel(1); + Self { tx } + } + /// Signal that all listeners should shutdown. + /// Shutdown can be called from any clone. + pub fn shutdown(&self) { + let _ = self.tx.send(()); + } + /// Construct a future that resolves when the shutdown signal is sent. + /// + /// The future is cancel safe. + pub fn wait_fut(&self) -> ShutdownSignal { + let mut sub = self.tx.subscribe(); + Box::pin(async move { + let _ = sub.recv().await; + }) + } +}