Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: keep client initialization state internal [WPB-12004] #749

Open
wants to merge 8 commits into
base: feat/transaction
Choose a base branch
from
2 changes: 1 addition & 1 deletion crypto/benches/utils/mls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ pub(crate) async fn setup_mls(
in_memory: bool,
) -> (CoreCrypto, ConversationId) {
let (central, _) = new_central(ciphersuite, credential, in_memory).await;
let core_crypto = CoreCrypto::from(central);
let core_crypto = central;
let context = core_crypto.new_transaction().await.unwrap();
let id = conversation_id();
context
Expand Down
13 changes: 3 additions & 10 deletions crypto/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! This module contains the primitives to enable transactional support on a higher level within the
//! [MlsCentral]. All mutating operations need to be done through a [CentralContext].

use async_lock::{Mutex, RwLock, RwLockReadGuardArc, RwLockWriteGuardArc};

Check warning on line 4 in crypto/src/context.rs

View workflow job for this annotation

GitHub Actions / hack

unused import: `Mutex`

Check warning on line 4 in crypto/src/context.rs

View workflow job for this annotation

GitHub Actions / hack

unused import: `Mutex`
use mls_crypto_provider::{CryptoKeystore, MlsCryptoProvider};
use std::{ops::Deref, sync::Arc};

Expand Down Expand Up @@ -33,7 +33,7 @@
Valid {
provider: MlsCryptoProvider,
callbacks: Arc<RwLock<Option<std::sync::Arc<dyn CoreCryptoCallbacks + 'static>>>>,
mls_client: Arc<RwLock<Option<Client>>>,
mls_client: Client,
mls_groups: Arc<RwLock<GroupStore<MlsConversation>>>,
#[cfg(feature = "proteus")]
proteus_central: Arc<Mutex<Option<ProteusCentral>>>,
Expand Down Expand Up @@ -79,16 +79,9 @@
})
}

