diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 3d5862d4..da1436cf 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -37,6 +37,9 @@ macro_rules! setup_tracked_fn { // Path to the cycle recovery function to use. cycle_recovery_fn: ($($cycle_recovery_fn:tt)*), + // Path to function to get the initial value to use for cycle recovery. + cycle_recovery_initial: ($($cycle_recovery_initial:tt)*), + // Name of cycle recovery strategy variant to use. cycle_recovery_strategy: $cycle_recovery_strategy:ident, @@ -160,7 +163,7 @@ macro_rules! setup_tracked_fn { const CYCLE_STRATEGY: $zalsa::CycleRecoveryStrategy = $zalsa::CycleRecoveryStrategy::$cycle_recovery_strategy; - fn should_backdate_value( + fn values_equal( old_value: &Self::Output<'_>, new_value: &Self::Output<'_>, ) -> bool { @@ -168,7 +171,7 @@ macro_rules! setup_tracked_fn { if $no_eq { false } else { - $zalsa::should_backdate_value(old_value, new_value) + $zalsa::values_equal(old_value, new_value) } } } @@ -179,12 +182,17 @@ macro_rules! setup_tracked_fn { $inner($db, $($input_id),*) } + fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db, ($($input_id),*): ($($input_ty),*)) -> Self::Output<$db_lt> { + $($cycle_recovery_initial)*(db, $($input_id),*) + } + fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, - cycle: &$zalsa::Cycle, + value: &Self::Output<$db_lt>, + count: u32, ($($input_id),*): ($($input_ty),*) - ) -> Self::Output<$db_lt> { - $($cycle_recovery_fn)*(db, cycle, $($input_id),*) + ) -> $zalsa::CycleRecoveryAction> { + $($cycle_recovery_fn)*(db, value, count, $($input_id),*) } fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index a8b8122b..a1cd1e73 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -3,11 +3,18 @@ // a macro because it can take a variadic number of arguments. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $cycle:ident, $($other_inputs:ident),*) => { - { - std::mem::drop($db); - std::mem::drop(($($other_inputs),*)); - panic!("cannot recover from cycle `{:?}`", $cycle) - } - } + ($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{ + std::mem::drop($db); + std::mem::drop(($($other_inputs),*)); + panic!("cannot recover from cycle") + }}; +} + +#[macro_export] +macro_rules! unexpected_cycle_initial { + ($db:ident, $($other_inputs:ident),*) => {{ + std::mem::drop($db); + std::mem::drop(($($other_inputs),*)); + panic!("no cycle initial value") + }}; } diff --git a/components/salsa-macros/src/accumulator.rs b/components/salsa-macros/src/accumulator.rs index 0c9f39bf..8de9446e 100644 --- a/components/salsa-macros/src/accumulator.rs +++ b/components/salsa-macros/src/accumulator.rs @@ -38,7 +38,8 @@ impl AllowedOptions for Accumulator { const SINGLETON: bool = false; const DATA: bool = false; const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + const CYCLE_INITIAL: bool = false; const LRU: bool = false; const CONSTRUCTOR_NAME: bool = false; } diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index e5d67d53..56993bd3 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -49,7 +49,9 @@ impl crate::options::AllowedOptions for InputStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index aea9a4e1..66dbb223 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -50,7 +50,9 @@ impl crate::options::AllowedOptions for InternedStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/components/salsa-macros/src/options.rs b/components/salsa-macros/src/options.rs index 6f30bb3e..c4175c70 100644 --- a/components/salsa-macros/src/options.rs +++ b/components/salsa-macros/src/options.rs @@ -44,10 +44,15 @@ pub(crate) struct Options { /// If this is `Some`, the value is the ``. pub db_path: Option, - /// The `recovery_fn = ` option is used to indicate the recovery function. + /// The `cycle_fn = ` option is used to indicate the cycle recovery function. /// /// If this is `Some`, the value is the ``. - pub recovery_fn: Option, + pub cycle_fn: Option, + + /// The `cycle_initial = ` option is the initial value for cycle iteration. + /// + /// If this is `Some`, the value is the ``. + pub cycle_initial: Option, /// The `data = ` option is used to define the name of the data type for an interned /// struct. @@ -79,7 +84,8 @@ impl Default for Options { no_debug: Default::default(), no_clone: Default::default(), db_path: Default::default(), - recovery_fn: Default::default(), + cycle_fn: Default::default(), + cycle_initial: Default::default(), data: Default::default(), constructor_name: Default::default(), phantom: Default::default(), @@ -99,7 +105,8 @@ pub(crate) trait AllowedOptions { const SINGLETON: bool; const DATA: bool; const DB: bool; - const RECOVERY_FN: bool; + const CYCLE_FN: bool; + const CYCLE_INITIAL: bool; const LRU: bool; const CONSTRUCTOR_NAME: bool; } @@ -207,20 +214,39 @@ impl syn::parse::Parse for Options { "`db` option not allowed here", )); } - } else if ident == "recovery_fn" { - if A::RECOVERY_FN { + } else if ident == "cycle_fn" { + if A::CYCLE_FN { + let _eq = Equals::parse(input)?; + let path = syn::Path::parse(input)?; + if let Some(old) = std::mem::replace(&mut options.cycle_fn, Some(path)) { + return Err(syn::Error::new( + old.span(), + "option `cycle_fn` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`cycle_fn` option not allowed here", + )); + } + } else if ident == "cycle_initial" { + if A::CYCLE_INITIAL { + // TODO(carljm) should it be an error to give cycle_initial without cycle_fn, + // or should we just allow this to fall into potentially infinite iteration, if + // iteration never converges? let _eq = Equals::parse(input)?; let path = syn::Path::parse(input)?; - if let Some(old) = std::mem::replace(&mut options.recovery_fn, Some(path)) { + if let Some(old) = std::mem::replace(&mut options.cycle_initial, Some(path)) { return Err(syn::Error::new( old.span(), - "option `recovery_fn` provided twice", + "option `cycle_initial` provided twice", )); } } else { return Err(syn::Error::new( ident.span(), - "`recovery_fn` option not allowed here", + "`cycle_initial` option not allowed here", )); } } else if ident == "data" { diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 57023ef2..74cc3bca 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -39,7 +39,9 @@ impl crate::options::AllowedOptions for TrackedFn { const DB: bool = false; - const RECOVERY_FN: bool = true; + const CYCLE_FN: bool = true; + + const CYCLE_INITIAL: bool = true; const LRU: bool = true; @@ -68,9 +70,20 @@ impl Macro { let input_ids = self.input_ids(&item); let input_tys = self.input_tys(&item)?; let output_ty = self.output_ty(&db_lt, &item)?; - let (cycle_recovery_fn, cycle_recovery_strategy) = self.cycle_recovery(); + let (cycle_recovery_fn, cycle_recovery_initial, cycle_recovery_strategy) = + self.cycle_recovery()?; let is_specifiable = self.args.specify.is_some(); - let no_eq = self.args.no_eq.is_some(); + let no_eq = if let Some(token) = &self.args.no_eq { + if self.args.cycle_fn.is_some() { + return Err(syn::Error::new_spanned( + token, + "the `no_eq` option cannot be used with `cycle_fn`", + )); + } + true + } else { + false + }; let mut inner_fn = item.clone(); inner_fn.vis = syn::Visibility::Inherited; @@ -127,6 +140,7 @@ impl Macro { output_ty: #output_ty, inner_fn: #inner_fn, cycle_recovery_fn: #cycle_recovery_fn, + cycle_recovery_initial: #cycle_recovery_initial, cycle_recovery_strategy: #cycle_recovery_strategy, is_specifiable: #is_specifiable, no_eq: #no_eq, @@ -160,14 +174,26 @@ impl Macro { Ok(ValidFn { db_ident, db_path }) } - fn cycle_recovery(&self) -> (TokenStream, TokenStream) { - if let Some(recovery_fn) = &self.args.recovery_fn { - (quote!((#recovery_fn)), quote!(Fallback)) - } else { - ( + fn cycle_recovery(&self) -> syn::Result<(TokenStream, TokenStream, TokenStream)> { + match (&self.args.cycle_fn, &self.args.cycle_initial) { + (Some(cycle_fn), Some(cycle_initial)) => Ok(( + quote!((#cycle_fn)), + quote!((#cycle_initial)), + quote!(Fixpoint), + )), + (None, None) => Ok(( quote!((salsa::plumbing::unexpected_cycle_recovery!)), + quote!((salsa::plumbing::unexpected_cycle_initial!)), quote!(Panic), - ) + )), + (Some(_), None) => Err(syn::Error::new_spanned( + self.args.cycle_fn.as_ref().unwrap(), + "must provide `cycle_initial` along with `cycle_fn`", + )), + (None, Some(_)) => Err(syn::Error::new_spanned( + self.args.cycle_initial.as_ref().unwrap(), + "must provide `cycle_fn` along with `cycle_initial`", + )), } } diff --git a/components/salsa-macros/src/tracked_struct.rs b/components/salsa-macros/src/tracked_struct.rs index 1730b340..009f6f8f 100644 --- a/components/salsa-macros/src/tracked_struct.rs +++ b/components/salsa-macros/src/tracked_struct.rs @@ -45,7 +45,9 @@ impl crate::options::AllowedOptions for TrackedStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/src/active_query.rs b/src/active_query.rs index aadcf196..9a8e92ed 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -1,4 +1,4 @@ -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions}; use crate::tracked_struct::IdentityHash; @@ -9,7 +9,7 @@ use crate::{ key::{DatabaseKeyIndex, DependencyIndex}, tracked_struct::{Disambiguator, Identity}, zalsa_local::EMPTY_DEPENDENCIES, - Cycle, Id, Revision, + Id, Revision, }; #[derive(Debug)] @@ -35,9 +35,6 @@ pub(crate) struct ActiveQuery { /// True if there was an untracked read. untracked_read: bool, - /// Stores the entire cycle, if one is found and this query is part of it. - pub(crate) cycle: Option, - /// When new tracked structs are created, their data is hashed, and the resulting /// hash is added to this map. If it is not present, then the disambiguator is 0. /// Otherwise it is 1 more than the current value (which is incremented). @@ -54,6 +51,9 @@ pub(crate) struct ActiveQuery { /// Stores the values accumulated to the given ingredient. /// The type of accumulated value is erased but known to the ingredient. pub(crate) accumulated: AccumulatedMap, + + /// Provisional cycle results that this query depends on. + pub(crate) cycle_heads: FxHashSet, } impl ActiveQuery { @@ -64,10 +64,10 @@ impl ActiveQuery { changed_at: Revision::start(), input_outputs: FxIndexSet::default(), untracked_read: false, - cycle: None, disambiguator_map: Default::default(), tracked_struct_ids: Default::default(), accumulated: Default::default(), + cycle_heads: Default::default(), } } @@ -76,10 +76,12 @@ impl ActiveQuery { input: DependencyIndex, durability: Durability, revision: Revision, + cycle_heads: &FxHashSet, ) { self.input_outputs.insert((EdgeKind::Input, input)); self.durability = self.durability.min(durability); self.changed_at = self.changed_at.max(revision); + self.cycle_heads.extend(cycle_heads); } pub(super) fn add_untracked_read(&mut self, changed_at: Revision) { @@ -125,36 +127,11 @@ impl ActiveQuery { durability: self.durability, tracked_struct_ids: self.tracked_struct_ids, accumulated: self.accumulated, + cycle_ignore: !self.cycle_heads.is_empty(), + cycle_heads: self.cycle_heads, } } - /// Adds any dependencies from `other` into `self`. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(super) fn add_from(&mut self, other: &ActiveQuery) { - self.changed_at = self.changed_at.max(other.changed_at); - self.durability = self.durability.min(other.durability); - self.untracked_read |= other.untracked_read; - self.input_outputs - .extend(other.input_outputs.iter().copied()); - } - - /// Removes the participants in `cycle` from my dependencies. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(super) fn remove_cycle_participants(&mut self, cycle: &Cycle) { - for p in cycle.participant_keys() { - let p: DependencyIndex = p.into(); - self.input_outputs.shift_remove(&(EdgeKind::Input, p)); - } - } - - /// Copy the changed-at, durability, and dependencies from `cycle_query`. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(crate) fn take_inputs_from(&mut self, cycle_query: &ActiveQuery) { - self.changed_at = cycle_query.changed_at; - self.durability = cycle_query.durability; - self.input_outputs.clone_from(&cycle_query.input_outputs); - } - pub(super) fn disambiguate(&mut self, key: IdentityHash) -> Disambiguator { let disambiguator = self .disambiguator_map diff --git a/src/cycle.rs b/src/cycle.rs index 6071aa30..c90f2170 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,91 +1,11 @@ -use crate::{key::DatabaseKeyIndex, Database}; -use std::{panic::AssertUnwindSafe, sync::Arc}; - -/// Captures the participants of a cycle that occurred when executing a query. -/// -/// This type is meant to be used to help give meaningful error messages to the -/// user or to help salsa developers figure out why their program is resulting -/// in a computation cycle. -/// -/// It is used in a few ways: -/// -/// * During [cycle recovery](https://https://salsa-rs.github.io/salsa/cycles/fallback.html), -/// where it is given to the fallback function. -/// * As the panic value when an unexpected cycle (i.e., a cycle where one or more participants -/// lacks cycle recovery information) occurs. -/// -/// You can read more about cycle handling in -/// the [salsa book](https://https://salsa-rs.github.io/salsa/cycles.html). -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Cycle { - participants: CycleParticipants, -} - -pub(crate) type CycleParticipants = Arc>; - -impl Cycle { - pub(crate) fn new(participants: CycleParticipants) -> Self { - Self { participants } - } - - /// True if two `Cycle` values represent the same cycle. - pub(crate) fn is(&self, cycle: &Cycle) -> bool { - Arc::ptr_eq(&self.participants, &cycle.participants) - } - - pub(crate) fn throw(self) -> ! { - tracing::debug!("throwing cycle {:?}", self); - std::panic::resume_unwind(Box::new(self)) - } - - pub(crate) fn catch(execute: impl FnOnce() -> T) -> Result { - match std::panic::catch_unwind(AssertUnwindSafe(execute)) { - Ok(v) => Ok(v), - Err(err) => match err.downcast::() { - Ok(cycle) => Err(*cycle), - Err(other) => std::panic::resume_unwind(other), - }, - } - } - - /// Iterate over the [`DatabaseKeyIndex`] for each query participating - /// in the cycle. The start point of this iteration within the cycle - /// is arbitrary but deterministic, but the ordering is otherwise determined - /// by the execution. - pub fn participant_keys(&self) -> impl Iterator + '_ { - self.participants.iter().copied() - } - - /// Returns a vector with the debug information for - /// all the participants in the cycle. - pub fn all_participants(&self, _db: &dyn Database) -> Vec { - self.participant_keys().collect() - } - - /// Returns a vector with the debug information for - /// those participants in the cycle that lacked recovery - /// information. - pub fn unexpected_participants(&self, db: &dyn Database) -> Vec { - self.participant_keys() - .filter(|&d| d.cycle_recovery_strategy(db) == CycleRecoveryStrategy::Panic) - .collect() - } -} - -impl std::fmt::Debug for Cycle { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - crate::attach::with_attached_database(|db| { - f.debug_struct("UnexpectedCycle") - .field("all_participants", &self.all_participants(db)) - .field("unexpected_participants", &self.unexpected_participants(db)) - .finish() - }) - .unwrap_or_else(|| { - f.debug_struct("Cycle") - .field("participants", &self.participants) - .finish() - }) - } +/// Return value from a cycle recovery function. +#[derive(Debug)] +pub enum CycleRecoveryAction { + /// Iterate the cycle again to look for a fixpoint. + Iterate, + + /// Cut off iteration and use the given result value for this query. + Fallback(T), } /// Cycle recovery strategy: Is this query capable of recovering from @@ -95,14 +15,12 @@ pub enum CycleRecoveryStrategy { /// Cannot recover from cycles: panic. /// /// This is the default. - /// - /// In the case of a failure due to a cycle, the panic - /// value will be the `Cycle`. Panic, - /// Recovers from cycles by storing a sentinel value. + /// Recovers from cycles by fixpoint iterating and/or falling + /// back to a sentinel value. /// - /// This value is computed by the query's `recovery_fn` - /// function. - Fallback, + /// This choice is computed by the query's `cycle_recovery` + /// function and initial value. + Fixpoint, } diff --git a/src/function.rs b/src/function.rs index 07f13d49..97511d22 100644 --- a/src/function.rs +++ b/src/function.rs @@ -2,14 +2,14 @@ use std::{any::Any, fmt, sync::Arc}; use crate::{ accumulator::accumulated_map::AccumulatedMap, - cycle::CycleRecoveryStrategy, + cycle::{CycleRecoveryAction, CycleRecoveryStrategy}, ingredient::fmt_index, key::DatabaseKeyIndex, plumbing::JarAux, salsa_struct::SalsaStructInDb, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, zalsa_local::QueryOrigin, - Cycle, Database, Id, Revision, + Database, Id, Revision, }; use self::delete::DeletedEntries; @@ -49,13 +49,12 @@ pub trait Configuration: Any { /// (and, if so, how). const CYCLE_STRATEGY: CycleRecoveryStrategy; - /// Invokes after a new result `new_value`` has been computed for which an older memoized - /// value existed `old_value`. Returns true if the new value is equal to the older one - /// and hence should be "backdated" (i.e., marked as having last changed in an older revision, - /// even though it was recomputed). + /// Invokes after a new result `new_value`` has been computed for which an older memoized value + /// existed `old_value`, or in fixpoint iteration. Returns true if the new value is equal to + /// the older one. /// - /// This invokes user's code in form of the `Eq` impl. - fn should_backdate_value(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool; + /// This invokes user code in form of the `Eq` impl. + fn values_equal(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool; /// Convert from the id used internally to the value that execute is expecting. /// This is a no-op if the input to the function is a salsa struct. @@ -67,15 +66,16 @@ pub trait Configuration: Any { /// This invokes the function the user wrote. fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; - /// If the cycle strategy is `Fallback`, then invoked when `key` is a participant - /// in a cycle to find out what value it should have. - /// - /// This invokes the recovery function given by the user. + /// Get the cycle recovery initial value. + fn cycle_initial<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; + + /// Decide whether to iterate a cycle again or fallback. fn recover_from_cycle<'db>( db: &'db Self::DbView, - cycle: &Cycle, + value: &Self::Output<'db>, + count: u32, input: Self::Input<'db>, - ) -> Self::Output<'db>; + ) -> CycleRecoveryAction>; } /// Function ingredients are the "workhorse" of salsa. @@ -116,9 +116,9 @@ pub struct IngredientImpl { } /// True if `old_value == new_value`. Invoked by the generated -/// code for `should_backdate_value` so as to give a better +/// code for `values_equal` so as to give a better /// error message. -pub fn should_backdate_value(old_value: &V, new_value: &V) -> bool { +pub fn values_equal(old_value: &V, new_value: &V) -> bool { old_value == new_value } diff --git a/src/function/backdate.rs b/src/function/backdate.rs index bfca6f05..7eff1b3d 100644 --- a/src/function/backdate.rs +++ b/src/function/backdate.rs @@ -21,7 +21,7 @@ where // consumers must be aware of. Becoming *more* durable // is not. See the test `constant_to_non_constant`. if revisions.durability >= old_memo.revisions.durability - && C::should_backdate_value(old_value, value) + && C::values_equal(old_value, value) { tracing::debug!( "value is equal, back-dating to {:?}", diff --git a/src/function/execute.rs b/src/function/execute.rs index 4171fe6d..ef43fc2c 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,8 +1,6 @@ use std::sync::Arc; -use crate::{ - zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind, -}; +use crate::{zalsa::ZalsaDatabase, Database, DatabaseKeyIndex, Event, EventKind}; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -22,12 +20,12 @@ where pub(super) fn execute<'db>( &'db self, db: &'db C::DbView, - active_query: ActiveQueryGuard<'_>, + database_key_index: DatabaseKeyIndex, opt_old_memo: Option>>>, ) -> &'db Memo> { - let zalsa = db.zalsa(); + let (zalsa, zalsa_local) = db.zalsas(); let revision_now = zalsa.current_revision(); - let database_key_index = active_query.database_key_index; + let id = database_key_index.key_index; tracing::info!("{:?}: executing query", database_key_index); @@ -38,53 +36,125 @@ where }, }); - // If we already executed this query once, then use the tracked-struct ids from the - // previous execution as the starting point for the new one. - if let Some(old_memo) = &opt_old_memo { - active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids); - } + let mut iteration_count: u32 = 0; - // Query was not previously executed, or value is potentially - // stale, or value is absent. Let's execute! - let database_key_index = active_query.database_key_index; - let id = database_key_index.key_index; - let value = match Cycle::catch(|| C::execute(db, C::id_to_input(db, id))) { - Ok(v) => v, - Err(cycle) => { + // Our provisional value from the previous iteration, when doing fixpoint iteration. + // Initially it's set to None, because the initial provisional value is created lazily, + // only when a cycle is actually encountered. + let mut opt_last_provisional: Option<&Memo<::Output<'db>>> = None; + + loop { + let active_query = zalsa_local.push_query(database_key_index); + + // If we already executed this query once, then use the tracked-struct ids from the + // previous execution as the starting point for the new one. + if let Some(old_memo) = &opt_old_memo { + active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids); + } + + // Query was not previously executed, or value is potentially + // stale, or value is absent. Let's execute! + let mut new_value = C::execute(db, C::id_to_input(db, id)); + let mut revisions = active_query.pop(); + + // Did the new result we got depend on our own provisional value, in a cycle? + if revisions.cycle_heads.contains(&database_key_index) { + let opt_owned_last_provisional; + let last_provisional_value = if let Some(last_provisional) = opt_last_provisional { + // We have a last provisional value from our previous time around the loop. + last_provisional + .value + .as_ref() + .expect("provisional value should not be evicted by LRU") + } else { + // This is our first time around the loop; a provisional value must have been + // inserted into the memo table when the cycle was hit, so let's pull our + // initial provisional value from there. + opt_owned_last_provisional = self.get_memo_from_table_for(zalsa, id); + opt_owned_last_provisional + .as_deref() + .expect( + "{database_key_index:#?} is a cycle head, \ + but no provisional memo found", + ) + .value + .as_ref() + .expect("provisional value should not be evicted by LRU") + }; tracing::debug!( - "{database_key_index:?}: caught cycle {cycle:?}, have strategy {:?}", - C::CYCLE_STRATEGY + "{database_key_index:?}: execute: \ + I am a cycle head, comparing last provisional value with new value" ); - match C::CYCLE_STRATEGY { - crate::cycle::CycleRecoveryStrategy::Panic => cycle.throw(), - crate::cycle::CycleRecoveryStrategy::Fallback => { - if let Some(c) = active_query.take_cycle() { - assert!(c.is(&cycle)); - C::recover_from_cycle(db, &cycle, C::id_to_input(db, id)) - } else { - // we are not a participant in this cycle - debug_assert!(!cycle - .participant_keys() - .any(|k| k == database_key_index)); - cycle.throw() + // If the new result is equal to the last provisional result, the cycle has + // converged and we are done. + if !C::values_equal(&new_value, last_provisional_value) { + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do: + match C::recover_from_cycle( + db, + &new_value, + iteration_count, + C::id_to_input(db, id), + ) { + crate::CycleRecoveryAction::Iterate => { + tracing::debug!("{database_key_index:?}: execute: iterate again"); + iteration_count = iteration_count.checked_add(1).expect( + "fixpoint iteration of {database_key_index:#?} should \ + converge before u32::MAX iterations", + ); + revisions.cycle_ignore = false; + opt_last_provisional = Some(self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + )); + continue; + } + crate::CycleRecoveryAction::Fallback(fallback_value) => { + tracing::debug!( + "{database_key_index:?}: execute: user cycle_fn says to fall back" + ); + new_value = fallback_value; } } } + // This is no longer a provisional result, it's our final result, so remove ourself + // from the cycle heads, and iterate one last time to remove ourself from all other + // results in the cycle as well and turn them into usable cached results. + // TODO Can we avoid doing this? the extra cycle is quite expensive if there is a + // nested cycle. Maybe track the relevant memos and replace them all with the cycle + // head removed? Or just let them keep the cycle head and allow cycle memos to be + // used when we are not actually iterating the cycle for that head? + tracing::debug!( + "{database_key_index:?}: execute: fixpoint iteration has a final value, \ + one more iteration to remove cycle heads from memos" + ); + revisions.cycle_heads.remove(&database_key_index); + revisions.cycle_ignore = false; + self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + ); + continue; } - }; - let mut revisions = active_query.pop(); - // If the new value is equal to the old one, then it didn't - // really change, even if some of its inputs have. So we can - // "backdate" its `changed_at` revision to be the same as the - // old value. - if let Some(old_memo) = &opt_old_memo { - self.backdate_if_appropriate(old_memo, &mut revisions, &value); - self.diff_outputs(db, database_key_index, old_memo, &mut revisions); - } + tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}"); - tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}"); + // If the new value is equal to the old one, then it didn't + // really change, even if some of its inputs have. So we can + // "backdate" its `changed_at` revision to be the same as the + // old value. + if let Some(old_memo) = &opt_old_memo { + self.backdate_if_appropriate(old_memo, &mut revisions, &new_value); + self.diff_outputs(db, database_key_index, old_memo, &mut revisions); + } - self.insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions)) + return self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + ); + } } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 07a08d96..9fa08afa 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,4 +1,7 @@ -use crate::{runtime::StampedValue, zalsa::ZalsaDatabase, AsDynDatabase as _, Id}; +use crate::{ + runtime::StampedValue, table::sync::ClaimResult, zalsa::ZalsaDatabase, + zalsa_local::QueryRevisions, AsDynDatabase as _, Id, +}; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -21,7 +24,12 @@ where self.evict_value_from_memo_for(zalsa, evicted); } - zalsa_local.report_tracked_read(self.database_key_index(id).into(), durability, changed_at); + zalsa_local.report_tracked_read( + self.database_key_index(id).into(), + durability, + changed_at, + &memo.revisions.cycle_heads, + ); value } @@ -61,28 +69,58 @@ where let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let _claim_guard = zalsa.sync_table_for(id).claim( + let _claim_guard = match zalsa.sync_table_for(id).claim( db.as_dyn_database(), zalsa_local, database_key_index, self.memo_ingredient_index, - )?; - - // Push the query on the stack. - let active_query = zalsa_local.push_query(database_key_index); + ) { + ClaimResult::Retry => return None, + ClaimResult::Cycle => { + return self + .initial_value(db, database_key_index.key_index) + .map(|initial_value| { + tracing::debug!( + "hit cycle at {database_key_index:#?}, \ + inserting and returning fixpoint initial value" + ); + self.insert_memo( + zalsa, + id, + Memo::new( + Some(initial_value), + zalsa.current_revision(), + QueryRevisions::fixpoint_initial( + database_key_index, + zalsa.current_revision(), + ), + ), + ) + }) + .or_else(|| { + panic!( + "dependency graph cycle querying {database_key_index:#?}; \ + set cycle_fn/cycle_initial to fixpoint iterate" + ) + }) + } + ClaimResult::Claimed(guard) => guard, + }; // Now that we've claimed the item, check again to see if there's a "hot" value. - let zalsa = db.zalsa(); let opt_old_memo = self.get_memo_from_table_for(zalsa, id); if let Some(old_memo) = &opt_old_memo { - if old_memo.value.is_some() && self.deep_verify_memo(db, old_memo, &active_query) { - // Unsafety invariant: memo is present in memo_map. - unsafe { - return Some(self.extend_memo_lifetime(old_memo)); + if old_memo.value.is_some() { + let active_query = zalsa_local.push_query(database_key_index); + if self.deep_verify_memo(db, old_memo, &active_query) { + // Unsafety invariant: memo is present in memo_map. + unsafe { + return Some(self.extend_memo_lifetime(old_memo)); + } } } } - Some(self.execute(db, active_query, opt_old_memo)) + Some(self.execute(db, database_key_index, opt_old_memo)) } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index b1d671a3..e0b33eb8 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,5 +1,7 @@ use crate::{ + cycle::CycleRecoveryStrategy, key::DatabaseKeyIndex, + table::sync::ClaimResult, zalsa::{Zalsa, ZalsaDatabase}, zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin}, AsDynDatabase as _, Id, Revision, @@ -53,14 +55,24 @@ where let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(key_index); - let _claim_guard = zalsa.sync_table_for(key_index).claim( + let _claim_guard = match zalsa.sync_table_for(key_index).claim( db.as_dyn_database(), zalsa_local, database_key_index, self.memo_ingredient_index, - )?; - let active_query = zalsa_local.push_query(database_key_index); - + ) { + ClaimResult::Retry => return None, + ClaimResult::Cycle => match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Panic => panic!( + "dependency graph cycle validating {database_key_index:#?}; \ + set cycle_fn/cycle_initial to fixpoint iterate" + ), + // If we hit a cycle in memo validation, but we support fixpoint iteration, just + // consider the memo changed so we'll re-run the iteration in this revision. + CycleRecoveryStrategy::Fixpoint => return Some(true), + }, + ClaimResult::Claimed(guard) => guard, + }; // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index) else { return Some(true); @@ -73,6 +85,7 @@ where ); // Check if the inputs are still valid and we can just compare `changed_at`. + let active_query = zalsa_local.push_query(database_key_index); if self.deep_verify_memo(db, &old_memo, &active_query) { return Some(old_memo.revisions.changed_at > revision); } @@ -82,7 +95,7 @@ where // backdated. In that case, although we will have computed a new memo, // the value has not logically changed. if old_memo.value.is_some() { - let memo = self.execute(db, active_query, Some(old_memo)); + let memo = self.execute(db, database_key_index, Some(old_memo)); let changed_at = memo.revisions.changed_at; return Some(changed_at > revision); } @@ -101,6 +114,9 @@ where database_key_index: DatabaseKeyIndex, memo: &Memo>, ) -> bool { + if memo.revisions.cycle_ignore { + return false; + } let verified_at = memo.verified_at.load(); let revision_now = zalsa.current_revision(); @@ -131,14 +147,16 @@ where /// /// Takes an [`ActiveQueryGuard`] argument because this function recursively /// walks dependencies of `old_memo` and may even execute them to see if their - /// outputs have changed. As that could lead to cycles, it is important that the - /// query is on the stack. + /// outputs have changed. pub(super) fn deep_verify_memo( &self, db: &C::DbView, old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, ) -> bool { + if old_memo.revisions.cycle_ignore { + return false; + } let zalsa = db.zalsa(); let database_key_index = active_query.database_key_index; @@ -166,7 +184,7 @@ where // in rev 1 but not in rev 2. return false; } - QueryOrigin::BaseInput => { + QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { // This value was `set` by the mutator thread -- ie, it's a base input and it cannot be out of date. return true; } diff --git a/src/function/memo.rs b/src/function/memo.rs index f982dc33..8e167ccb 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -7,8 +7,8 @@ use crossbeam::atomic::AtomicCell; use crate::zalsa_local::QueryOrigin; use crate::{ - key::DatabaseKeyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, Event, EventKind, Id, - Revision, + cycle::CycleRecoveryStrategy, key::DatabaseKeyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, + Event, EventKind, Id, Revision, }; use super::{Configuration, IngredientImpl}; @@ -68,7 +68,8 @@ impl IngredientImpl { match memo.revisions.origin { QueryOrigin::Assigned(_) | QueryOrigin::DerivedUntracked(_) - | QueryOrigin::BaseInput => { + | QueryOrigin::BaseInput + | QueryOrigin::FixpointInitial => { // Careful: Cannot evict memos whose values were // assigned as output of another query // or those with untracked inputs @@ -86,6 +87,17 @@ impl IngredientImpl { } } } + + pub(super) fn initial_value<'db>( + &'db self, + db: &'db C::DbView, + key: Id, + ) -> Option> { + match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db, C::id_to_input(db, key))), + CycleRecoveryStrategy::Panic => None, + } + } } #[derive(Debug)] diff --git a/src/function/specify.rs b/src/function/specify.rs index 9eccad65..f5803b3d 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -70,6 +70,8 @@ where origin: QueryOrigin::Assigned(active_query_key), tracked_struct_ids: Default::default(), accumulated: Default::default(), + cycle_heads: Default::default(), + cycle_ignore: false, }; if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) { diff --git a/src/input.rs b/src/input.rs index fdad27ac..d37c6ac1 100644 --- a/src/input.rs +++ b/src/input.rs @@ -188,6 +188,7 @@ impl IngredientImpl { }, stamp.durability, stamp.changed_at, + &Default::default(), ); &value.fields } diff --git a/src/interned.rs b/src/interned.rs index 0c6d32cd..40a2a450 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -133,6 +133,7 @@ where DependencyIndex::for_table(self.ingredient_index), Durability::MAX, self.reset_at, + &Default::default(), ); // Optimisation to only get read lock on the map if the data has already diff --git a/src/key.rs b/src/key.rs index 92e63541..de84f710 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,7 +1,4 @@ -use crate::{ - accumulator::accumulated_map::AccumulatedMap, cycle::CycleRecoveryStrategy, - zalsa::IngredientIndex, Database, Id, -}; +use crate::{accumulator::accumulated_map::AccumulatedMap, zalsa::IngredientIndex, Database, Id}; /// An integer that uniquely identifies a particular query instance within the /// database. Used to track dependencies between queries. Fully ordered and @@ -96,10 +93,6 @@ impl DatabaseKeyIndex { self.key_index } - pub(crate) fn cycle_recovery_strategy(self, db: &dyn Database) -> CycleRecoveryStrategy { - self.ingredient_index.cycle_recovery_strategy(db) - } - pub(crate) fn accumulated(self, db: &dyn Database) -> Option<&AccumulatedMap> { db.zalsa() .lookup_ingredient(self.ingredient_index) diff --git a/src/lib.rs b/src/lib.rs index 0ee7d3ec..4fe625ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,7 @@ mod zalsa_local; pub use self::accumulator::Accumulator; pub use self::cancelled::Cancelled; -pub use self::cycle::Cycle; +pub use self::cycle::CycleRecoveryAction; pub use self::database::AsDynDatabase; pub use self::database::Database; pub use self::database_impl::DatabaseImpl; @@ -68,11 +68,11 @@ pub mod plumbing { pub use crate::array::Array; pub use crate::attach::attach; pub use crate::attach::with_attached_database; - pub use crate::cycle::Cycle; + pub use crate::cycle::CycleRecoveryAction; pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::current_revision; pub use crate::database::Database; - pub use crate::function::should_backdate_value; + pub use crate::function::values_equal; pub use crate::id::AsId; pub use crate::id::FromId; pub use crate::id::Id; @@ -112,6 +112,7 @@ pub mod plumbing { pub use salsa_macro_rules::setup_method_body; pub use salsa_macro_rules::setup_tracked_fn; pub use salsa_macro_rules::setup_tracked_struct; + pub use salsa_macro_rules::unexpected_cycle_initial; pub use salsa_macro_rules::unexpected_cycle_recovery; pub mod accumulator { diff --git a/src/runtime.rs b/src/runtime.rs index ba35f09f..ff413e45 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,16 +1,11 @@ -use std::{ - panic::panic_any, - sync::{atomic::AtomicUsize, Arc}, - thread::ThreadId, -}; +use std::{sync::atomic::AtomicUsize, thread::ThreadId}; use crossbeam::atomic::AtomicCell; use parking_lot::Mutex; use crate::{ - active_query::ActiveQuery, cycle::CycleRecoveryStrategy, durability::Durability, - key::DatabaseKeyIndex, revision::AtomicRevision, table::Table, zalsa_local::ZalsaLocal, - Cancelled, Cycle, Database, Event, EventKind, Revision, + durability::Durability, key::DatabaseKeyIndex, revision::AtomicRevision, table::Table, + zalsa_local::ZalsaLocal, Cancelled, Database, Event, EventKind, Revision, }; use self::dependency_graph::DependencyGraph; @@ -49,7 +44,12 @@ pub struct Runtime { pub(crate) enum WaitResult { Completed, Panicked, - Cycle(Cycle), +} + +#[derive(Clone, Debug)] +pub(crate) enum BlockResult { + Completed, + Cycle, } #[derive(Copy, Clone, Debug)] @@ -154,8 +154,8 @@ impl Runtime { r_new } - /// Block until `other_id` completes executing `database_key`; - /// panic or unwind in the case of a cycle. + /// Block until `other_id` completes executing `database_key`, or return `BlockResult::Cycle` + /// immediately in case of a cycle. /// /// `query_mutex_guard` is the guard for the current query's state; /// it will be dropped after we have successfully registered the @@ -165,34 +165,19 @@ impl Runtime { /// /// If the thread `other_id` panics, then our thread is considered /// cancelled, so this function will panic with a `Cancelled` value. - /// - /// # Cycle handling - /// - /// If the thread `other_id` already depends on the current thread, - /// and hence there is a cycle in the query graph, then this function - /// will unwind instead of returning normally. The method of unwinding - /// depends on the [`Self::mutual_cycle_recovery_strategy`] - /// of the cycle participants: - /// - /// * [`CycleRecoveryStrategy::Panic`]: panic with the [`Cycle`] as the value. - /// * [`CycleRecoveryStrategy::Fallback`]: initiate unwinding with [`CycleParticipant::unwind`]. - pub(crate) fn block_on_or_unwind( + pub(crate) fn block_on( &self, db: &dyn Database, local_state: &ZalsaLocal, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, - ) { + ) -> BlockResult { let mut dg = self.dependency_graph.lock(); let thread_id = std::thread::current().id(); if dg.depends_on(other_id, thread_id) { - self.unblock_cycle_and_maybe_throw(db, local_state, &mut dg, database_key, other_id); - - // If the above fn returns, then (via cycle recovery) it has unblocked the - // cycle, so we can continue. - assert!(!dg.depends_on(other_id, thread_id)); + return BlockResult::Cycle; } db.salsa_event(&|| Event { @@ -217,128 +202,12 @@ impl Runtime { local_state.restore_query_stack(stack); match result { - WaitResult::Completed => (), + WaitResult::Completed => BlockResult::Completed, // If the other thread panicked, then we consider this thread // cancelled. The assumption is that the panic will be detected // by the other thread and responded to appropriately. WaitResult::Panicked => Cancelled::PropagatedPanic.throw(), - - WaitResult::Cycle(c) => c.throw(), - } - } - - /// Handles a cycle in the dependency graph that was detected when the - /// current thread tried to block on `database_key_index` which is being - /// executed by `to_id`. If this function returns, then `to_id` no longer - /// depends on the current thread, and so we should continue executing - /// as normal. Otherwise, the function will throw a `Cycle` which is expected - /// to be caught by some frame on our stack. This occurs either if there is - /// a frame on our stack with cycle recovery (possibly the top one!) or if there - /// is no cycle recovery at all. - fn unblock_cycle_and_maybe_throw( - &self, - db: &dyn Database, - local_state: &ZalsaLocal, - dg: &mut DependencyGraph, - database_key_index: DatabaseKeyIndex, - to_id: ThreadId, - ) { - tracing::debug!( - "unblock_cycle_and_maybe_throw(database_key={:?})", - database_key_index - ); - - let mut from_stack = local_state.take_query_stack(); - let from_id = std::thread::current().id(); - - // Make a "dummy stack frame". As we iterate through the cycle, we will collect the - // inputs from each participant. Then, if we are participating in cycle recovery, we - // will propagate those results to all participants. - let mut cycle_query = ActiveQuery::new(database_key_index); - - // Identify the cycle participants: - let cycle = { - let mut v = vec![]; - dg.for_each_cycle_participant( - from_id, - &mut from_stack, - database_key_index, - to_id, - |aqs| { - aqs.iter_mut().for_each(|aq| { - cycle_query.add_from(aq); - v.push(aq.database_key_index); - }); - }, - ); - - // We want to give the participants in a deterministic order - // (at least for this execution, not necessarily across executions), - // no matter where it started on the stack. Find the minimum - // key and rotate it to the front. - let min = v - .iter() - .map(|key| (key.ingredient_index.debug_name(db), key)) - .min() - .unwrap() - .1; - let index = v.iter().position(|p| p == min).unwrap(); - v.rotate_left(index); - - // No need to store extra memory. - v.shrink_to_fit(); - - Cycle::new(Arc::new(v)) - }; - tracing::debug!("cycle {cycle:?}, cycle_query {cycle_query:#?}"); - - // We can remove the cycle participants from the list of dependencies; - // they are a strongly connected component (SCC) and we only care about - // dependencies to things outside the SCC that control whether it will - // form again. - cycle_query.remove_cycle_participants(&cycle); - - // Mark each cycle participant that has recovery set, along with - // any frames that come after them on the same thread. Those frames - // are going to be unwound so that fallback can occur. - dg.for_each_cycle_participant(from_id, &mut from_stack, database_key_index, to_id, |aqs| { - aqs.iter_mut() - .skip_while(|aq| { - match db - .zalsa() - .lookup_ingredient(aq.database_key_index.ingredient_index) - .cycle_recovery_strategy() - { - CycleRecoveryStrategy::Panic => true, - CycleRecoveryStrategy::Fallback => false, - } - }) - .for_each(|aq| { - tracing::debug!("marking {:?} for fallback", aq.database_key_index); - aq.take_inputs_from(&cycle_query); - assert!(aq.cycle.is_none()); - aq.cycle = Some(cycle.clone()); - }); - }); - - // Unblock every thread that has cycle recovery with a `WaitResult::Cycle`. - // They will throw the cycle, which will be caught by the frame that has - // cycle recovery so that it can execute that recovery. - let (me_recovered, others_recovered) = - dg.maybe_unblock_runtimes_in_cycle(from_id, &from_stack, database_key_index, to_id); - - local_state.restore_query_stack(from_stack); - - if me_recovered { - // If the current thread has recovery, we want to throw - // so that it can begin. - cycle.throw() - } else if others_recovered { - // If other threads have recovery but we didn't: return and we will block on them. - } else { - // if nobody has recover, then we panic - panic_any(cycle); } } diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs index 9db1752a..c90e650d 100644 --- a/src/runtime/dependency_graph.rs +++ b/src/runtime/dependency_graph.rs @@ -31,7 +31,6 @@ pub(super) struct DependencyGraph { #[derive(Debug)] struct Edge { blocked_on_id: ThreadId, - blocked_on_key: DatabaseKeyIndex, stack: QueryStack, /// Signalled whenever a query with dependents completes. @@ -55,131 +54,6 @@ impl DependencyGraph { p == to_id } - /// Invokes `closure` with a `&mut ActiveQuery` for each query that participates in the cycle. - /// The cycle runs as follows: - /// - /// 1. The runtime `from_id`, which has the stack `from_stack`, would like to invoke `database_key`... - /// 2. ...but `database_key` is already being executed by `to_id`... - /// 3. ...and `to_id` is transitively dependent on something which is present on `from_stack`. - pub(super) fn for_each_cycle_participant( - &mut self, - from_id: ThreadId, - from_stack: &mut QueryStack, - database_key: DatabaseKeyIndex, - to_id: ThreadId, - mut closure: impl FnMut(&mut [ActiveQuery]), - ) { - debug_assert!(self.depends_on(to_id, from_id)); - - // To understand this algorithm, consider this [drawing](https://is.gd/TGLI9v): - // - // database_key = QB2 - // from_id = A - // to_id = B - // from_stack = [QA1, QA2, QA3] - // - // self.edges[B] = { C, QC2, [QB1..QB3] } - // self.edges[C] = { A, QA2, [QC1..QC3] } - // - // The cyclic - // edge we have - // failed to add. - // : - // A : B C - // : - // QA1 v QB1 QC1 - // ┌► QA2 ┌──► QB2 ┌─► QC2 - // │ QA3 ───┘ QB3 ──┘ QC3 ───┐ - // │ │ - // └───────────────────────────────┘ - // - // Final output: [QB2, QB3, QC2, QC3, QA2, QA3] - - let mut id = to_id; - let mut key = database_key; - while id != from_id { - // Looking at the diagram above, the idea is to - // take the edge from `to_id` starting at `key` - // (inclusive) and down to the end. We can then - // load up the next thread (i.e., we start at B/QB2, - // and then load up the dependency on C/QC2). - let edge = self.edges.get_mut(&id).unwrap(); - let prefix = edge - .stack - .iter_mut() - .take_while(|p| p.database_key_index != key) - .count(); - closure(&mut edge.stack[prefix..]); - id = edge.blocked_on_id; - key = edge.blocked_on_key; - } - - // Finally, we copy in the results from `from_stack`. - let prefix = from_stack - .iter_mut() - .take_while(|p| p.database_key_index != key) - .count(); - closure(&mut from_stack[prefix..]); - } - - /// Unblock each blocked runtime (excluding the current one) if some - /// query executing in that runtime is participating in cycle fallback. - /// - /// Returns a boolean (Current, Others) where: - /// * Current is true if the current runtime has cycle participants - /// with fallback; - /// * Others is true if other runtimes were unblocked. - pub(super) fn maybe_unblock_runtimes_in_cycle( - &mut self, - from_id: ThreadId, - from_stack: &QueryStack, - database_key: DatabaseKeyIndex, - to_id: ThreadId, - ) -> (bool, bool) { - // See diagram in `for_each_cycle_participant`. - let mut id = to_id; - let mut key = database_key; - let mut others_unblocked = false; - while id != from_id { - let edge = self.edges.get(&id).unwrap(); - let prefix = edge - .stack - .iter() - .take_while(|p| p.database_key_index != key) - .count(); - let next_id = edge.blocked_on_id; - let next_key = edge.blocked_on_key; - - if let Some(cycle) = edge.stack[prefix..] - .iter() - .rev() - .find_map(|aq| aq.cycle.clone()) - { - // Remove `id` from the list of runtimes blocked on `next_key`: - self.query_dependents - .get_mut(&next_key) - .unwrap() - .retain(|r| *r != id); - - // Unblock runtime so that it can resume execution once lock is released: - self.unblock_runtime(id, WaitResult::Cycle(cycle)); - - others_unblocked = true; - } - - id = next_id; - key = next_key; - } - - let prefix = from_stack - .iter() - .take_while(|p| p.database_key_index != key) - .count(); - let this_unblocked = from_stack[prefix..].iter().any(|aq| aq.cycle.is_some()); - - (this_unblocked, others_unblocked) - } - /// Modifies the graph so that `from_id` is blocked /// on `database_key`, which is being computed by /// `to_id`. @@ -235,7 +109,6 @@ impl DependencyGraph { from_id, Edge { blocked_on_id: to_id, - blocked_on_key: database_key, stack: from_stack, condvar: condvar.clone(), }, diff --git a/src/table/sync.rs b/src/table/sync.rs index dfe78a23..75077927 100644 --- a/src/table/sync.rs +++ b/src/table/sync.rs @@ -7,7 +7,7 @@ use parking_lot::RwLock; use crate::{ key::DatabaseKeyIndex, - runtime::WaitResult, + runtime::{BlockResult, WaitResult}, zalsa::{MemoIngredientIndex, Zalsa}, zalsa_local::ZalsaLocal, Database, @@ -30,6 +30,12 @@ struct SyncState { anyone_waiting: AtomicBool, } +pub(crate) enum ClaimResult<'a> { + Retry, + Cycle, + Claimed(ClaimGuard<'a>), +} + impl SyncTable { pub(crate) fn claim<'me>( &'me self, @@ -37,7 +43,7 @@ impl SyncTable { zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, - ) -> Option> { + ) -> ClaimResult<'me> { let mut syncs = self.syncs.write(); let zalsa = db.zalsa(); let thread_id = std::thread::current().id(); @@ -50,7 +56,7 @@ impl SyncTable { id: thread_id, anyone_waiting: AtomicBool::new(false), }); - Some(ClaimGuard { + ClaimResult::Claimed(ClaimGuard { database_key_index, memo_ingredient_index, zalsa, @@ -68,8 +74,10 @@ impl SyncTable { // boolean is to decide *whether* to acquire the lock, // not to gate future atomic reads. anyone_waiting.store(true, Ordering::Relaxed); - zalsa.block_on_or_unwind(db, zalsa_local, database_key_index, *other_id, syncs); - None + match zalsa.block_on(db, zalsa_local, database_key_index, *other_id, syncs) { + BlockResult::Completed => ClaimResult::Retry, + BlockResult::Cycle => ClaimResult::Cycle, + } } } } diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 540bb765..550ed090 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -561,6 +561,7 @@ where }, data.durability, field_changed_at, + &Default::default(), ); unsafe { self.to_self_ref(&data.fields) } diff --git a/src/zalsa.rs b/src/zalsa.rs index 2f8fa95f..09e51243 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -5,10 +5,9 @@ use std::any::{Any, TypeId}; use std::marker::PhantomData; use std::thread::ThreadId; -use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::{Ingredient, Jar, JarAux}; use crate::nonce::{Nonce, NonceGenerator}; -use crate::runtime::{Runtime, WaitResult}; +use crate::runtime::{BlockResult, Runtime, WaitResult}; use crate::table::memo::MemoTable; use crate::table::sync::SyncTable; use crate::table::Table; @@ -82,18 +81,9 @@ impl IngredientIndex { self.0 as usize } - pub(crate) fn cycle_recovery_strategy(self, db: &dyn Database) -> CycleRecoveryStrategy { - db.zalsa().lookup_ingredient(self).cycle_recovery_strategy() - } - pub fn successor(self, index: usize) -> Self { IngredientIndex(self.0 + 1 + index as u32) } - - /// Return the "debug name" of this ingredient (e.g., the name of the tracked struct it represents) - pub(crate) fn debug_name(self, db: &dyn Database) -> &'static str { - db.zalsa().lookup_ingredient(self).debug_name() - } } /// A special secondary index *just* for ingredients that attach @@ -266,16 +256,16 @@ impl Zalsa { } /// See [`Runtime::block_on_or_unwind`][] - pub(crate) fn block_on_or_unwind( + pub(crate) fn block_on( &self, db: &dyn Database, local_state: &ZalsaLocal, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, - ) { + ) -> BlockResult { self.runtime - .block_on_or_unwind(db, local_state, database_key, other_id, query_mutex_guard) + .block_on(db, local_state, database_key, other_id, query_mutex_guard) } /// See [`Runtime::unblock_queries_blocked_on`][] diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 9076ffa9..4ef67f64 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1,4 +1,4 @@ -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use tracing::debug; use crate::accumulator::accumulated_map::AccumulatedMap; @@ -14,7 +14,6 @@ use crate::tracked_struct::{Disambiguator, Identity, IdentityHash}; use crate::zalsa::IngredientIndex; use crate::Accumulator; use crate::Cancelled; -use crate::Cycle; use crate::Database; use crate::Event; use crate::EventKind; @@ -170,6 +169,7 @@ impl ZalsaLocal { input: DependencyIndex, durability: Durability, changed_at: Revision, + cycle_heads: &FxHashSet, ) { debug!( "report_tracked_read(input={:?}, durability={:?}, changed_at={:?})", @@ -177,32 +177,7 @@ impl ZalsaLocal { ); self.with_query_stack(|stack| { if let Some(top_query) = stack.last_mut() { - top_query.add_read(input, durability, changed_at); - - // We are a cycle participant: - // - // C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0 - // ^ ^ - // : | - // This edge -----+ | - // | - // | - // N0 - // - // In this case, the value we have just read from `Ci+1` - // is actually the cycle fallback value and not especially - // interesting. We unwind now with `CycleParticipant` to avoid - // executing the rest of our query function. This unwinding - // will be caught and our own fallback value will be used. - // - // Note that `Ci+1` may` have *other* callers who are not - // participants in the cycle (e.g., N0 in the graph above). - // They will not have the `cycle` marker set in their - // stack frames, so they will just read the fallback value - // from `Ci+1` and continue on their merry way. - if let Some(cycle) = &top_query.cycle { - cycle.clone().throw() - } + top_query.add_read(input, durability, changed_at, cycle_heads); } }) } @@ -377,9 +352,38 @@ pub(crate) struct QueryRevisions { pub(super) tracked_struct_ids: FxHashMap, pub(super) accumulated: AccumulatedMap, + + /// This result was computed based on provisional values from + /// these cycle heads. The "cycle head" is the query responsible + /// for managing a fixpoint iteration. In a cycle like + /// `--> A --> B --> C --> A`, the cycle head is query `A`: it is + /// the query whose value is requested while it is executing, + /// which must provide the initial provisional value and decide, + /// after each iteration, whether the cycle has converged or must + /// iterate again. + pub(super) cycle_heads: FxHashSet, + + /// True if this result is based on provisional results of other + /// queries, and is not created explicitly by the query managing + /// a fixpoint iteration (the "cycle head"); this should never be + /// treated as a valid cached result. + pub(super) cycle_ignore: bool, } impl QueryRevisions { + pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, revision: Revision) -> Self { + let cycle_heads = FxHashSet::from_iter([query]); + Self { + changed_at: revision, + durability: Durability::MAX, + origin: QueryOrigin::FixpointInitial, + tracked_struct_ids: Default::default(), + accumulated: Default::default(), + cycle_heads, + cycle_ignore: false, + } + } + pub(crate) fn stamped_value(&self, value: V) -> StampedValue { self.stamp_template().stamp(value) } @@ -428,6 +432,9 @@ pub enum QueryOrigin { /// The [`QueryEdges`] argument contains a listing of all the inputs we saw /// (but we know there were more). DerivedUntracked(QueryEdges), + + /// The value is an initial provisional value for a query that supports fixpoint iteration. + FixpointInitial, } impl QueryOrigin { @@ -435,7 +442,9 @@ impl QueryOrigin { pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), - QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, + QueryOrigin::Assigned(_) | QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { + None + } }; opt_edges.into_iter().flat_map(|edges| edges.inputs()) } @@ -444,7 +453,9 @@ impl QueryOrigin { pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), - QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, + QueryOrigin::Assigned(_) | QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { + None + } }; opt_edges.into_iter().flat_map(|edges| edges.outputs()) } @@ -557,18 +568,8 @@ impl ActiveQueryGuard<'_> { // Extract accumulated inputs. let popped_query = self.complete(); - // If this frame were a cycle participant, it would have unwound. - assert!(popped_query.cycle.is_none()); - popped_query.into_revisions() } - - /// If the active query is registered as a cycle participant, remove and - /// return that cycle. - pub(crate) fn take_cycle(&self) -> Option { - self.local_state - .with_query_stack(|stack| stack.last_mut()?.cycle.take()) - } } impl Drop for ActiveQueryGuard<'_> { diff --git a/tests/cycle/dataflow.rs b/tests/cycle/dataflow.rs new file mode 100644 index 00000000..d8ef4cf3 --- /dev/null +++ b/tests/cycle/dataflow.rs @@ -0,0 +1,239 @@ +//! Test case for fixpoint iteration cycle resolution. +//! +//! This test case is intended to simulate a (very simplified) version of a real dataflow analysis +//! using fixpoint iteration. +use salsa::{CycleRecoveryAction, Database as Db, Setter}; +use std::collections::BTreeSet; +use std::iter::IntoIterator; + +/// A Use of a symbol. +#[salsa::input] +struct Use { + reaching_definitions: Vec, +} + +/// A Definition of a symbol, either of the form `base + increment` or `0 + increment`. +#[salsa::input] +struct Definition { + base: Option, + increment: usize, +} + +#[derive(Eq, PartialEq, Clone, Debug)] +enum Type { + Bottom, + Values(Box<[usize]>), + Top, +} + +impl Type { + fn join(tys: impl IntoIterator) -> Type { + let mut result = Type::Bottom; + for ty in tys.into_iter() { + result = match (result, ty) { + (result, Type::Bottom) => result, + (_, Type::Top) => Type::Top, + (Type::Top, _) => Type::Top, + (Type::Bottom, ty) => ty, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend(a_ints); + set.extend(b_ints); + Type::Values(set.into_iter().collect()) + } + } + } + result + } +} + +#[salsa::tracked(cycle_fn=use_cycle_recover, cycle_initial=use_cycle_initial)] +fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { + let defs = u.reaching_definitions(db); + match defs[..] { + [] => Type::Bottom, + [def] => infer_definition(db, def), + _ => Type::join(defs.iter().map(|&def| infer_definition(db, def))), + } +} + +#[salsa::tracked(cycle_fn=def_cycle_recover, cycle_initial=def_cycle_initial)] +fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { + let increment_ty = Type::Values(Box::from([def.increment(db)])); + if let Some(base) = def.base(db) { + let base_ty = infer_use(db, base); + add(&base_ty, &increment_ty) + } else { + increment_ty + } +} + +fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { + Type::Bottom +} + +fn def_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _def: Definition, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { + Type::Bottom +} + +fn use_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _use: Use, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { + match value { + Type::Bottom => CycleRecoveryAction::Iterate, + Type::Values(_) => { + if count > 4 { + CycleRecoveryAction::Fallback(Type::Top) + } else { + CycleRecoveryAction::Iterate + } + } + Type::Top => CycleRecoveryAction::Iterate, + } +} + +fn add(a: &Type, b: &Type) -> Type { + match (a, b) { + (Type::Bottom, _) | (_, Type::Bottom) => Type::Bottom, + (Type::Top, _) | (_, Type::Top) => Type::Top, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend( + a_ints + .into_iter() + .flat_map(|a| b_ints.into_iter().map(move |b| a + b)), + ); + Type::Values(set.into_iter().collect()) + } + } +} + +/// x = 1 +#[test] +fn simple() { + let db = salsa::DatabaseImpl::new(); + + let def = Definition::new(&db, None, 1); + let u = Use::new(&db, vec![def]); + + let ty = infer_use(&db, u); + + assert_eq!(ty, Type::Values(Box::from([1]))); +} + +/// x = 1 if flag else 2 +#[test] +fn union() { + let db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 2); + let u = Use::new(&db, vec![def1, def2]); + + let ty = infer_use(&db, u); + + assert_eq!(ty, Type::Values(Box::from([1, 2]))); +} + +/// x = 1 if flag else 2; y = x + 1 +#[test] +fn union_add() { + let db = salsa::DatabaseImpl::new(); + + let x1 = Definition::new(&db, None, 1); + let x2 = Definition::new(&db, None, 2); + let x_use = Use::new(&db, vec![x1, x2]); + let y_def = Definition::new(&db, Some(x_use), 1); + let y_use = Use::new(&db, vec![y_def]); + + let ty = infer_use(&db, y_use); + + assert_eq!(ty, Type::Values(Box::from([2, 3]))); +} + +/// x = 1; loop { x = x + 0 } +#[test] +fn cycle_converges_then_diverges() { + let mut db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 0); + let u = Use::new(&db, vec![def1, def2]); + def2.set_base(&mut db).to(Some(u)); + + let ty = infer_use(&db, u); + + // Loop converges on 1 + assert_eq!(ty, Type::Values(Box::from([1]))); + + // Set the increment on x from 0 to 1 + let new_increment = 1; + def2.set_increment(&mut db).to(new_increment); + + // Now the loop diverges and we fall back to Top + assert_eq!(infer_use(&db, u), Type::Top); +} + +/// x = 1; loop { x = x + 1 } +#[test] +fn cycle_diverges_then_converges() { + let mut db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 1); + let u = Use::new(&db, vec![def1, def2]); + def2.set_base(&mut db).to(Some(u)); + + let ty = infer_use(&db, u); + + // Loop diverges. Cut it off and fallback to Type::Top + assert_eq!(ty, Type::Top); + + // Set the increment from 1 to 0. + def2.set_increment(&mut db).to(0); + + // Now the loop converges on 1. + assert_eq!(infer_use(&db, u), Type::Values(Box::from([1]))); +} + +/// x = 0; y = 0; loop { x = y + 0; y = x + 0 } +#[test] +fn multi_symbol_cycle_converges_then_diverges() { + let mut db = salsa::DatabaseImpl::new(); + + let defx0 = Definition::new(&db, None, 0); + let defy0 = Definition::new(&db, None, 0); + let defx1 = Definition::new(&db, None, 0); + let defy1 = Definition::new(&db, None, 0); + let use_x = Use::new(&db, vec![defx0, defx1]); + let use_y = Use::new(&db, vec![defy0, defy1]); + defx1.set_base(&mut db).to(Some(use_y)); + defy1.set_base(&mut db).to(Some(use_x)); + + // Both symbols converge on 0 + assert_eq!(infer_use(&db, use_x), Type::Values(Box::from([0]))); + assert_eq!(infer_use(&db, use_y), Type::Values(Box::from([0]))); + + // Set the increment on x from 0 to 1. + defx1.set_increment(&mut db).to(1); + + // Now the loop diverges and we fall back to Top. + assert_eq!(infer_use(&db, use_x), Type::Top); + assert_eq!(infer_use(&db, use_y), Type::Top); +} diff --git a/tests/cycle/main.rs b/tests/cycle/main.rs new file mode 100644 index 00000000..09cb6e83 --- /dev/null +++ b/tests/cycle/main.rs @@ -0,0 +1,760 @@ +//! Test cases for fixpoint iteration cycle resolution. +//! +//! These test cases use a generic query setup that allows constructing arbitrary dependency +//! graphs, and attempts to achieve good coverage of various cases. +mod dataflow; + +use salsa::{CycleRecoveryAction, Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; + +/// A vector of inputs a query can evaluate to get an iterator of u8 values to operate on. +/// +/// This allows creating arbitrary query graphs between the four queries below (`min_iterate`, +/// `max_iterate`, `min_panic`, `max_panic`) for testing cycle behaviors. +#[salsa::input] +struct Inputs { + inputs: Vec, +} + +impl Inputs { + fn values(self, db: &dyn Db) -> impl Iterator + '_ { + self.inputs(db).into_iter().map(|input| input.eval(db)) + } +} + +/// A single input, evaluating to a single u8 value. +#[derive(Clone, Debug)] +enum Input { + /// a simple value + Value(u8), + + /// a simple value, reported as an untracked read + UntrackedRead(u8), + + /// minimum of the given inputs, with fixpoint iteration on cycles + MinIterate(Inputs), + + /// maximum of the given inputs, with fixpoint iteration on cycles + MaxIterate(Inputs), + + /// minimum of the given inputs, panicking on cycles + MinPanic(Inputs), + + /// maximum of the given inputs, panicking on cycles + MaxPanic(Inputs), + + /// value of the given input, plus one + Successor(Box), +} + +impl Input { + fn eval(self, db: &dyn Db) -> u8 { + match self { + Self::Value(value) => value, + Self::UntrackedRead(value) => { + db.report_untracked_read(); + value + } + Self::MinIterate(inputs) => min_iterate(db, inputs), + Self::MaxIterate(inputs) => max_iterate(db, inputs), + Self::MinPanic(inputs) => min_panic(db, inputs), + Self::MaxPanic(inputs) => max_panic(db, inputs), + Self::Successor(input) => input.eval(db) + 1, + } + } + + fn assert(self, db: &dyn Db, expected: u8) { + assert_eq!(self.eval(db), expected) + } +} + +#[salsa::tracked(cycle_fn=min_recover, cycle_initial=min_initial)] +fn min_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { + inputs.values(db).min().expect("empty inputs!") +} + +const MIN_COUNT_FALLBACK: u8 = 100; +const MIN_VALUE_FALLBACK: u8 = 5; +const MIN_VALUE: u8 = 10; + +fn min_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction { + if *value < MIN_VALUE { + CycleRecoveryAction::Fallback(MIN_VALUE_FALLBACK) + } else if count > 10 { + CycleRecoveryAction::Fallback(MIN_COUNT_FALLBACK) + } else { + CycleRecoveryAction::Iterate + } +} + +fn min_initial(_db: &dyn Db, _inputs: Inputs) -> u8 { + 255 +} + +#[salsa::tracked(cycle_fn=max_recover, cycle_initial=max_initial)] +fn max_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { + inputs.values(db).max().expect("empty inputs!") +} + +const MAX_COUNT_FALLBACK: u8 = 200; +const MAX_VALUE_FALLBACK: u8 = 250; +const MAX_VALUE: u8 = 245; + +fn max_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction { + if *value > MAX_VALUE { + CycleRecoveryAction::Fallback(MAX_VALUE_FALLBACK) + } else if count > 10 { + CycleRecoveryAction::Fallback(MAX_COUNT_FALLBACK) + } else { + CycleRecoveryAction::Iterate + } +} + +fn max_initial(_db: &dyn Db, _inputs: Inputs) -> u8 { + 0 +} + +#[salsa::tracked] +fn min_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { + inputs.values(db).min().expect("empty inputs!") +} + +#[salsa::tracked] +fn max_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { + inputs.values(db).max().expect("empty inputs!") +} + +// Diagram nomenclature for nodes: Each node is represented as a:xx(ii), where `a` is a sequential +// identifier from `a`, `b`, `c`..., xx is one of the four query kinds: +// - `Ni` for `min_iterate` +// - `Xi` for `max_iterate` +// - `Np` for `min_panic` +// - `Xp` for `max_panic` +// +// and `ii` is the inputs for that query, represented as a comma-separated list, with each +// component representing an input: +// - `a`, `b`, `c`... where the input is another node, +// - `uXX` for `UntrackedRead(XX)` +// - `vXX` for `Value(XX)` +// - `sY` for `Successor(Y)` +// +// We always enter from the top left node in the diagram. + +/// a:Np(a) -+ +/// ^ | +/// +--------+ +/// +/// Simple self-cycle, no iteration, should panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn self_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + a_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.eval(&db); +} + +/// a:Np(u10, a) -+ +/// ^ | +/// +-------------+ +/// +/// Simple self-cycle with untracked read, no iteration, should panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn self_untracked_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + a_in.set_inputs(&mut db) + .to(vec![Input::UntrackedRead(10), a.clone()]); + + a.eval(&db); +} + +/// a:Ni(a) -+ +/// ^ | +/// +--------+ +/// +/// Simple self-cycle, iteration converges on initial value. +#[test] +fn self_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + a_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert(&db, 255); +} + +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, one with iteration and one without. +/// If we enter from the one with iteration, we converge on its initial value. +#[test] +fn two_mixed_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert(&db, 255); +} + +/// a:Np(b) --> b:Ni(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, one with iteration and one without. +/// If we enter from the one with no iteration, we panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn two_mixed_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(b_in); + let b = Input::MinIterate(a_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.eval(&db); +} + +/// a:Ni(b) --> b:Xi(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, both with iteration. +/// We converge on the initial value of whichever we first enter from. +#[test] +fn two_iterate_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MaxIterate(b_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert(&db, 255); + b.assert(&db, 255); +} + +/// a:Xi(b) --> b:Ni(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, both with iteration. +/// We converge on the initial value of whichever we enter from. +/// (Same setup as above test, different query order.) +#[test] +fn two_iterate_converge_initial_value_2() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MinIterate(b_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert(&db, 0); + b.assert(&db, 0); +} + +/// a:Np(b) --> b:Ni(c) --> c:Xp(b) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, enter indirectly at node with iteration, converge on its initial value. +#[test] +fn two_indirect_iterate_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.assert(&db, 255); +} + +/// a:Xp(b) --> b:Np(c) --> c:Xi(b) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, enter indirectly at node without iteration, panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn two_indirect_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinPanic(b_in); + let c = Input::MaxIterate(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.eval(&db); +} + +/// a:Np(b) -> b:Ni(v250,c) -> c:Xp(b) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, converges to non-initial value. +#[test] +fn two_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![Input::Value(250), c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.assert(&db, 250); +} + +/// a:Xp(b) -> b:Xi(v10,c) -> c:Xp(sb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back due to >10 iterations. +#[test] +fn two_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![Input::Value(10), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Successor(Box::new(b))]); + + a.assert(&db, MAX_COUNT_FALLBACK + 1); +} + +/// a:Xp(b) -> b:Xi(v241,c) -> c:Xp(sb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back due to value reaching >MAX_VALUE (we start at 241 and each +/// iteration increments until we reach >245). +#[test] +fn two_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![Input::Value(241), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Successor(Box::new(b))]); + + a.assert(&db, MAX_VALUE_FALLBACK + 1); +} + +/// a:Ni(b) -> b:Np(a, c) -> c:Np(v25, a) +/// ^ | | +/// +----------+------------------------+ +/// +/// Three-query cycle, (b) and (c) both depend on (a). We converge on 25. +#[test] +fn three_fork_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(25), a.clone()]); + + a.assert(&db, 25); +} + +/// a:Ni(b) -> b:Ni(a, c) -> c:Np(v25, b) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We converge on 25. +#[test] +fn layered_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db).to(vec![Input::Value(25), b]); + + a.assert(&db, 25); +} + +/// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v25, sb) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We hit max iterations and fall back. +#[test] +fn layered_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(25), Input::Successor(Box::new(b))]); + + a.assert(&db, MAX_COUNT_FALLBACK + 1); +} + +/// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v240, sb) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We hit max value and fall back. +#[test] +fn layered_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(240), Input::Successor(Box::new(b))]); + + a.assert(&db, MAX_VALUE_FALLBACK + 1); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Np(v25, a, b) +/// ^ ^ | +/// +----------+------------------------+ +/// +/// Nested cycles. We converge on 25. +#[test] +fn nested_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(25), a.clone(), b]); + + a.assert(&db, 25); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Np(v25, b, a) +/// ^ ^ | +/// +----------+------------------------+ +/// +/// Nested cycles, inner first. We converge on 25. +#[test] +fn nested_inner_first_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(25), b, a.clone()]); + + a.assert(&db, 25); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v25, a, sb) +/// ^ ^ | +/// +----------+-------------------------+ +/// +/// Nested cycles. We hit max iterations and fall back. +#[test] +fn nested_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![ + Input::Value(25), + a.clone(), + Input::Successor(Box::new(b)), + ]); + + a.assert(&db, MAX_COUNT_FALLBACK + 1); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v25, b, sa) +/// ^ ^ | +/// +----------+-------------------------+ +/// +/// Nested cycles, inner first. We hit max iterations and fall back. +#[test] +fn nested_inner_first_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![ + Input::Value(25), + b, + Input::Successor(Box::new(a.clone())), + ]); + + a.assert(&db, MAX_COUNT_FALLBACK + 1); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v240, a, sb) +/// ^ ^ | +/// +----------+--------------------------+ +/// +/// Nested cycles. We hit max value and fall back. +#[test] +fn nested_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![ + Input::Value(240), + a.clone(), + Input::Successor(Box::new(b)), + ]); + + a.assert(&db, MAX_VALUE_FALLBACK + 1); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v240, b, sa) +/// ^ ^ | +/// +----------+--------------------------+ +/// +/// Nested cycles, inner first. We hit max value and fall back. +#[test] +fn nested_inner_first_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![ + Input::Value(240), + b, + Input::Successor(Box::new(a.clone())), + ]); + + a.assert(&db, MAX_VALUE_FALLBACK + 1); +} + +/// a:Ni(b) -> b:Ni(c, a) -> c:Np(v25, a, b) +/// ^ ^ | | +/// +----------+--------|------------------+ +/// | | +/// +-------------------+ +/// +/// Nested cycles, double head. We converge on 25. +#[test_log::test] +fn nested_double_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c, a.clone()]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(25), a.clone(), b]); + + a.assert(&db, 25); +} + +// Multiple-revision cycles + +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// a:Ni(b) --> b:Np(v30) +/// +/// Cycle becomes not-a-cycle in next revision. +#[test] +fn cycle_becomes_non_cycle() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.clone().assert(&db, 255); + + b_in.set_inputs(&mut db).to(vec![Input::Value(30)]); + + a.assert(&db, 30); +} + +/// a:Ni(b) --> b:Np(v30) +/// +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// Non-cycle becomes a cycle in next revision. +#[test] +fn non_cycle_becomes_cycle() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![Input::Value(30)]); + + a.clone().assert(&db, 30); + + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert(&db, 255); +} + +/// a:Xi(b) -> b:Xi(c, a) -> c:Xp(v25, a, sb) +/// ^ ^ | | +/// +----------+--------|-------------------+ +/// | | +/// +-------------------+ +/// +/// Nested cycles, double head. We hit max iterations and fall back, then max value on the next +/// revision, then converge on the next. +#[test] +fn nested_double_multiple_revisions() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c, a.clone()]); + c_in.set_inputs(&mut db).to(vec![ + Input::Value(25), + a.clone(), + Input::Successor(Box::new(b.clone())), + ]); + + a.clone().assert(&db, MAX_COUNT_FALLBACK + 1); + + // next revision, we hit max value instead + c_in.set_inputs(&mut db).to(vec![ + Input::Value(240), + a.clone(), + Input::Successor(Box::new(b.clone())), + ]); + + a.clone().assert(&db, MAX_VALUE_FALLBACK + 1); + + // and next revision, we converge + c_in.set_inputs(&mut db) + .to(vec![Input::Value(240), a.clone(), b]); + + a.assert(&db, 240); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Ni(a) +/// ^ | +/// +---------------------------+ +/// +/// In a cycle with some LOW durability and some HIGH durability inputs, changing a LOW durability +/// input still re-executes the full cycle in the next revision. +#[test] +fn cycle_durability() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinIterate(c_in); + a_in.set_inputs(&mut db) + .with_durability(Durability::LOW) + .to(vec![b.clone()]); + b_in.set_inputs(&mut db) + .with_durability(Durability::HIGH) + .to(vec![c]); + c_in.set_inputs(&mut db) + .with_durability(Durability::HIGH) + .to(vec![a.clone()]); + + a.clone().assert(&db, 255); + + // next revision, we converge instead + a_in.set_inputs(&mut db) + .with_durability(Durability::LOW) + .to(vec![Input::Value(45), b]); + + a.assert(&db, 45); +} diff --git a/tests/cycles.rs b/tests/cycles.rs deleted file mode 100644 index f0748418..00000000 --- a/tests/cycles.rs +++ /dev/null @@ -1,436 +0,0 @@ -#![allow(warnings)] - -use std::panic::{RefUnwindSafe, UnwindSafe}; - -use expect_test::expect; -use salsa::DatabaseImpl; -use salsa::Durability; - -// Axes: -// -// Threading -// * Intra-thread -// * Cross-thread -- part of cycle is on one thread, part on another -// -// Recovery strategies: -// * Panic -// * Fallback -// * Mixed -- multiple strategies within cycle participants -// -// Across revisions: -// * N/A -- only one revision -// * Present in new revision, not old -// * Present in old revision, not new -// * Present in both revisions -// -// Dependencies -// * Tracked -// * Untracked -- cycle participant(s) contain untracked reads -// -// Layers -// * Direct -- cycle participant is directly invoked from test -// * Indirect -- invoked a query that invokes the cycle -// -// -// | Thread | Recovery | Old, New | Dep style | Layers | Test Name | -// | ------ | -------- | -------- | --------- | ------ | --------- | -// | Intra | Panic | N/A | Tracked | direct | cycle_memoized | -// | Intra | Panic | N/A | Untracked | direct | cycle_volatile | -// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | -// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | -// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | -// | Intra | Fallback | New | Tracked | direct | cycle_appears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | -// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | -// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | -// | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs | - -#[derive(PartialEq, Eq, Hash, Clone, Debug)] -struct Error { - cycle: Vec, -} - -use salsa::Database as Db; -use salsa::Setter; - -#[salsa::input] -struct MyInput {} - -#[salsa::tracked] -fn memoized_a(db: &dyn Db, input: MyInput) { - memoized_b(db, input) -} - -#[salsa::tracked] -fn memoized_b(db: &dyn Db, input: MyInput) { - memoized_a(db, input) -} - -#[salsa::tracked] -fn volatile_a(db: &dyn Db, input: MyInput) { - db.report_untracked_read(); - volatile_b(db, input) -} - -#[salsa::tracked] -fn volatile_b(db: &dyn Db, input: MyInput) { - db.report_untracked_read(); - volatile_a(db, input) -} - -/// The queries A, B, and C in `Database` can be configured -/// to invoke one another in arbitrary ways using this -/// enum. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum CycleQuery { - None, - A, - B, - C, - AthenC, -} - -#[salsa::input] -struct ABC { - a: CycleQuery, - b: CycleQuery, - c: CycleQuery, -} - -impl CycleQuery { - fn invoke(self, db: &dyn Db, abc: ABC) -> Result<(), Error> { - match self { - CycleQuery::A => cycle_a(db, abc), - CycleQuery::B => cycle_b(db, abc), - CycleQuery::C => cycle_c(db, abc), - CycleQuery::AthenC => { - let _ = cycle_a(db, abc); - cycle_c(db, abc) - } - CycleQuery::None => Ok(()), - } - } -} - -#[salsa::tracked(recovery_fn=recover_a)] -fn cycle_a(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.a(db).invoke(db, abc) -} - -fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { - Err(Error { - cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) -} - -#[salsa::tracked(recovery_fn=recover_b)] -fn cycle_b(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.b(db).invoke(db, abc) -} - -fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { - Err(Error { - cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) -} - -#[salsa::tracked] -fn cycle_c(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.c(db).invoke(db, abc) -} - -#[track_caller] -fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { - let v = std::panic::catch_unwind(f); - if let Err(d) = &v { - if let Some(cycle) = d.downcast_ref::() { - return cycle.clone(); - } - } - panic!("unexpected value: {:?}", v) -} - -#[test] -fn cycle_memoized() { - salsa::DatabaseImpl::new().attach(|db| { - let input = MyInput::new(db); - let cycle = extract_cycle(|| memoized_a(db, input)); - let expected = expect![[r#" - [ - memoized_a(Id(0)), - memoized_b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&cycle.all_participants(db)); - }) -} - -#[test] -fn cycle_volatile() { - salsa::DatabaseImpl::new().attach(|db| { - let input = MyInput::new(db); - let cycle = extract_cycle(|| volatile_a(db, input)); - let expected = expect![[r#" - [ - volatile_a(Id(0)), - volatile_b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&cycle.all_participants(db)); - }); -} - -#[test] -fn expect_cycle() { - // A --> B - // ^ | - // +-----+ - - salsa::DatabaseImpl::new().attach(|db| { - let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(db, abc).is_err()); - }) -} - -#[test] -fn inner_cycle() { - // A --> B <-- C - // ^ | - // +-----+ - salsa::DatabaseImpl::new().attach(|db| { - let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::B); - let err = cycle_c(db, abc); - assert!(err.is_err()); - let expected = expect![[r#" - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&err.unwrap_err().cycle); - }) -} - -#[test] -fn cycle_revalidate() { - // A --> B - // ^ | - // +-----+ - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - abc.set_b(&mut db).to(CycleQuery::A); // same value as default - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_recovery_unchanged_twice() { - // A --> B - // ^ | - // +-----+ - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - - abc.set_c(&mut db).to(CycleQuery::A); // force new revision - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_appears() { - let mut db = salsa::DatabaseImpl::new(); - // A --> B - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None); - assert!(cycle_a(&db, abc).is_ok()); - - // A --> B - // ^ | - // +-----+ - abc.set_b(&mut db).to(CycleQuery::A); - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_disappears() { - let mut db = salsa::DatabaseImpl::new(); - - // A --> B - // ^ | - // +-----+ - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - - // A --> B - abc.set_b(&mut db).to(CycleQuery::None); - assert!(cycle_a(&db, abc).is_ok()); -} - -/// A variant on `cycle_disappears` in which the values of -/// `a` and `b` are set with durability values. -/// If we are not careful, this could cause us to overlook -/// the fact that the cycle will no longer occur. -#[test] -fn cycle_disappears_durability() { - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new( - &mut db, - CycleQuery::None, - CycleQuery::None, - CycleQuery::None, - ); - abc.set_a(&mut db) - .with_durability(Durability::LOW) - .to(CycleQuery::B); - abc.set_b(&mut db) - .with_durability(Durability::HIGH) - .to(CycleQuery::A); - - assert!(cycle_a(&db, abc).is_err()); - - // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However, - // because `b` participates in the same cycle as `a`, its final durability - // should be `LOW`. - // - // Check that setting a `LOW` input causes us to re-execute `b` query, and - // observe that the cycle goes away. - abc.set_a(&mut db) - .with_durability(Durability::LOW) - .to(CycleQuery::None); - - assert!(cycle_b(&mut db, abc).is_ok()); -} - -#[test] -fn cycle_mixed_1() { - salsa::DatabaseImpl::new().attach(|db| { - // A --> B <-- C - // | ^ - // +-----+ - let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::B); - - let expected = expect![[r#" - [ - "cycle_b(Id(0))", - "cycle_c(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&cycle_c(db, abc).unwrap_err().cycle); - }) -} - -#[test] -fn cycle_mixed_2() { - salsa::DatabaseImpl::new().attach(|db| { - // Configuration: - // - // A --> B --> C - // ^ | - // +-----------+ - let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::A); - let expected = expect![[r#" - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - "cycle_c(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&cycle_a(db, abc).unwrap_err().cycle); - }) -} - -#[test] -fn cycle_deterministic_order() { - // No matter whether we start from A or B, we get the same set of participants: - let f = || { - let mut db = salsa::DatabaseImpl::new(); - - // A --> B - // ^ | - // +-----+ - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - (db, abc) - }; - let (db, abc) = f(); - let a = cycle_a(&db, abc); - let (db, abc) = f(); - let b = cycle_b(&db, abc); - let expected = expect![[r#" - ( - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - ) - "#]]; - expected.assert_debug_eq(&(a.unwrap_err().cycle, b.unwrap_err().cycle)); -} - -#[test] -fn cycle_multiple() { - // No matter whether we start from A or B, we get the same set of participants: - let mut db = salsa::DatabaseImpl::new(); - - // Configuration: - // - // A --> B <-- C - // ^ | ^ - // +-----+ | - // | | - // +-----+ - // - // Here, conceptually, B encounters a cycle with A and then - // recovers. - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); - - let c = cycle_c(&db, abc); - let b = cycle_b(&db, abc); - let a = cycle_a(&db, abc); - let expected = expect![[r#" - ( - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - ) - "#]]; - expected.assert_debug_eq(&( - c.unwrap_err().cycle, - b.unwrap_err().cycle, - a.unwrap_err().cycle, - )); -} - -#[test] -fn cycle_recovery_set_but_not_participating() { - salsa::DatabaseImpl::new().attach(|db| { - // A --> C -+ - // ^ | - // +--+ - let abc = ABC::new(db, CycleQuery::C, CycleQuery::None, CycleQuery::C); - - // Here we expect C to panic and A not to recover: - let r = extract_cycle(|| drop(cycle_a(db, abc))); - let expected = expect![[r#" - [ - cycle_c(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&r.all_participants(db)); - }) -} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 578a83cb..b47612c3 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1,8 +1,4 @@ mod setup; mod parallel_cancellation; -mod parallel_cycle_all_recover; -mod parallel_cycle_mid_recover; -mod parallel_cycle_none_recover; -mod parallel_cycle_one_recover; mod signal; diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs deleted file mode 100644 index 9dc8c74e..00000000 --- a/tests/parallel/parallel_cycle_all_recover.rs +++ /dev/null @@ -1,104 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::Knobs; -use crate::setup::KnobsDatabase; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked(recovery_fn = recover_a1)] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - a2(db, input) -} - -fn recover_a1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_a1"); - key.field(db) * 10 + 1 -} - -#[salsa::tracked(recovery_fn=recover_a2)] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - b1(db, input) -} - -fn recover_a2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_a2"); - key.field(db) * 10 + 2 -} - -#[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - b2(db, input) -} - -fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b1"); - key.field(db) * 20 + 1 -} - -#[salsa::tracked(recovery_fn=recover_b2)] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - a1(db, input) -} - -fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b2"); - key.field(db) * 20 + 2 -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | signal stage 2 -// (unblocked) wait for stage 3 (blocks) -// a2 | -// b1 (blocks -> stage 3) | -// | (unblocked) -// | b2 -// | a1 (cycle detected, recovers) -// | b2 completes, recovers -// | b1 completes, recovers -// a2 sees cycle, recovers -// a1 completes, recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b1(&db, input) - }); - - assert_eq!(thread_a.join().unwrap(), 11); - assert_eq!(thread_b.join().unwrap(), 21); -} diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs deleted file mode 100644 index 593d46a6..00000000 --- a/tests/parallel/parallel_cycle_mid_recover.rs +++ /dev/null @@ -1,102 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // tell thread b we have started - db.signal(1); - - // wait for thread b to block on a1 - db.wait_for(2); - - a2(db, input) -} - -#[salsa::tracked] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // create the cycle - b1(db, input) -} - -#[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // wait for thread a to have started - db.wait_for(1); - b2(db, input) -} - -fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b1"); - key.field(db) * 20 + 2 -} - -#[salsa::tracked] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // will encounter a cycle but recover - b3(db, input); - b1(db, input); // hasn't recovered yet - 0 -} - -#[salsa::tracked(recovery_fn=recover_b3)] -pub(crate) fn b3(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // will block on thread a, signaling stage 2 - a1(db, input) -} - -fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b3"); - key.field(db) * 200 + 2 -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | | -// | b2 -// | b3 -// | a1 (blocks -> stage 2) -// (unblocked) | -// a2 (cycle detected) | -// b3 recovers -// b2 resumes -// b1 recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || b1(&db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs deleted file mode 100644 index 89f1ecfb..00000000 --- a/tests/parallel/parallel_cycle_none_recover.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Test a cycle where no queries recover that occurs across threads. -//! See the `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::Knobs; -use crate::setup::KnobsDatabase; -use expect_test::expect; -use salsa::Database; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - b(db, input) -} - -#[salsa::tracked] -pub(crate) fn b(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - - // Now try to execute A - a(db, input) -} - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, -1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || a(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b(&db, input) - }); - - // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). - // Right now, it panics with a string. - let err_b = thread_b.join().unwrap_err(); - db.attach(|_| { - if let Some(c) = err_b.downcast_ref::() { - let expected = expect![[r#" - [ - a(Id(0)), - b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&c.all_participants(&db)); - } else { - panic!("b failed in an unexpected way: {:?}", err_b); - } - }); - - // We expect A to propagate a panic, which causes us to use the sentinel - // type `Canceled`. - assert!(thread_a - .join() - .unwrap_err() - .downcast_ref::() - .is_some()); -} diff --git a/tests/parallel/parallel_cycle_one_recover.rs b/tests/parallel/parallel_cycle_one_recover.rs deleted file mode 100644 index c0378282..00000000 --- a/tests/parallel/parallel_cycle_one_recover.rs +++ /dev/null @@ -1,91 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - a2(db, input) -} - -#[salsa::tracked(recovery_fn=recover)] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - b1(db, input) -} - -fn recover(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover"); - key.field(db) * 20 + 2 -} - -#[salsa::tracked] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - b2(db, input) -} - -#[salsa::tracked] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - a1(db, input) -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | signal stage 2 -// (unblocked) wait for stage 3 (blocks) -// a2 | -// b1 (blocks -> stage 3) | -// | (unblocked) -// | b2 -// | a1 (cycle detected) -// a2 recovery fn executes | -// a1 completes normally | -// b2 completes, recovers -// b1 completes, recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b1(&db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 56d204ee..c266731a 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -9,8 +9,6 @@ use crate::signal::Signal; /// a certain behavior. #[salsa::db] pub(crate) trait KnobsDatabase: Database { - fn knobs(&self) -> &Knobs; - fn signal(&self, stage: usize); fn wait_for(&self, stage: usize); @@ -68,10 +66,6 @@ impl salsa::Database for Knobs { #[salsa::db] impl KnobsDatabase for Knobs { - fn knobs(&self) -> &Knobs { - self - } - fn signal(&self, stage: usize) { self.signal.signal(stage); }