Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Initialize mls only according to backend feature flag #WPB-10117 #3076

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,8 @@ class UserSessionScope internal constructor(
userConfigRepository,
featureSupport,
clientIdProvider,
mlsClientProvider,
mlsPublicKeysRepository,
userScopedLogger
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,23 @@ import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.client.isActive
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel
import com.wire.kalium.logic.data.featureConfig.Status
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.feature.mlsmigration.hasMigrationEnded
import com.wire.kalium.logic.featureFlags.FeatureSupport
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.flatMapLeft
import com.wire.kalium.logic.functional.isLeft
import com.wire.kalium.logic.functional.isRight
import com.wire.kalium.logic.functional.map
import kotlinx.datetime.Instant

Expand All @@ -51,12 +56,23 @@ internal class UpdateSelfUserSupportedProtocolsUseCaseImpl(
private val userConfigRepository: UserConfigRepository,
private val featureSupport: FeatureSupport,
private val currentClientIdProvider: CurrentClientIdProvider,
private val mlsClientProvider: MLSClientProvider,
private val mlsPublicKeysRepository: MLSPublicKeysRepository,
private val logger: KaliumLogger
) : UpdateSelfUserSupportedProtocolsUseCase {

override suspend operator fun invoke(): Either<CoreFailure, Boolean> {
return if (!featureSupport.isMLSSupported) {
logger.d("Skip updating supported protocols, since MLS is not supported.")
val mlsKey = mlsClientProvider.getMLSClient().flatMap { mlsClient ->
val cipherSuite: CipherSuite = CipherSuite.fromTag(mlsClient.getDefaultCipherSuite())
mlsPublicKeysRepository.getKeyForCipherSuite(cipherSuite)
}

return if (!featureSupport.isMLSSupported || mlsKey.isLeft()) {
logger.d(
"Skip updating supported protocols, since MLS is not supported. " +
"Feature flag: ${featureSupport.isMLSSupported} " +
"Has mls key: ${mlsKey.isRight()}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: not sure what this changes from the code before?
Creating an object of MLS client will create some local files/db for MLS and if there is no need to create it uus better to not init mls untill it is needed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • looking at the Jira ticket i think the app is already doing what it described in the ticket

Either.Right(false)
} else {
(userRepository.getSelfUser()?.let { selfUser ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
*/
package com.wire.kalium.logic.feature.user

import com.wire.kalium.cryptography.MLSClient
import com.wire.kalium.logger.KaliumLogger
import com.wire.kalium.logic.MLSFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.client.Client
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.client.MLSClientProvider
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.featureConfig.MLSMigrationModel
import com.wire.kalium.logic.data.featureConfig.Status
import com.wire.kalium.logic.data.mls.CipherSuite
import com.wire.kalium.logic.data.mls.MLSPublicKeys
import com.wire.kalium.logic.data.mlspublickeys.MLSPublicKeysRepository
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.feature.user.UpdateSupportedProtocolsUseCaseTest.Arrangement.Companion.COMPLETED_MIGRATION_CONFIGURATION
Expand All @@ -37,6 +43,7 @@ import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangement
import com.wire.kalium.logic.util.arrangement.provider.CurrentClientIdProviderArrangementImpl
import com.wire.kalium.logic.util.shouldSucceed
import io.ktor.util.decodeBase64Bytes
import io.mockative.Mock
import io.mockative.any
import io.mockative.coEvery
Expand All @@ -53,10 +60,27 @@ import kotlin.test.Test
class UpdateSupportedProtocolsUseCaseTest {

@Test
fun givenMLSIsNotSupported_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest {
fun givenMlsFeatureDisabledAndMlsKeyPresent_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest {
val (arrangement, useCase) = Arrangement()
.arrange {
withIsMLSSupported(false)
withKeyForCipherSuite()
}

useCase.invoke().shouldSucceed()

coVerify {
arrangement.userRepository.updateSupportedProtocols(any())
}.wasNotInvoked()
}


@Test
fun givenMlsFeatureEnabledAndMlsKeyNotPresent_whenInvokingUseCase_thenSupportedProtocolsAreNotUpdated() = runTest {
val (arrangement, useCase) = Arrangement()
.arrange {
withIsMLSSupported(true)
withoutKeyForCipherSuite()
}

useCase.invoke().shouldSucceed()
Expand All @@ -72,6 +96,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS))
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -92,6 +117,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -112,6 +138,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -132,6 +159,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION)
Expand All @@ -152,6 +180,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -176,6 +205,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -201,6 +231,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand All @@ -226,6 +257,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.MLS))
withGetMigrationConfigurationSuccessful(COMPLETED_MIGRATION_CONFIGURATION)
Expand All @@ -251,6 +283,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS, SupportedProtocol.MLS))
withGetMigrationConfigurationFailing(StorageFailure.DataNotFound)
Expand All @@ -275,6 +308,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful()
withGetSupportedProtocolsSuccessful(setOf(SupportedProtocol.PROTEUS))
withGetMigrationConfigurationSuccessful(DISABLED_MIGRATION_CONFIGURATION)
Expand All @@ -299,6 +333,7 @@ class UpdateSupportedProtocolsUseCaseTest {
.arrange {
withCurrentClientIdSuccess(ClientId("1"))
withIsMLSSupported(true)
withKeyForCipherSuite()
withGetSelfUserSuccessful(supportedProtocols = setOf(SupportedProtocol.PROTEUS))
withGetSupportedProtocolsFailing(StorageFailure.DataNotFound)
withGetMigrationConfigurationSuccessful(ONGOING_MIGRATION_CONFIGURATION)
Expand Down Expand Up @@ -326,6 +361,15 @@ class UpdateSupportedProtocolsUseCaseTest {
@Mock
val featureSupport = mock(FeatureSupport::class)

@Mock
val mlsPublicKeysRepository = mock(MLSPublicKeysRepository::class)

@Mock
val mlsClientProvider = mock(MLSClientProvider::class)

@Mock
val mlsClient = mock(MLSClient::class)

private var kaliumLogger = KaliumLogger.disabled()

fun withIsMLSSupported(supported: Boolean) = apply {
Expand All @@ -334,6 +378,30 @@ class UpdateSupportedProtocolsUseCaseTest {
}.returns(supported)
}

suspend fun withKeyForCipherSuite() = apply {
coEvery {
mlsClientProvider.getMLSClient(any())
}.returns(Either.Right(mlsClient))
every {
mlsClient.getDefaultCipherSuite()
}.returns(CIPHER_SUITE.tag.toUShort())
coEvery {
mlsPublicKeysRepository.getKeyForCipherSuite(any())
}.returns(Either.Right(CRYPTO_MLS_PUBLIC_KEY))
}

suspend fun withoutKeyForCipherSuite() = apply {
coEvery {
mlsClientProvider.getMLSClient(any())
}.returns(Either.Right(mlsClient))
every {
mlsClient.getDefaultCipherSuite()
}.returns(CIPHER_SUITE.tag.toUShort())
coEvery {
mlsPublicKeysRepository.getKeyForCipherSuite(any())
}.returns(Either.Left(MLSFailure.Generic(IllegalStateException("No key found for cipher suite"))))
}

suspend fun withGetSelfUserSuccessful(supportedProtocols: Set<SupportedProtocol>? = null) = apply {
coEvery { userRepository.getSelfUser() }
.returns(TestUser.SELF.copy(supportedProtocols = supportedProtocols))
Expand Down Expand Up @@ -383,6 +451,8 @@ class UpdateSupportedProtocolsUseCaseTest {
userConfigRepository,
featureSupport,
currentClientIdProvider,
mlsClientProvider,
mlsPublicKeysRepository,
kaliumLogger
)
}
Expand All @@ -397,6 +467,13 @@ class UpdateSupportedProtocolsUseCaseTest {
.copy(endTime = Instant.DISTANT_PAST)
val DISABLED_MIGRATION_CONFIGURATION = ONGOING_MIGRATION_CONFIGURATION
.copy(status = Status.DISABLED)
val CIPHER_SUITE = CipherSuite.MLS_128_DHKEMP256_AES128GCM_SHA256_P256
val MLS_PUBLIC_KEY = MLSPublicKeys(
removal = mapOf(
"ed25519" to "gRNvFYReriXbzsGu7zXiPtS8kaTvhU1gUJEV9rdFHVw="
)
)
val CRYPTO_MLS_PUBLIC_KEY: ByteArray = MLS_PUBLIC_KEY.removal?.get("ed25519")!!.decodeBase64Bytes()
}
}
}
Loading