pub(crate) async fn mls_client(&self) -> CryptoResult<RwLockReadGuardArc<Option<Client>>> {
pub(crate) async fn mls_client(&self) -> CryptoResult<Client> {
match self.state.read().await.deref() {
ContextState::Valid { mls_client, .. } => Ok(mls_client.read_arc().await),
ContextState::Invalid => Err(CryptoError::InvalidContext),
}
}

pub(crate) async fn mls_client_mut(&self) -> CryptoResult<RwLockWriteGuardArc<Option<Client>>> {
match self.state.read().await.deref() {
ContextState::Valid { mls_client, .. } => Ok(mls_client.write_arc().await),
ContextState::Valid { mls_client, .. } => Ok(mls_client.clone()),
ContextState::Invalid => Err(CryptoError::InvalidContext),
}
}
Expand Down
13 changes: 5 additions & 8 deletions crypto/src/e2e_identity/conversation_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,13 @@ mod tests {
let x509_test_chain = x509_test_chain_arc.as_ref().as_ref().unwrap();

// That way the conversation creator (Alice) will have a different credential type than Bob
let mut alice_client_guard = alice_central.context.mls_client_mut().await.unwrap();
let alice_client = alice_client_guard.as_mut().unwrap();
let alice_client = alice_central.context.mls_client().await.unwrap();
let alice_provider = alice_central.context.mls_provider().await.unwrap();
let creator_ct = match case.credential_type {
MlsCredentialType::Basic => {
let intermediate_ca = x509_test_chain.find_local_intermediate_ca();
let cert_bundle = CertificateBundle::rand(alice_client.id(), intermediate_ca);
let cert_bundle =
CertificateBundle::rand(&alice_client.id().await.unwrap(), intermediate_ca);
alice_client
.init_x509_credential_bundle_if_missing(
&alice_provider,
Expand All @@ -326,7 +326,6 @@ mod tests {
MlsCredentialType::Basic
}
};
drop(alice_client_guard);

alice_central
.context
Expand Down Expand Up @@ -392,16 +391,14 @@ mod tests {
.await
.unwrap();

let mut alice_client_guard = alice_central.context.mls_client_mut().await.unwrap();
let alice_client = alice_client_guard.as_mut().unwrap();
let alice_client = alice_central.context.mls_client().await.unwrap();
let alice_provider = alice_central.context.mls_provider().await.unwrap();
// Needed because 'e2ei_rotate' does not do it directly and it's required for 'get_group_info'
alice_client
.save_new_x509_credential_bundle(&alice_provider.keystore(), case.signature_scheme(), cert)
.await
.unwrap();

drop(alice_client_guard);
// Need to fetch it before it becomes invalid & expires
let gi = alice_central.get_group_info(&id).await;

Expand Down Expand Up @@ -458,7 +455,7 @@ mod tests {
alice_central.context.e2ei_rotate(&id, Some(&cb)).await.unwrap();
alice_central.context.commit_accepted(&id).await.unwrap();

let mut alice_client = alice_central.client().await;
let alice_client = alice_central.client().await;
let alice_provider = alice_central.context.mls_provider().await.unwrap();

// Needed because 'e2ei_rotate' does not do it directly and it's required for 'get_group_info'
Expand Down
19 changes: 8 additions & 11 deletions crypto/src/e2e_identity/enabled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,31 @@ use openmls_traits::types::SignatureScheme;
impl CentralContext {
/// See [MlsCentral::e2ei_is_enabled]
pub async fn e2ei_is_enabled(&self, signature_scheme: SignatureScheme) -> CryptoResult<bool> {
let client_guard = self.mls_client().await?;
let client = client_guard.as_ref().ok_or(CryptoError::MlsNotInitialized)?;
let client = self.mls_client().await?;
client.e2ei_is_enabled(signature_scheme).await
}
}

impl MlsCentral {
/// Returns true when end-to-end-identity is enabled for the given SignatureScheme
pub async fn e2ei_is_enabled(&self, signature_scheme: SignatureScheme) -> CryptoResult<bool> {
let client_guard = self.mls_client().await;
let client = client_guard.as_ref().ok_or(CryptoError::MlsNotInitialized)?;
client.e2ei_is_enabled(signature_scheme).await
self.mls_client.e2ei_is_enabled(signature_scheme).await
}
}

impl Client {
async fn e2ei_is_enabled(&self, signature_scheme: SignatureScheme) -> CryptoResult<bool> {
let maybe_x509 = self
let x509_result = self
.find_most_recent_credential_bundle(signature_scheme, MlsCredentialType::X509)
.await;
match maybe_x509 {
None => {
match x509_result {
Err(CryptoError::CredentialNotFound(MlsCredentialType::X509)) => {
self.find_most_recent_credential_bundle(signature_scheme, MlsCredentialType::Basic)
.await
.ok_or(CryptoError::CredentialNotFound(MlsCredentialType::Basic))?;
.await?;
Ok(false)
}
Some(_) => Ok(true),
Err(e) => Err(e),
Ok(_) => Ok(true),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crypto/src/e2e_identity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ pub(crate) mod tests {
use crate::{
e2e_identity::{id::QualifiedE2eiClientId, tests::x509::X509TestChain},
prelude::*,
test_utils::{central::TEAM, *},
test_utils::{context::TEAM, *},
CryptoResult,
};

Expand Down
43 changes: 19 additions & 24 deletions crypto/src/e2e_identity/rotate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ impl CentralContext {
expiry_sec: u32,
ciphersuite: MlsCiphersuite,
) -> CryptoResult<E2eiEnrollment> {
let client_guard = self.mls_client().await?;
let client = client_guard.as_ref().ok_or(CryptoError::MlsNotInitialized)?;
let mls_provider = self.mls_provider().await?;
// look for existing credential of type basic. If there isn't, then this method has been misused
let cb = client
let cb = self
.mls_client()
.await?
.find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), MlsCredentialType::Basic)
.await
.ok_or(E2eIdentityError::MissingExistingClient(MlsCredentialType::Basic))?;
.map_err(|_| E2eIdentityError::MissingExistingClient(MlsCredentialType::Basic))?;
let client_id = cb.credential().identity().into();

let sign_keypair = Some((&cb.signature_key).try_into()?);
Expand Down Expand Up @@ -73,14 +73,14 @@ impl CentralContext {
expiry_sec: u32,
ciphersuite: MlsCiphersuite,
) -> CryptoResult<E2eiEnrollment> {
let client_guard = self.mls_client().await?;
let client = client_guard.as_ref().ok_or(CryptoError::MlsNotInitialized)?;
let mls_provider = self.mls_provider().await?;
// look for existing credential of type x509. If there isn't, then this method has been misused
let cb = client
let cb = self
.mls_client()
.await?
.find_most_recent_credential_bundle(ciphersuite.signature_algorithm(), MlsCredentialType::X509)
.await
.ok_or(E2eIdentityError::MissingExistingClient(MlsCredentialType::X509))?;
.map_err(|_| E2eIdentityError::MissingExistingClient(MlsCredentialType::X509))?;
let client_id = cb.credential().identity().into();
let sign_keypair = Some((&cb.signature_key).try_into()?);
let existing_identity = cb
Expand Down Expand Up @@ -141,9 +141,8 @@ impl CentralContext {
certificate_chain,
private_key,
};
let client = &self.mls_client().await?;

let mut client_guard = self.mls_client_mut().await?;
let client = client_guard.as_mut().ok_or(CryptoError::MlsNotInitialized)?;
let new_cb = client
.save_new_x509_credential_bundle(
&self.mls_provider().await?.keystore(),
Expand Down Expand Up @@ -222,8 +221,7 @@ impl CentralContext {
id: &crate::prelude::ConversationId,
cb: Option<&CredentialBundle>,
) -> CryptoResult<MlsCommitBundle> {
let client_guard = self.mls_client().await?;
let client = client_guard.as_ref().ok_or(CryptoError::MlsNotInitialized)?;
let client = &self.mls_client().await?;
self.get_conversation(id)
.await?
.write()
Expand All @@ -246,7 +244,7 @@ impl MlsConversation {
None => &client
.find_most_recent_credential_bundle(self.ciphersuite().signature_algorithm(), MlsCredentialType::X509)
.await
.ok_or(E2eIdentityError::MissingExistingClient(MlsCredentialType::X509))?,
.map_err(|_| E2eIdentityError::MissingExistingClient(MlsCredentialType::X509))?,
};
let mut leaf_node = self.group.own_leaf().ok_or(CryptoError::InternalMlsError)?.clone();
leaf_node.set_credential_with_key(cb.to_mls_credential_with_key());
Expand Down Expand Up @@ -340,7 +338,7 @@ pub(crate) mod tests {
pub(crate) mod all {
use openmls_traits::types::SignatureScheme;

use crate::test_utils::central::TEAM;
use crate::test_utils::context::TEAM;

use super::*;

Expand Down Expand Up @@ -637,10 +635,10 @@ pub(crate) mod tests {
assert_eq!(old_cb, old_cb_found);
let (cid, all_credentials, scs, old_nb_identities) = {
let alice_client = alice_central.client().await;
let old_nb_identities = alice_client.identities.as_vec().await.len();
let old_nb_identities = alice_client.identities_count().await.unwrap();

// Let's simulate an app crash, client gets deleted and restored from keystore
let cid = alice_client.id().clone();
let cid = alice_client.id().await.unwrap();
let scs = HashSet::from([case.signature_scheme()]);
let all_credentials = alice_central
.context
Expand All @@ -665,30 +663,27 @@ pub(crate) mod tests {
backend.keystore().commit_transaction().await.unwrap();
backend.keystore().new_transaction().await.unwrap();

let client = Client::load(backend, &cid, all_credentials, scs).await.unwrap();
let mut alice_client_guard = alice_central.context.mls_client_mut().await.unwrap();
*alice_client_guard = Some(client);
drop(alice_client_guard);
let new_client = Client::default();

let alice_client = alice_central.client().await;
new_client.load(backend, &cid, all_credentials, scs).await.unwrap();

// Verify that Alice has the same credentials
let cb = alice_central
let cb = new_client
.find_most_recent_credential_bundle(case.signature_scheme(), MlsCredentialType::X509)
.await
.unwrap();
let identity = cb
.to_mls_credential_with_key()
.extract_identity(case.ciphersuite(), None)
.unwrap();
// backend.keystore().commit_transaction().await.unwrap();

assert_eq!(identity.x509_identity.as_ref().unwrap().display_name, NEW_DISPLAY_NAME);
assert_eq!(
identity.x509_identity.as_ref().unwrap().handle,
format!("wireapp://%40{NEW_HANDLE}@world.com")
);

assert_eq!(alice_client.identities.as_vec().await.len(), old_nb_identities);
assert_eq!(new_client.identities_count().await.unwrap(), old_nb_identities);
})
})
.await
Expand Down
Loading
Loading