Skip to content

Commit

Permalink
feat: Initialize mls only according to backend feature flag #WPB-10117
Browse files Browse the repository at this point in the history
  • Loading branch information
m-zagorski committed Oct 25, 2024
1 parent ebcf95a commit 23d517e
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,8 @@ class UserSessionScope internal constructor(
userConfigRepository,
featureSupport,
clientIdProvider,
mlsClientProvider,
mlsPublicKeysRepository,
userScopedLogger
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,23 @@ import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.client.isActive
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel
import com.wire.kalium.logic.data.featureConfig.Status
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.feature.mlsmigration.hasMigrationEnded
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.flatMapLeft
import com.wire.kalium.logic.functional.isLeft
import com.wire.kalium.logic.functional.isRight
import com.wire.kalium.logic.functional.map
import kotlinx.datetime.Instant

Expand All @@ -51,12 +56,23 @@ internal class UpdateSelfUserSupportedProtocolsUseCaseImpl(
private val userConfigRepository: UserConfigRepository,
private val featureSupport: FeatureSupport,
private val currentClientIdProvider: CurrentClientIdProvider,
private val mlsClientProvider: MLSClientProvider,
private val mlsPublicKeysRepository: MLSPublicKeysRepository,
private val logger: KaliumLogger
) : UpdateSelfUserSupportedProtocolsUseCase {

override suspend operator fun invoke(): Either<CoreFailure, Boolean> {
return if (!featureSupport.isMLSSupported) {
logger.d("Skip updating supported protocols, since MLS is not supported.")
val mlsKey = mlsClientProvider.getMLSClient().flatMap { mlsClient ->
val cipherSuite: CipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite())
mlsPublicKeysRepository.getKeyForCipherSuite(cipherSuite)
}

return if (!featureSupport.isMLSSupported || mlsKey.isLeft()) {
logger.d(
"Skip updating supported protocols, since MLS is not supported. " +
"Feature flag: ${featureSupport.isMLSSupported} " +
"Has mls key: ${mlsKey.isRight()}"
)
Either.Right(false)
} else {
(userRepository.getSelfUser()?.let { selfUser ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
*/
package com.wire.kalium.logic.feature.user

import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel
import com.wire.kalium.logic.data.featureConfig.Status
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mls.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.COMPLETED_MIGRATION_CONFIGURATION
Expand All @@ -37,6 +43,7 @@ import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangement
import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangementImpl
import com.wire.kalium.logic.util.shouldSucceed
import io.ktor.util.decodeBase64Bytes
import io.mockative.Mock
import io.mockative.any
import io.mockative.coEvery
Expand All @@ -53,10 +60,27 @@ import kotlin.test.Test
class UpdateSupportedProtocolsUseCaseTest {

@Test
fun givenMLSIsNotSupported_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest {
fun givenMlsFeatureDisabledAndMlsKeyPresent_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest {
val (arrangement, useCase) = Arrangement()
.arrange {
withIsMLSSupported(false)
withKeyForCipherSuite()
}

useCase.invoke().shouldSucceed()

coVerify {
arrangement.userRepository.updateSupportedProtocols(any())
}.wasNotInvoked()
}


@Test
fun givenMlsFeatureEnabledAndMlsKeyNotPresent_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest {
val (arrangement, useCase) = Arrangement()
.arrange {
withIsMLSSupported(true)
withoutKeyForCipherSuite()
}

useCase.invoke().shouldSucceed()
Expand All @@ -72,6 +96,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS))
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -92,6 +117,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -112,6 +138,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -132,6 +159,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION)
Expand All @@ -152,6 +180,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -176,6 +205,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -201,6 +231,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -226,6 +257,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION)
Expand All @@ -251,6 +283,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS, SupportedProtocol.MLS))
withGetMigrationConfigurationFailing(StorageFailure.DataNotFound)
Expand All @@ -275,6 +308,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS))
withGetMigrationConfigurationSuccessful(DISABLED_MIGRATION_CONFIGURATION)
Expand All @@ -299,6 +333,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS))
withGetSupportedProtocolsFailing(StorageFailure.DataNotFound)
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand Down Expand Up @@ -326,6 +361,15 @@ class UpdateSupportedProtocolsUseCaseTest {
@Mock
val featureSupport = mock(FeatureSupport::class)

@Mock
val mlsPublicKeysRepository = mock(MLSPublicKeysRepository::class)

@Mock
val mlsClientProvider = mock(MLSClientProvider::class)

@Mock
val mlsClient = mock(MLSClient::class)

private var kaliumLogger = KaliumLogger.disabled()

fun withIsMLSSupported(supported: Boolean) = apply {
Expand All @@ -334,6 +378,30 @@ class UpdateSupportedProtocolsUseCaseTest {
}.returns(supported)
}

suspend fun withKeyForCipherSuite() = apply {
coEvery {
mlsClientProvider.getMLSClient(any())
}.returns(Either.Right(mlsClient))
every {
mlsClient.getDefaultCipherSuite()
}.returns(CIPHER_SUITE.tag.toUShort())
coEvery {
mlsPublicKeysRepository.getKeyForCipherSuite(any())
}.returns(Either.Right(CRYPTO_MLS_PUBLIC_KEY))
}

suspend fun withoutKeyForCipherSuite() = apply {
coEvery {
mlsClientProvider.getMLSClient(any())
}.returns(Either.Right(mlsClient))
every {
mlsClient.getDefaultCipherSuite()
}.returns(CIPHER_SUITE.tag.toUShort())
coEvery {
mlsPublicKeysRepository.getKeyForCipherSuite(any())
}.returns(Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite"))))
}

suspend fun withGetSelfUserSuccessful(supportedProtocols: Set<SupportedProtocol>? = null) = apply {
coEvery { userRepository.getSelfUser() }
.returns(TestUser.SELF.copy(supportedProtocols = supportedProtocols))
Expand Down Expand Up @@ -383,6 +451,8 @@ class UpdateSupportedProtocolsUseCaseTest {
userConfigRepository,
featureSupport,
currentClientIdProvider,
mlsClientProvider,
mlsPublicKeysRepository,
kaliumLogger
)
}
Expand All @@ -397,6 +467,13 @@ class UpdateSupportedProtocolsUseCaseTest {
.copy(endTime = Instant.DISTANT_PAST)
val DISABLED_MIGRATION_CONFIGURATION = ONGOING_MIGRATION_CONFIGURATION
.copy(status = Status.DISABLED)
val CIPHER_SUITE = CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256
val MLS_PUBLIC_KEY = MLSPublicKeys(
removal = mapOf(
"ed25519" to "gRNvFYReriXbzsGu7zXiPtS8kaTvhU1gUJEV9rdFHVw="
)
)
val CRYPTO_MLS_PUBLIC_KEY: ByteArray = MLS_PUBLIC_KEY.removal?.get("ed25519")!!.decodeBase64Bytes()
}
}
}

0 comments on commit 23d517e

Please sign in to comment.