Skip to content

Commit

Permalink
feat(mls-migration): force migration when migration deadline arrives …
Browse files Browse the repository at this point in the history
…#10 (#1831)

* feat: end migration regardless when migration deadline arrives

* refactor: generalise methods for fetching conversations ids

* refactor: better naming
  • Loading branch information
typfel committed Aug 16, 2023
1 parent af35f17 commit 1f01908
Show file tree
Hide file tree
Showing 19 changed files with 113 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ private fun ConversationEntity.AccessRole.toDAO(): Conversation.AccessRole = whe
ConversationEntity.AccessRole.EXTERNAL -> Conversation.AccessRole.EXTERNAL
}

private fun Conversation.Type.toDAO(): ConversationEntity.Type = when (this) {
internal fun Conversation.Type.toDAO(): ConversationEntity.Type = when (this) {
Conversation.Type.SELF -> ConversationEntity.Type.SELF
Conversation.Type.ONE_ON_ONE -> ConversationEntity.Type.ONE_ON_ONE
Conversation.Type.GROUP -> ConversationEntity.Type.GROUP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ interface ConversationRepository {

suspend fun getConversationList(): Either<StorageFailure, Flow<List<Conversation>>>
suspend fun observeConversationList(): Flow<List<Conversation>>
suspend fun getProteusTeamConversations(teamId: TeamId): Either<StorageFailure, List<QualifiedID>>
suspend fun getProteusTeamConversationsReadyForFinalisation(teamId: TeamId): Either<StorageFailure, List<QualifiedID>>
suspend fun getConversationIds(
type: Conversation.Type,
protocol: Conversation.Protocol,
teamId: TeamId? = null
): Either<StorageFailure, List<QualifiedID>>
suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: TeamId): Either<StorageFailure, List<QualifiedID>>
suspend fun observeConversationListDetails(): Flow<List<ConversationDetails>>
suspend fun observeConversationDetailsById(conversationID: ConversationId): Flow<Either<StorageFailure, ConversationDetails>>
suspend fun fetchConversation(conversationID: ConversationId): Either<CoreFailure, Unit>
Expand All @@ -129,7 +133,6 @@ interface ConversationRepository {
suspend fun getRecipientById(conversationId: ConversationId, userIDList: List<UserId>): Either<StorageFailure, List<Recipient>>
suspend fun getConversationRecipientsForCalling(conversationId: ConversationId): Either<CoreFailure, List<Recipient>>
suspend fun getConversationProtocolInfo(conversationId: ConversationId): Either<StorageFailure, Conversation.ProtocolInfo>
suspend fun getGroupConversationIdsByProtocol(protocol: Conversation.Protocol): Either<StorageFailure, List<ConversationId>>
suspend fun observeConversationMembers(conversationID: ConversationId): Flow<List<Conversation.Member>>

/**
Expand Down Expand Up @@ -398,15 +401,18 @@ internal class ConversationDataSource internal constructor(
return conversationDAO.getAllConversations().map { it.map(conversationMapper::fromDaoModel) }
}

override suspend fun getProteusTeamConversations(teamId: TeamId): Either<StorageFailure, List<QualifiedID>> =
override suspend fun getConversationIds(
type: Conversation.Type,
protocol: Conversation.Protocol,
teamId: TeamId?
): Either<StorageFailure, List<QualifiedID>> =
wrapStorageRequest {
conversationDAO.getAllProteusTeamConversations(teamId.value)
conversationDAO.getConversationIds(type.toDAO(), protocol.toDao(), teamId?.value)
.map { it.toModel() }
}

override suspend fun getProteusTeamConversationsReadyForFinalisation(teamId: TeamId): Either<StorageFailure, List<QualifiedID>> =
override suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: TeamId): Either<StorageFailure, List<QualifiedID>> =
wrapStorageRequest {
conversationDAO.getAllProteusTeamConversationsReadyToBeFinalised(teamId.value)
conversationDAO.getTeamConversationIdsReadyToCompleteMigration(teamId.value)
.map { it.toModel() }
}

Expand Down Expand Up @@ -508,11 +514,6 @@ internal class ConversationDataSource internal constructor(
}
}

override suspend fun getGroupConversationIdsByProtocol(protocol: Conversation.Protocol): Either<StorageFailure, List<ConversationId>> =
wrapStorageRequest {
conversationDAO.getGroupConversationIdsByProtocol(protocol.toDao()).map(QualifiedIDEntity::toModel)
}

override suspend fun observeConversationMembers(conversationID: ConversationId): Flow<List<Conversation.Member>> =
memberDAO.observeConversationMembers(conversationID.toDao()).map { members ->
members.map(memberMapper::fromDaoModel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ internal class SyncConversationsUseCaseImpl(
private val systemMessageInserter: SystemMessageInserter
) : SyncConversationsUseCase {
override suspend operator fun invoke(): Either<CoreFailure, Unit> =
conversationRepository.getGroupConversationIdsByProtocol(Conversation.Protocol.PROTEUS)
conversationRepository.getConversationIds(Conversation.Type.GROUP, Conversation.Protocol.PROTEUS)
.flatMap { proteusConversationIds ->
conversationRepository.fetchConversations()
.flatMap {
Expand All @@ -49,7 +49,7 @@ internal class SyncConversationsUseCaseImpl(
private suspend fun reportConversationsWithPotentialHistoryLoss(
proteusConversationIds: List<ConversationId>
): Either<StorageFailure, Unit> =
conversationRepository.getGroupConversationIdsByProtocol(Conversation.Protocol.MLS)
conversationRepository.getConversationIds(Conversation.Type.GROUP, Conversation.Protocol.MLS)
.flatMap { mlsConversationIds ->
val conversationsWithUpgradedProtocol = mlsConversationIds.intersect(proteusConversationIds)
for (conversationId in conversationsWithUpgradedProtocol) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ class MLSMigrationWorkerImpl(
kaliumLogger.i("Running proteus to MLS migration")
updateSupportedProtocols().flatMap {
mlsMigrator.migrateProteusConversations().flatMap {
if (configuration.hasMigrationEnded()) {
mlsMigrator.finaliseAllProteusConversations()
} else {
mlsMigrator.finaliseProteusConversations()
}
}
}
} else {
kaliumLogger.i("MLS migration is not enabled")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import com.wire.kalium.logic.kaliumLogger
interface MLSMigrator {
suspend fun migrateProteusConversations(): Either<CoreFailure, Unit>
suspend fun finaliseProteusConversations(): Either<CoreFailure, Unit>
suspend fun finaliseAllProteusConversations(): Either<CoreFailure, Unit>
}
internal class MLSMigratorImpl(
private val selfUserId: UserId,
Expand All @@ -52,21 +53,33 @@ internal class MLSMigratorImpl(
selfTeamIdProvider().flatMap {
it?.let { Either.Right(it) } ?: Either.Left(StorageFailure.DataNotFound)
}.flatMap { teamId ->
conversationRepository.getProteusTeamConversations(teamId)
conversationRepository.getConversationIds(Conversation.Type.GROUP, Protocol.PROTEUS, teamId)
.flatMap {
it.foldToEitherWhileRight(Unit) { conversationId, _ ->
migrate(conversationId)
}
}
}

override suspend fun finaliseAllProteusConversations(): Either<CoreFailure, Unit> =
selfTeamIdProvider().flatMap {
it?.let { Either.Right(it) } ?: Either.Left(StorageFailure.DataNotFound)
}.flatMap { teamId ->
conversationRepository.getConversationIds(Conversation.Type.GROUP, Protocol.MIXED, teamId)
.flatMap {
it.foldToEitherWhileRight(Unit) { conversationId, _ ->
finalise(conversationId)
}
}
}

override suspend fun finaliseProteusConversations(): Either<CoreFailure, Unit> =
selfTeamIdProvider().flatMap {
it?.let { Either.Right(it) } ?: Either.Left(StorageFailure.DataNotFound)
}.flatMap { teamId ->
userRepository.fetchAllOtherUsers()
.flatMap {
conversationRepository.getProteusTeamConversationsReadyForFinalisation(teamId)
conversationRepository.getTeamConversationIdsReadyToCompleteMigration(teamId)
.flatMap {
it.foldToEitherWhileRight(Unit) { conversationId, _ ->
finalise(conversationId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.wire.kalium.logic.util.shouldSucceed
import com.wire.kalium.network.api.base.authenticated.conversation.ConvProtocol
import com.wire.kalium.network.api.base.authenticated.featureConfigs.AppLockConfigDTO
import com.wire.kalium.network.api.base.authenticated.featureConfigs.ClassifiedDomainsConfigDTO
import com.wire.kalium.network.api.base.authenticated.featureConfigs.E2EIConfigDTO
import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureConfigApi
import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureConfigData
import com.wire.kalium.network.api.base.authenticated.featureConfigs.FeatureConfigResponse
Expand Down Expand Up @@ -167,6 +168,10 @@ class MLSMigrationRepositoryTest {
emptyList(),
1
), FeatureFlagStatusDTO.ENABLED),
FeatureConfigData.E2EI(
E2EIConfigDTO(null),
FeatureFlagStatusDTO.ENABLED
),
FeatureConfigData.MLSMigration(
MLSMigrationConfigDTO(Instant.DISTANT_FUTURE, Instant.DISTANT_FUTURE),
FeatureFlagStatusDTO.ENABLED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.wire.kalium.logic.data.id.ConversationId
import com.wire.kalium.logic.data.id.TeamId
import com.wire.kalium.logic.data.user.ConnectionState
import com.wire.kalium.logic.data.user.OtherUser
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserAssetId
import com.wire.kalium.logic.data.user.UserAvailabilityStatus
import com.wire.kalium.logic.data.user.UserId
Expand Down Expand Up @@ -166,7 +167,8 @@ class EndCallOnConversationChangeUseCaseTest {
userType = UserType.INTERNAL,
botService = null,
deleted = true,
defederated = false
defederated = false,
supportedProtocols = setOf(SupportedProtocol.PROTEUS)
)

private val groupConversationDetail = ConversationDetails.Group(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class GetConversationVerificationStatusUseCaseTest {
val MLS_CONVERSATION1 = TestConversation.GROUP(
Conversation.ProtocolInfo.MLS(
GROUP_ID1,
Conversation.ProtocolInfo.MLS.GroupState.PENDING_JOIN,
Conversation.ProtocolInfo.MLSCapable.GroupState.PENDING_JOIN,
epoch = 1UL,
keyingMaterialLastUpdate = DateTimeUtil.currentInstant(),
cipherSuite = Conversation.CipherSuite.MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class SyncConversationsUseCaseTest {
protocol: Conversation.Protocol? = null
) = apply {
given(conversationRepository)
.suspendFunction(conversationRepository::getGroupConversationIdsByProtocol)
.whenInvokedWith(protocol?.let { eq(it) } ?: any())
.suspendFunction(conversationRepository::getConversationIds)
.whenInvokedWith(eq(Conversation.Type.GROUP), protocol?.let { eq(it) } ?: any(), eq(null))
.thenReturn(Either.Right(conversationIds))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,14 @@ class MLSMigratorTest {

fun withGetProteusTeamConversationsReturning(conversationsIds: List<ConversationId>) = apply {
given(conversationRepository)
.suspendFunction(conversationRepository::getProteusTeamConversations)
.whenInvokedWith(anything())
.suspendFunction(conversationRepository::getConversationIds)
.whenInvokedWith(eq(Conversation.Type.GROUP), eq(Conversation.Protocol.PROTEUS), anything())
.thenReturn(Either.Right(conversationsIds))
}

fun withGetProteusTeamConversationsReadyForFinalisationReturning(conversationsIds: List<ConversationId>) = apply {
given(conversationRepository)
.suspendFunction(conversationRepository::getProteusTeamConversationsReadyForFinalisation)
.suspendFunction(conversationRepository::getTeamConversationIdsReadyToCompleteMigration)
.whenInvokedWith(anything())
.thenReturn(Either.Right(conversationsIds))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.wire.kalium.logic.data.asset.KaliumFileSystem
import com.wire.kalium.logic.data.asset.UploadedAssetId
import com.wire.kalium.logic.data.user.ConnectionState
import com.wire.kalium.logic.data.user.SelfUser
import com.wire.kalium.logic.data.user.SupportedProtocol
import com.wire.kalium.logic.data.user.UserAssetId
import com.wire.kalium.logic.data.user.UserAvailabilityStatus
import com.wire.kalium.logic.data.user.UserId
Expand Down Expand Up @@ -130,7 +131,8 @@ class UploadUserAvatarUseCaseTest {
UserAssetId("value1", "domain"),
UserAssetId("value2", "domain"),
UserAvailabilityStatus.NONE,
null
null,
setOf(SupportedProtocol.PROTEUS)
)

fun withStoredData(data: ByteArray, dataNamePath: Path): Arrangement {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ internal open class SelfApiV0 internal constructor(
override suspend fun updateSupportedProtocols(protocols: List<SupportedProtocolDTO>): NetworkResponse<Unit> =
getApiNotSupportError(::updateSupportedProtocols.name, 4)

private companion object {
companion object {
const val PATH_SELF = "self"
const val PATH_HANDLE = "handle"
const val PATH_ACCESS = "access"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,13 @@ ORDER BY lastModifiedDate DESC, name COLLATE NOCASE ASC;
selectAllConversations:
SELECT * FROM ConversationDetails WHERE type IS NOT 'CONNECTION_PENDING' ORDER BY last_modified_date DESC, name ASC;

selectAllTeamProteusConversations:
SELECT qualified_id FROM Conversation WHERE type IS 'GROUP' AND protocol IS 'PROTEUS' AND team_id = ?;

selectAllTeamProteusConversationsReadyForMigration:
SELECT
qualified_id,
(SELECT count(*) FROM Member WHERE conversation = qualified_id) AS memberCount,
(SELECT count(*) FROM Member LEFT JOIN User ON User.qualified_id = Member.user WHERE Member.conversation = Conversation.qualified_id AND (User.supported_protocols = 'MLS' OR User.supported_protocols = 'MLS,PROTEUS' OR User.supported_protocols = 'PROTEUS,MLS')) AS mlsCapableMemberCount
FROM Conversation
WHERE type IS 'GROUP' AND protocol IS 'PROTEUS' AND team_id = ? AND memberCount = mlsCapableMemberCount;
WHERE type IS 'GROUP' AND protocol IS 'MIXED' AND team_id = ? AND memberCount = mlsCapableMemberCount;

selectByQualifiedId:
SELECT * FROM ConversationDetails WHERE qualifiedId = ?;
Expand All @@ -273,8 +270,8 @@ SELECT * FROM ConversationDetails WHERE mls_group_state = ? AND (protocol IS "ML
getConversationIdByGroupId:
SELECT qualified_id FROM Conversation WHERE mls_group_id = ?;

selectGroupConversationIdsByProtocol:
SELECT qualified_id FROM Conversation WHERE protocol = ? AND type IS 'GROUP';
selectConversationIds:
SELECT qualified_id FROM Conversation WHERE protocol = :protocol AND type = :type AND (:teamId IS NULL OR team_id = :teamId);

updateConversationMutingStatus:
UPDATE Conversation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import com.wire.kalium.persistence.dao.ConversationEntity;
import com.wire.kalium.persistence.dao.conversation.ConversationEntity;
import com.wire.kalium.persistence.dao.QualifiedIDEntity;
import com.wire.kalium.persistence.dao.message.MessageEntity.ContentType;
import com.wire.kalium.persistence.dao.message.MessageEntity.FederationType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ interface ConversationDAO {
suspend fun updateAllConversationsNotificationDate()
suspend fun getAllConversations(): Flow<List<ConversationViewEntity>>
suspend fun getAllConversationDetails(): Flow<List<ConversationViewEntity>>
suspend fun getAllProteusTeamConversations(teamId: String): List<QualifiedIDEntity>
suspend fun getAllProteusTeamConversationsReadyToBeFinalised(teamId: String): List<QualifiedIDEntity>
suspend fun getConversationIds(
type: ConversationEntity.Type,
protocol: ConversationEntity.Protocol,
teamId: String? = null
): List<QualifiedIDEntity>
suspend fun getTeamConversationIdsReadyToCompleteMigration(teamId: String): List<QualifiedIDEntity>
suspend fun observeGetConversationByQualifiedID(qualifiedID: QualifiedIDEntity): Flow<ConversationViewEntity?>
suspend fun observeGetConversationBaseInfoByQualifiedID(qualifiedID: QualifiedIDEntity): Flow<ConversationEntity?>
suspend fun getConversationBaseInfoByQualifiedID(qualifiedID: QualifiedIDEntity): ConversationEntity?
Expand All @@ -51,7 +55,6 @@ interface ConversationDAO {
suspend fun getConversationProtocolInfo(qualifiedID: QualifiedIDEntity): ConversationEntity.ProtocolInfo?
suspend fun getConversationByGroupID(groupID: String): Flow<ConversationViewEntity?>
suspend fun getConversationIdByGroupID(groupID: String): QualifiedIDEntity?
suspend fun getGroupConversationIdsByProtocol(protocol: ConversationEntity.Protocol): List<QualifiedIDEntity>
suspend fun getConversationsByGroupState(groupState: ConversationEntity.GroupState): List<ConversationViewEntity>
suspend fun deleteConversationByQualifiedID(qualifiedID: QualifiedIDEntity)

Expand All @@ -75,7 +78,7 @@ interface ConversationDAO {
suspend fun whoDeletedMeInConversation(conversationId: QualifiedIDEntity, selfUserIdString: String): UserIDEntity?
suspend fun updateConversationName(conversationId: QualifiedIDEntity, conversationName: String, timestamp: String)
suspend fun updateConversationType(conversationID: QualifiedIDEntity, type: ConversationEntity.Type)
suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: com.wire.kalium.persistence.dao.ConversationEntity.Protocol): Boolean
suspend fun updateConversationProtocol(conversationId: QualifiedIDEntity, protocol: ConversationEntity.Protocol): Boolean
suspend fun revokeOneOnOneConversationsWithDeletedUser(userId: UserIDEntity)
suspend fun getConversationIdsByUserId(userId: UserIDEntity): List<QualifiedIDEntity>
suspend fun updateConversationReceiptMode(conversationID: QualifiedIDEntity, receiptMode: ConversationEntity.ReceiptMode)
Expand Down
Loading

0 comments on commit 1f01908

Please sign in to comment.