diff --git a/iris-mpc-cpu/src/network/grpc.rs b/iris-mpc-cpu/src/network/grpc.rs index 1e1a0542a..488c42878 100644 --- a/iris-mpc-cpu/src/network/grpc.rs +++ b/iris-mpc-cpu/src/network/grpc.rs @@ -98,6 +98,10 @@ impl OutgoingStreams { )) .map(|s| s.value().clone()) } + + fn contains_session(&self, session_id: SessionId) -> bool { + self.streams.iter().any(|v| v.key().0 == session_id) + } } #[derive(Default, Clone)] @@ -139,7 +143,7 @@ impl GrpcNetworking { } pub async fn create_session(&self, session_id: SessionId) -> eyre::Result<()> { - if self.message_queues.contains_key(&session_id) { + if self.outgoing_streams.contains_session(session_id) { return Err(eyre!( "Player {:?} has already created session {session_id:?}", self.party_id @@ -403,6 +407,7 @@ mod tests { // Each party sending and receiving messages to each other { + let players = players.clone(); jobs.spawn(async move { let session_id = SessionId::from(1); @@ -455,6 +460,18 @@ mod tests { }); } + // Parties create a session consecutively + { + let players = players.clone(); + jobs.spawn(async move { + let session_id = SessionId::from(2); + + for player in players.iter() { + player.create_session(session_id).await.unwrap(); + } + }); + } + jobs.join_all().await; Ok(())