diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt index 6d693de786..61c6e5db8e 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/UserSessionScope.kt @@ -1051,6 +1051,8 @@ class UserSessionScope internal constructor( userConfigRepository, featureSupport, clientIdProvider, + mlsClientProvider, + mlsPublicKeysRepository, userScopedLogger ) diff --git a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSelfUserSupportedProtocolsUseCase.kt b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSelfUserSupportedProtocolsUseCase.kt index 2e4fd59975..f48ceee115 100644 --- a/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSelfUserSupportedProtocolsUseCase.kt +++ b/logic/src/commonMain/kotlin/com/wire/kalium/logic/feature/user/UpdateSelfUserSupportedProtocolsUseCase.kt @@ -23,11 +23,14 @@ import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.client.Client import com.wire.kalium.logic.data.client.ClientRepository +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.client.isActive import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel import com.wire.kalium.logic.data.featureConfig.Status import com.wire.kalium.logic.data.id.CurrentClientIdProvider +import com.wire.kalium.logic.data.mls.CipherSuite +import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.feature.mlsmigration.hasMigrationEnded @@ -35,6 +38,8 @@ import com.wire.kalium.logic.featureFlags.FeatureSupport import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.functional.flatMap import com.wire.kalium.logic.functional.flatMapLeft +import com.wire.kalium.logic.functional.isLeft +import com.wire.kalium.logic.functional.isRight import com.wire.kalium.logic.functional.map import kotlinx.datetime.Instant @@ -51,12 +56,23 @@ internal class UpdateSelfUserSupportedProtocolsUseCaseImpl( private val userConfigRepository: UserConfigRepository, private val featureSupport: FeatureSupport, private val currentClientIdProvider: CurrentClientIdProvider, + private val mlsClientProvider: MLSClientProvider, + private val mlsPublicKeysRepository: MLSPublicKeysRepository, private val logger: KaliumLogger ) : UpdateSelfUserSupportedProtocolsUseCase { override suspend operator fun invoke(): Either { - return if (!featureSupport.isMLSSupported) { - logger.d("Skip updating supported protocols, since MLS is not supported.") + val mlsKey = mlsClientProvider.getMLSClient().flatMap { mlsClient -> + val cipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite()) + mlsPublicKeysRepository.getKeyForCipherSuite(cipherSuite) + } + + return if (!featureSupport.isMLSSupported || mlsKey.isLeft()) { + logger.d( + "Skip updating supported protocols, since MLS is not supported. " + + "Feature flag: ${featureSupport.isMLSSupported} " + + "Has mls key: ${mlsKey.isRight()}" + ) Either.Right(false) } else { (userRepository.getSelfUser()?.let { selfUser -> diff --git a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt index c13215c798..82ef208e00 100644 --- a/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt +++ b/logic/src/commonTest/kotlin/com/wire/kalium/logic/feature/user/UpdateSupportedProtocolsUseCaseTest.kt @@ -17,14 +17,20 @@ */ package com.wire.kalium.logic.feature.user +import com.wire.kalium.cryptography.MLSClient import com.wire.kalium.logger.KaliumLogger +import com.wire.kalium.logic.MLSFailure import com.wire.kalium.logic.StorageFailure import com.wire.kalium.logic.configuration.UserConfigRepository import com.wire.kalium.logic.data.client.Client import com.wire.kalium.logic.data.client.ClientRepository +import com.wire.kalium.logic.data.client.MLSClientProvider import com.wire.kalium.logic.data.conversation.ClientId import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel import com.wire.kalium.logic.data.featureConfig.Status +import com.wire.kalium.logic.data.mls.CipherSuite +import com.wire.kalium.logic.data.mls.MLSPublicKeys +import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository import com.wire.kalium.logic.data.user.SupportedProtocol import com.wire.kalium.logic.data.user.UserRepository import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.COMPLETED_MIGRATION_CONFIGURATION @@ -37,6 +43,7 @@ import com.wire.kalium.logic.functional.Either import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangement import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangementImpl import com.wire.kalium.logic.util.shouldSucceed +import io.ktor.util.decodeBase64Bytes import io.mockative.Mock import io.mockative.any import io.mockative.coEvery @@ -53,10 +60,27 @@ import kotlin.test.Test class UpdateSupportedProtocolsUseCaseTest { @Test - fun givenMLSIsNotSupported_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest { + fun givenMlsFeatureDisabledAndMlsKeyPresent_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest { val (arrangement, useCase) = Arrangement() .arrange { withIsMLSSupported(false) + withKeyForCipherSuite() + } + + useCase.invoke().shouldSucceed() + + coVerify { + arrangement.userRepository.updateSupportedProtocols(any()) + }.wasNotInvoked() + } + + + @Test + fun givenMlsFeatureEnabledAndMlsKeyNotPresent_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest { + val (arrangement, useCase) = Arrangement() + .arrange { + withIsMLSSupported(true) + withoutKeyForCipherSuite() } useCase.invoke().shouldSucceed() @@ -72,6 +96,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS)) withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS)) withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -92,6 +117,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful() withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS)) withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -112,6 +138,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful() withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -132,6 +159,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful() withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION) @@ -152,6 +180,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful() withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -176,6 +205,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful() withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -201,6 +231,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful() withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -226,6 +257,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful() withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS)) withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION) @@ -251,6 +283,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful() withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS, SupportedProtocol.MLS)) withGetMigrationConfigurationFailing(StorageFailure.DataNotFound) @@ -275,6 +308,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful() withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS)) withGetMigrationConfigurationSuccessful(DISABLED_MIGRATION_CONFIGURATION) @@ -299,6 +333,7 @@ class UpdateSupportedProtocolsUseCaseTest { .arrange { withCurrentClientIdSuccess(ClientId("1")) withIsMLSSupported(true) + withKeyForCipherSuite() withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS)) withGetSupportedProtocolsFailing(StorageFailure.DataNotFound) withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION) @@ -326,6 +361,15 @@ class UpdateSupportedProtocolsUseCaseTest { @Mock val featureSupport = mock(FeatureSupport::class) + @Mock + val mlsPublicKeysRepository = mock(MLSPublicKeysRepository::class) + + @Mock + val mlsClientProvider = mock(MLSClientProvider::class) + + @Mock + val mlsClient = mock(MLSClient::class) + private var kaliumLogger = KaliumLogger.disabled() fun withIsMLSSupported(supported: Boolean) = apply { @@ -334,6 +378,30 @@ class UpdateSupportedProtocolsUseCaseTest { }.returns(supported) } + suspend fun withKeyForCipherSuite() = apply { + coEvery { + mlsClientProvider.getMLSClient(any()) + }.returns(Either.Right(mlsClient)) + every { + mlsClient.getDefaultCipherSuite() + }.returns(CIPHER_SUITE.tag.toUShort()) + coEvery { + mlsPublicKeysRepository.getKeyForCipherSuite(any()) + }.returns(Either.Right(CRYPTO_MLS_PUBLIC_KEY)) + } + + suspend fun withoutKeyForCipherSuite() = apply { + coEvery { + mlsClientProvider.getMLSClient(any()) + }.returns(Either.Right(mlsClient)) + every { + mlsClient.getDefaultCipherSuite() + }.returns(CIPHER_SUITE.tag.toUShort()) + coEvery { + mlsPublicKeysRepository.getKeyForCipherSuite(any()) + }.returns(Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite")))) + } + suspend fun withGetSelfUserSuccessful(supportedProtocols: Set? = null) = apply { coEvery { userRepository.getSelfUser() } .returns(TestUser.SELF.copy(supportedProtocols = supportedProtocols)) @@ -383,6 +451,8 @@ class UpdateSupportedProtocolsUseCaseTest { userConfigRepository, featureSupport, currentClientIdProvider, + mlsClientProvider, + mlsPublicKeysRepository, kaliumLogger ) } @@ -397,6 +467,13 @@ class UpdateSupportedProtocolsUseCaseTest { .copy(endTime = Instant.DISTANT_PAST) val DISABLED_MIGRATION_CONFIGURATION = ONGOING_MIGRATION_CONFIGURATION .copy(status = Status.DISABLED) + val CIPHER_SUITE = CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256 + val MLS_PUBLIC_KEY = MLSPublicKeys( + removal = mapOf( + "ed25519" to "gRNvFYReriXbzsGu7zXiPtS8kaTvhU1gUJEV9rdFHVw=" + ) + ) + val CRYPTO_MLS_PUBLIC_KEY: ByteArray = MLS_PUBLIC_KEY.removal?.get("ed25519")!!.decodeBase64Bytes() } } }