Skip to content

Commit

Permalink
KeysCollector is now Send / Sync
Browse files Browse the repository at this point in the history
  • Loading branch information
Sajjon committed Sep 20, 2024
1 parent 33e93b2 commit a54c1e1
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 21 deletions.
20 changes: 14 additions & 6 deletions src/derivation/collector/key_ring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,33 @@ use crate::prelude::*;

/// A collection of `HierarchicalDeterministicFactorInstance` derived from a
/// factor source.
#[derive(Clone, Debug)]
#[derive(Debug)]
pub(crate) struct Keyring {
pub(crate) factor_source_id: FactorSourceIDFromHash,
pub(crate) paths: IndexSet<DerivationPath>,
derived: RefCell<IndexSet<HierarchicalDeterministicFactorInstance>>,
derived: RwLock<IndexSet<HierarchicalDeterministicFactorInstance>>,
}

impl Keyring {
pub fn clone_snapshot(&self) -> Self {
Self {
factor_source_id: self.factor_source_id,
paths: self.paths.clone(),
derived: RwLock::new(self.derived.try_read().unwrap().clone()),
}
}
pub(crate) fn new(
factor_source_id: FactorSourceIDFromHash,
paths: IndexSet<DerivationPath>,
) -> Self {
Self {
factor_source_id,
paths,
derived: RefCell::new(IndexSet::new()),
derived: RwLock::new(IndexSet::new()),
}
}
pub(crate) fn factors(&self) -> IndexSet<HierarchicalDeterministicFactorInstance> {
self.derived.borrow().clone()
self.derived.try_read().unwrap().clone()
}

pub(crate) fn process_response(
Expand All @@ -33,10 +40,11 @@ impl Keyring {
.all(|f| f.factor_source_id == self.factor_source_id
&& !self
.derived
.borrow()
.try_read()
.unwrap()
.iter()
.any(|x| x.public_key == f.public_key)));

self.derived.borrow_mut().extend(response)
self.derived.try_write().unwrap().extend(response)
}
}
17 changes: 12 additions & 5 deletions src/derivation/collector/keys_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct KeysCollector {

/// Mutable internal state of the collector which builds up the list
/// of public keys from each used factor source.
state: RefCell<KeysCollectorState>,
state: RwLock<KeysCollectorState>,
}

impl KeysCollector {
Expand Down Expand Up @@ -47,7 +47,7 @@ impl KeysCollector {

Ok(Self {
dependencies,
state: RefCell::new(state),
state: RwLock::new(state),
})
}
}
Expand All @@ -60,7 +60,7 @@ impl KeysCollector {
.derive_with_factors() // in decreasing "friction order"
.await
.inspect_err(|e| error!("Failed to use factor sources: {:#?}", e));
self.state.into_inner().outcome()
self.state.into_inner().unwrap().outcome()
}
}

