diff --git a/testing/matrix-sdk-integration-testing/src/helpers.rs b/testing/matrix-sdk-integration-testing/src/helpers.rs index 8b71f5136fe..3a11cc16f9f 100644 --- a/testing/matrix-sdk-integration-testing/src/helpers.rs +++ b/testing/matrix-sdk-integration-testing/src/helpers.rs @@ -2,6 +2,7 @@ use std::{ collections::HashMap, ops::Deref, option_env, + path::{Path, PathBuf}, sync::{Arc, Mutex as StdMutex}, time::Duration, }; @@ -21,9 +22,14 @@ use tokio::sync::Mutex; static USERS: Lazy>> = Lazy::new(Mutex::default); +enum SqlitePath { + Random, + Path(PathBuf), +} + pub struct TestClientBuilder { username: String, - use_sqlite: bool, + use_sqlite_dir: Option, encryption_settings: EncryptionSettings, http_proxy: Option, } @@ -32,7 +38,7 @@ impl TestClientBuilder { pub fn new(username: impl Into) -> Self { Self { username: username.into(), - use_sqlite: false, + use_sqlite_dir: None, encryption_settings: Default::default(), http_proxy: None, } @@ -45,7 +51,16 @@ impl TestClientBuilder { } pub fn use_sqlite(mut self) -> Self { - self.use_sqlite = true; + self.use_sqlite_dir = Some(SqlitePath::Random); + self + } + + /// Create or re-use a Sqlite store (with no passphrase) in the supplied + /// directory. Note: this path must remain valid throughout the use of + /// the constructed Client, so if you created a TempDir you must hang on + /// to a reference to it throughout the test. + pub fn use_sqlite_dir(mut self, path: &Path) -> Self { + self.use_sqlite_dir = Some(SqlitePath::Path(path.to_owned())); self } @@ -83,10 +98,14 @@ impl TestClientBuilder { client_builder = client_builder.proxy(proxy); } - let client = if self.use_sqlite { - client_builder.sqlite_store(tmp_dir.path(), None).build().await? - } else { - client_builder.build().await? + let client = match self.use_sqlite_dir { + None => client_builder.build().await?, + Some(SqlitePath::Random) => { + client_builder.sqlite_store(tmp_dir.path(), None).build().await? + } + Some(SqlitePath::Path(path_buf)) => { + client_builder.sqlite_store(&path_buf, None).build().await? + } }; // safe to assume we have not registered this user yet, but ignore if we did diff --git a/testing/matrix-sdk-integration-testing/src/tests/e2ee.rs b/testing/matrix-sdk-integration-testing/src/tests/e2ee.rs index 5a3b3d91aaf..100b81279eb 100644 --- a/testing/matrix-sdk-integration-testing/src/tests/e2ee.rs +++ b/testing/matrix-sdk-integration-testing/src/tests/e2ee.rs @@ -1,4 +1,8 @@ -use std::sync::{Arc, Mutex}; +use std::{ + path::Path, + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; use anyhow::Result; use assert_matches::assert_matches; @@ -19,9 +23,11 @@ use matrix_sdk::{ SyncRoomMessageEvent, }, }, + EventId, OwnedEventId, OwnedRoomId, RoomId, }, - Client, + Client, Room, }; +use tempfile::tempdir; use tracing::warn; use crate::helpers::{SyncTokenAwareClient, TestClientBuilder}; @@ -289,6 +295,164 @@ async fn test_mutual_sas_verification() -> Result<()> { Ok(()) } +struct ClientWrapper { + pub client: SyncTokenAwareClient, + events: Arc>>, +} + +impl ClientWrapper { + async fn new(username: &str) -> Self { + Self::from_client_builder(TestClientBuilder::new(username).use_sqlite()).await + } + + async fn with_sqlite_dir(username: &str, sqlite_dir: &Path) -> Self { + Self::from_client_builder(TestClientBuilder::new(username).use_sqlite_dir(sqlite_dir)).await + } + + async fn from_client_builder(builder: TestClientBuilder) -> Self { + let events = Arc::new(Mutex::new(Vec::new())); + + let client = SyncTokenAwareClient::new( + builder + .encryption_settings(Self::encryption_settings()) + .build() + .await + .expect("Failed to create client"), + ); + + let events_clone = events.clone(); + client.add_event_handler(|ev: OriginalSyncRoomMessageEvent, _: Client| async move { + events_clone.lock().unwrap().push(ev.event_id.clone()) + }); + + Self { client, events } + } + + fn encryption_settings() -> EncryptionSettings { + EncryptionSettings { auto_enable_cross_signing: true, ..Default::default() } + } + + fn timeout() -> Duration { + Duration::from_secs(10) + } + + async fn create_room(&self, invite: &[&ClientWrapper]) -> OwnedRoomId { + let invite = invite.iter().map(|cw| cw.client.user_id().unwrap().to_owned()).collect(); + + let request = assign!(CreateRoomRequest::new(), { + invite, + is_direct: true, + }); + + let room = self.client.create_room(request).await.expect("Failed to create room"); + room.enable_encryption().await.expect("Failed to enable encryption"); + room.room_id().to_owned() + } + + async fn join(&self, room_id: &RoomId) { + let room = self.wait_until_room_exists(room_id).await; + room.join().await.expect("Unable to join room") + } + + /// Wait (syncing if needed) until the room with supplied ID exists, or time out + async fn wait_until_room_exists(&self, room_id: &RoomId) -> Room { + let end_time = Instant::now() + Self::timeout(); + while Instant::now() < end_time { + let room = self.client.get_room(room_id); + if let Some(room) = room { + return room; + } + self.client.sync_once().await.expect("Sync failed"); + } + panic!("Timed out waiting for room {room_id} to exist"); + } + + /// Wait (syncing if needed) until the user appears in the supplied room, or time out + async fn wait_until_user_in_room(&self, room_id: &RoomId, other: &ClientWrapper) { + let room = self.wait_until_room_exists(room_id).await; + let user_id = other.client.user_id().unwrap(); + + let end_time = Instant::now() + Self::timeout(); + while Instant::now() < end_time { + if room.get_member_no_sync(user_id).await.expect("get_member failed").is_some() { + return; + } + self.client.sync_once().await.expect("Sync failed"); + } + panic!("Timed out waiting for user {user_id} to be in room {room_id}"); + } + + /// Wait (syncing if needed) until the event with this ID appears, or time out + async fn wait_until_received(&self, event_id: &EventId) { + let event_id = event_id.to_owned(); + let end_time = Instant::now() + Self::timeout(); + while Instant::now() < end_time { + if self.events.lock().unwrap().contains(&event_id) { + return; + } + self.client.sync_once().await.expect("Sync failed"); + } + panic!("Timed out waiting for event {event_id} to be received"); + } + + /// Send a text message in the supplied room and return the event ID + async fn send(&self, room_id: &RoomId, message: &str) -> OwnedEventId { + let room = self.wait_until_room_exists(room_id).await; + + room.send(RoomMessageEventContent::text_plain(message.to_owned())) + .await + .expect("Sending message failed") + .event_id + .to_owned() + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn test_multiple_clients_share_crypto_state() -> Result<()> { + let alice_sqlite_dir = tempdir()?; + let alice1 = ClientWrapper::with_sqlite_dir("alice", alice_sqlite_dir.path()).await; + let alice2 = ClientWrapper::with_sqlite_dir("alice", alice_sqlite_dir.path()).await; + let bob = ClientWrapper::new("bob").await; + + warn!("alice's device: {}", alice1.client.device_id().unwrap()); + warn!("bob's device: {}", bob.client.device_id().unwrap()); + + // TODO: surely both alice clients share the same device ID because they are sharing the same DB? + //assert_eq!(alice1.client.device_id(), alice2.client.device_id()); + + let room_id = alice1.create_room(&[&bob]).await; + + warn!("alice1 has created and enabled encryption in the room"); + + bob.join(&room_id).await; + alice1.wait_until_user_in_room(&room_id, &bob).await; + + warn!("alice1 and bob are both aware of each other in the e2ee room"); + + let msg1 = bob.send(&room_id, "msg1_from_bob").await; + alice1.wait_until_received(&msg1).await; + + warn!("alice1 received msg1 from bob"); + + let msg2 = bob.send(&room_id, "msg2_from_bob").await; + alice2.wait_until_received(&msg2).await; + + warn!("alice2 received msg2 from bob"); + + let msg3 = alice1.send(&room_id, "msg3_from_alice").await; + bob.wait_until_received(&msg3).await; + + warn!("bob received msg3 from alice1"); + + let msg4 = bob.send(&room_id, "msg4_from_bob").await; + alice1.wait_until_received(&msg4).await; + alice2.wait_until_received(&msg4).await; + + warn!("alice1 and alice2 both received msg4 from bob"); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn test_mutual_qrcode_verification() -> Result<()> { let encryption_settings =