Expand Down Expand Up @@ -123,7 +123,11 @@ impl KeysCollector {
&self,
factor_source_id: &FactorSourceIDFromHash,
) -> Result<MonoFactorKeyDerivationRequest> {
let keyring = self.state.borrow().keyring_for(factor_source_id)?;
let keyring = self
.state
.try_read()
.unwrap()
.keyring_for(factor_source_id)?;
assert_eq!(keyring.factors().len(), 0);
let paths = keyring.paths.clone();
Ok(MonoFactorKeyDerivationRequest::new(
Expand Down Expand Up @@ -156,6 +160,9 @@ impl KeysCollector {
}

fn process_batch_response(&self, response: KeyDerivationResponse) -> Result<()> {
self.state.borrow_mut().process_batch_response(response)
self.state
.try_write()
.unwrap()
.process_batch_response(response)
}
}
13 changes: 7 additions & 6 deletions src/derivation/collector/keys_collector_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::prelude::*;
///
/// Holds a collection of keyrings derived from various factor sources.
pub(crate) struct KeysCollectorState {
pub(super) keyrings: RefCell<IndexMap<FactorSourceIDFromHash, Keyring>>,
pub(super) keyrings: RwLock<IndexMap<FactorSourceIDFromHash, Keyring>>,
}

impl KeysCollectorState {
Expand All @@ -22,12 +22,12 @@ impl KeysCollectorState {
})
.collect::<IndexMap<FactorSourceIDFromHash, Keyring>>();
Self {
keyrings: RefCell::new(keyrings),
keyrings: RwLock::new(keyrings),
}
}

pub(crate) fn outcome(self) -> KeyDerivationOutcome {
let key_rings = self.keyrings.into_inner();
let key_rings = self.keyrings.into_inner().unwrap();
KeyDerivationOutcome::new(
key_rings
.into_iter()
Expand All @@ -38,16 +38,17 @@ impl KeysCollectorState {

pub(crate) fn keyring_for(&self, factor_source_id: &FactorSourceIDFromHash) -> Result<Keyring> {
self.keyrings
.borrow()
.try_read()
.unwrap()
.get(factor_source_id)
.cloned()
.map(|x| x.clone_snapshot())
.inspect(|k| assert_eq!(k.factor_source_id, *factor_source_id))
.ok_or(CommonError::UnknownFactorSource)
}

pub(crate) fn process_batch_response(&self, response: KeyDerivationResponse) -> Result<()> {
for (factor_source_id, factors) in response.per_factor_source.into_iter() {
let mut rings = self.keyrings.borrow_mut();
let mut rings = self.keyrings.try_write().unwrap();
let keyring = rings
.get_mut(&factor_source_id)
.ok_or(CommonError::UnknownFactorSource)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::prelude::*;
/// which it will use to update it internal state and continue with the next
/// factor source, or in case of failure the whole process will be aborted.
#[async_trait::async_trait]
pub trait MonoFactorKeyDerivationInteractor {
pub trait MonoFactorKeyDerivationInteractor: Send + Sync {
async fn derive(
&self,
request: MonoFactorKeyDerivationRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::prelude::*;
/// which it will use to update it internal state and continue with the next
/// factor source, or in case of failure the whole process will be aborted.
#[async_trait::async_trait]
pub trait PolyFactorKeyDerivationInteractor {
pub trait PolyFactorKeyDerivationInteractor: Send + Sync {
async fn derive(
&self,
request: PolyFactorKeyDerivationRequest,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::prelude::*;

/// A collection of "interactors" which can derive keys.
pub trait KeysDerivationInteractors {
pub trait KeysDerivationInteractors: Sync + Send {
fn interactor_for(&self, kind: FactorSourceKind) -> KeyDerivationInteractor;
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,30 @@ impl DeriveAndAnalyzeAccountRecoveryScanInput {
}
}
}

impl From<DeriveAndAnalyzeAccountRecoveryScanInput> for DeriveAndAnalyzeInput {
#[allow(clippy::diverging_sub_expression)]
fn from(value: DeriveAndAnalyzeAccountRecoveryScanInput) -> Self {
let initial_derivation_requests = IndexSet::<DerivationRequest>::new();
let unfactored_derivation_requests = AnyFactorDerivationRequest::many_for_each_on(
NetworkID::Mainnet,
[CAP26EntityKind::Account],
[CAP26KeyKind::TransactionSigning],
[KeySpace::Securified, KeySpace::Unsecurified],
);

let initial_derivation_requests = value
.factor_sources
.clone()
.into_iter()
.flat_map(|f| {
let factor_source_id = f.factor_source_id();
unfactored_derivation_requests
.clone()
.into_iter()
.map(move |u| u.derivation_request_with_factor_source_id(factor_source_id))
})
.collect::<IndexSet<_>>();

let factor_instances_provider: Arc<dyn IsFactorInstancesProvider> = { unreachable!() };
let analyze_factor_instances: Arc<dyn IsIntermediaryDerivationAnalyzer> =
{ unreachable!() };
Expand All @@ -43,3 +63,42 @@ impl From<DeriveAndAnalyzeAccountRecoveryScanInput> for DeriveAndAnalyzeInput {
)
}
}

pub struct UncachedFactorInstanceProvider {
factor_sources: IndexSet<HDFactorSource>,
derivation_index_ranges_start_values:
IndexMap<FactorSourceIDFromHash, IndexMap<DerivationRequest, HDPathValue>>,
interactors: Arc<dyn KeysDerivationInteractors>,
}

impl UncachedFactorInstanceProvider {
fn derivation_paths_for_requests(
&self,
derivation_requests: IndexSet<DerivationRequest>,
) -> IndexMap<FactorSourceIDFromHash, IndexSet<DerivationPath>> {
todo!()
}
async fn derive_instances(
&self,
derivation_requests: IndexSet<DerivationRequest>,
) -> Result<IndexSet<HierarchicalDeterministicFactorInstance>> {
let derivation_paths = self.derivation_paths_for_requests(derivation_requests);
let keys_collector = KeysCollector::new(
self.factor_sources.clone(),
derivation_paths,
self.interactors.clone(),
)?;
let derived = keys_collector.collect_keys().await;
Ok(derived.all_factors())
}
}

#[async_trait::async_trait]
impl IsFactorInstancesProvider for UncachedFactorInstanceProvider {
async fn provide_instances(
&self,
derivation_requests: IndexSet<DerivationRequest>,
) -> Result<IndexSet<HierarchicalDeterministicFactorInstance>> {
self.derive_instances(derivation_requests).await
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,81 @@ pub struct DerivationRequest {
pub key_space: KeySpace,
pub key_kind: CAP26KeyKind,
}

impl DerivationRequest {
pub fn new(
factor_source_id: FactorSourceIDFromHash,
network_id: NetworkID,
entity_kind: CAP26EntityKind,
key_space: KeySpace,
key_kind: CAP26KeyKind,
) -> Self {
Self {
factor_source_id,
network_id,
entity_kind,
key_space,
key_kind,
}
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct AnyFactorDerivationRequest {
pub network_id: NetworkID,
pub entity_kind: CAP26EntityKind,
pub key_space: KeySpace,
pub key_kind: CAP26KeyKind,
}

impl AnyFactorDerivationRequest {
pub fn new(
network_id: NetworkID,
entity_kind: CAP26EntityKind,
key_space: KeySpace,
key_kind: CAP26KeyKind,
) -> Self {
Self {
network_id,
entity_kind,
key_kind,
key_space,
}
}

pub fn derivation_request_with_factor_source_id(
&self,
factor_source_id: FactorSourceIDFromHash,
) -> DerivationRequest {
DerivationRequest::new(
factor_source_id,
self.network_id,
self.entity_kind,
self.key_space,
self.key_kind,
)
}

pub fn many_for_each_on(
network_id: NetworkID,
entity_kinds: impl IntoIterator<Item = CAP26EntityKind>,
key_kinds: impl IntoIterator<Item = CAP26KeyKind>,
key_spaces: impl IntoIterator<Item = KeySpace>,
) -> IndexSet<Self> {
let entity_kinds = entity_kinds.into_iter().collect::<IndexSet<_>>();
let key_kinds = key_kinds.into_iter().collect::<IndexSet<_>>();
let key_spaces = key_spaces.into_iter().collect::<IndexSet<_>>();

let mut requests = IndexSet::<Self>::new();

for entity_kind in entity_kinds.into_iter() {
for key_kind in key_kinds.clone().into_iter() {
for key_space in key_spaces.clone().into_iter() {
let request = Self::new(network_id, entity_kind, key_space, key_kind);
requests.insert(request);
}
}
}
requests
}
}

0 comments on commit a54c1e1

Please sign in to comment.