Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: asset restriction #2831

Merged
merged 6 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ class UserSessionScope internal constructor(
protoContentMapper,
observeSelfDeletingMessages,
messageMetadataRepository,
observeFileSharingStatus,
this
)
val users: UserScope
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.wire.kalium.cryptography.utils.AES256Key
import com.wire.kalium.cryptography.utils.SHA256Key
import com.wire.kalium.cryptography.utils.generateRandomAES256Key
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.configuration.FileSharingStatus
import com.wire.kalium.logic.data.asset.AssetRepository
import com.wire.kalium.logic.data.asset.UploadedAssetId
import com.wire.kalium.logic.data.asset.isAudioMimeType
Expand All @@ -42,6 +43,7 @@ import com.wire.kalium.logic.feature.CurrentClientIdProvider
import com.wire.kalium.logic.feature.message.MessageSendFailureHandler
import com.wire.kalium.logic.feature.message.MessageSender
import com.wire.kalium.logic.feature.selfDeletingMessages.ObserveSelfDeletionTimerSettingsForConversationUseCase
import com.wire.kalium.logic.feature.user.ObserveFileSharingStatusUseCase
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.fold
Expand Down Expand Up @@ -106,6 +108,8 @@ internal class ScheduleNewAssetMessageUseCaseImpl(
private val userPropertyRepository: UserPropertyRepository,
private val selfDeleteTimer: ObserveSelfDeletionTimerSettingsForConversationUseCase,
private val scope: CoroutineScope,
private val observeFileSharingStatus: ObserveFileSharingStatusUseCase,
private val validateAssetMimeTypeUseCase: ValidateAssetMimeTypeUseCase,
private val dispatcher: KaliumDispatcher,
) : ScheduleNewAssetMessageUseCase {

Expand All @@ -122,6 +126,16 @@ internal class ScheduleNewAssetMessageUseCaseImpl(
assetHeight: Int?,
audioLengthInMs: Long
): ScheduleNewAssetMessageResult {
observeFileSharingStatus().first().also {
when(it.state) {
FileSharingStatus.Value.Disabled -> return ScheduleNewAssetMessageResult.Failure.DisabledByTeam
FileSharingStatus.Value.EnabledAll -> { /* no-op*/ }
is FileSharingStatus.Value.EnabledSome -> if(!validateAssetMimeTypeUseCase(assetMimeType, it.state.allowedType)) {
return ScheduleNewAssetMessageResult.Failure.RestrictedFileType
}
}
}

slowSyncRepository.slowSyncStatus.first {
it is SlowSyncStatus.Complete
}
Expand Down Expand Up @@ -174,7 +188,7 @@ internal class ScheduleNewAssetMessageUseCaseImpl(
}
}
}.fold({
ScheduleNewAssetMessageResult.Failure(it)
ScheduleNewAssetMessageResult.Failure.Generic(it)
}, { (_, message) ->
ScheduleNewAssetMessageResult.Success(message.id)
})
Expand Down Expand Up @@ -345,9 +359,13 @@ internal class ScheduleNewAssetMessageUseCaseImpl(
}
}

sealed class ScheduleNewAssetMessageResult {
class Success(val messageId: String) : ScheduleNewAssetMessageResult()
class Failure(val coreFailure: CoreFailure) : ScheduleNewAssetMessageResult()
sealed interface ScheduleNewAssetMessageResult {
data class Success(val messageId: String) : ScheduleNewAssetMessageResult
sealed interface Failure : ScheduleNewAssetMessageResult {
data class Generic(val coreFailure: CoreFailure) : Failure
data object DisabledByTeam: Failure
data object RestrictedFileType: Failure
}
}

private data class AssetMessageMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package com.wire.kalium.logic.feature.message

import com.wire.kalium.logic.cache.SelfConversationIdProvider
import com.wire.kalium.logic.configuration.UserConfigDataSource
import com.wire.kalium.logic.configuration.UserConfigRepository
import com.wire.kalium.logic.data.asset.AssetRepository
import com.wire.kalium.logic.data.client.ClientRepository
import com.wire.kalium.logic.data.client.MLSClientProvider
Expand Down Expand Up @@ -48,6 +50,8 @@ import com.wire.kalium.logic.feature.asset.UpdateAssetMessageDownloadStatusUseCa
import com.wire.kalium.logic.feature.asset.UpdateAssetMessageDownloadStatusUseCaseImpl
import com.wire.kalium.logic.feature.asset.UpdateAssetMessageUploadStatusUseCase
import com.wire.kalium.logic.feature.asset.UpdateAssetMessageUploadStatusUseCaseImpl
import com.wire.kalium.logic.feature.asset.ValidateAssetMimeTypeUseCase
import com.wire.kalium.logic.feature.asset.ValidateAssetMimeTypeUseCaseImpl
import com.wire.kalium.logic.feature.message.composite.SendButtonActionMessageUseCase
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsReceiverUseCaseImpl
import com.wire.kalium.logic.feature.message.ephemeral.DeleteEphemeralMessageForSelfUserAsSenderUseCaseImpl
Expand All @@ -58,6 +62,8 @@ import com.wire.kalium.logic.feature.message.ephemeral.EphemeralMessageDeletionH
import com.wire.kalium.logic.feature.selfDeletingMessages.ObserveSelfDeletionTimerSettingsForConversationUseCase
import com.wire.kalium.logic.feature.sessionreset.ResetSessionUseCase
import com.wire.kalium.logic.feature.sessionreset.ResetSessionUseCaseImpl
import com.wire.kalium.logic.feature.user.ObserveFileSharingStatusUseCase
import com.wire.kalium.logic.feature.user.ObserveFileSharingStatusUseCaseImpl
import com.wire.kalium.logic.sync.SyncManager
import com.wire.kalium.logic.util.MessageContentEncoder
import com.wire.kalium.util.KaliumDispatcher
Expand Down Expand Up @@ -89,6 +95,7 @@ class MessageScope internal constructor(
private val protoContentMapper: ProtoContentMapper,
private val observeSelfDeletingMessages: ObserveSelfDeletionTimerSettingsForConversationUseCase,
private val messageMetadataRepository: MessageMetadataRepository,
private val observeFileSharingStatusUseCase: ObserveFileSharingStatusUseCase,
private val scope: CoroutineScope,
internal val dispatcher: KaliumDispatcher = KaliumDispatcherImpl
) {
Expand All @@ -113,6 +120,9 @@ class MessageScope internal constructor(
protoContentMapper = protoContentMapper
)

private val validateAssetMimeTypeUseCase: ValidateAssetMimeTypeUseCase
get() = ValidateAssetMimeTypeUseCaseImpl()

private val messageContentEncoder = MessageContentEncoder()
private val messageSendingInterceptor: MessageSendingInterceptor
get() = MessageSendingInterceptorImpl(messageContentEncoder, messageRepository)
Expand Down Expand Up @@ -204,6 +214,8 @@ class MessageScope internal constructor(
userPropertyRepository,
observeSelfDeletingMessages,
scope,
observeFileSharingStatusUseCase,
validateAssetMimeTypeUseCase,
dispatcher
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.wire.kalium.cryptography.utils.SHA256Key
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.NetworkFailure
import com.wire.kalium.logic.StorageFailure
import com.wire.kalium.logic.configuration.FileSharingStatus
import com.wire.kalium.logic.data.asset.AssetRepository
import com.wire.kalium.logic.data.asset.FakeKaliumFileSystem
import com.wire.kalium.logic.data.asset.UploadedAssetId
Expand All @@ -40,6 +41,7 @@ import com.wire.kalium.logic.feature.message.MessageSendFailureHandler
import com.wire.kalium.logic.feature.message.MessageSender
import com.wire.kalium.logic.feature.selfDeletingMessages.ObserveSelfDeletionTimerSettingsForConversationUseCase
import com.wire.kalium.logic.feature.selfDeletingMessages.SelfDeletionTimer
import com.wire.kalium.logic.feature.user.ObserveFileSharingStatusUseCase
import com.wire.kalium.logic.framework.TestAsset.dummyUploadedAssetId
import com.wire.kalium.logic.framework.TestAsset.mockedLongAssetData
import com.wire.kalium.logic.functional.Either
Expand Down Expand Up @@ -92,6 +94,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand Down Expand Up @@ -125,6 +128,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withDeleteAssetLocally()
.withObserveMessageVisibility()
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand Down Expand Up @@ -158,6 +162,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand Down Expand Up @@ -198,6 +203,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand Down Expand Up @@ -248,6 +254,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand Down Expand Up @@ -298,6 +305,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand Down Expand Up @@ -338,6 +346,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand Down Expand Up @@ -386,6 +395,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand Down Expand Up @@ -441,6 +451,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand Down Expand Up @@ -486,6 +497,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand Down Expand Up @@ -529,6 +541,7 @@ class ScheduleNewAssetMessageUseCaseTest {
.withSelfDeleteTimer(SelfDeletionTimer.Enabled(expectedDuration))
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledAll)
.arrange()

// When
Expand All @@ -555,6 +568,111 @@ class ScheduleNewAssetMessageUseCaseTest {
})
}

@Test
fun givenFileSendingRestrictedByTeam_whenSending_thenReturnDisabledByTeam() = runTest {
// Given
val assetToSend = mockedLongAssetData()
val assetName = "some-asset.txt"
val inputDataPath = fakeKaliumFileSystem.providePersistentAssetPath(assetName)
val conversationId = ConversationId("some-convo-id", "some-domain-id")
val (_, sendAssetUseCase) = Arrangement(this)
.withStoredData(assetToSend, inputDataPath)
.withObserveFileSharingStatusResult(FileSharingStatus.Value.Disabled)
.arrange()

// When
val result = sendAssetUseCase.invoke(
conversationId = conversationId,
assetDataPath = inputDataPath,
assetDataSize = assetToSend.size.toLong(),
assetName = assetName,
assetMimeType = "text/plain",
assetWidth = null,
assetHeight = null,
audioLengthInMs = 0
)
advanceUntilIdle()

// Then
assertTrue(result is ScheduleNewAssetMessageResult.Failure.DisabledByTeam)
}

@Test
fun givenAseetMimeTypeRestricted_whenSending_thenReturnRestrictedFileType() = runTest {
// Given
val assetToSend = mockedLongAssetData()
val assetName = "some-asset.txt"
val inputDataPath = fakeKaliumFileSystem.providePersistentAssetPath(assetName)
val conversationId = ConversationId("some-convo-id", "some-domain-id")
val (arrangement, sendAssetUseCase) = Arrangement(this)
.withStoredData(assetToSend, inputDataPath)
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledSome(listOf("png")))
.withValidateAsseMimeTypeResult(false)
.arrange()

// When
val result = sendAssetUseCase.invoke(
conversationId = conversationId,
assetDataPath = inputDataPath,
assetDataSize = assetToSend.size.toLong(),
assetName = assetName,
assetMimeType = "text/plain",
assetWidth = null,
assetHeight = null,
audioLengthInMs = 0
)
advanceUntilIdle()

// Then
assertTrue(result is ScheduleNewAssetMessageResult.Failure.RestrictedFileType)

verify(arrangement.validateAssetMimeTypeUseCase)
.function(arrangement.validateAssetMimeTypeUseCase::invoke)
.with(eq("text/plain"), eq(listOf("png")))
.wasInvoked(exactly = once)
}

@Test
fun givenAssetMimeTypeRestrictedAndFileAllowed_whenSending_thenReturnSendTheFile() = runTest(testDispatcher.default) {
// Given
val assetToSend = mockedLongAssetData()
val assetName = "some-asset.txt"
val inputDataPath = fakeKaliumFileSystem.providePersistentAssetPath(assetName)
val expectedAssetId = dummyUploadedAssetId
val expectedAssetSha256 = SHA256Key("some-asset-sha-256".toByteArray())
val conversationId = ConversationId("some-convo-id", "some-domain-id")
val (arrangement, sendAssetUseCase) = Arrangement(this)
.withStoredData(assetToSend, inputDataPath)
.withSuccessfulResponse(expectedAssetId, expectedAssetSha256)
.withObserveFileSharingStatusResult(FileSharingStatus.Value.EnabledSome(listOf("png")))
.withValidateAsseMimeTypeResult(true)
.withSelfDeleteTimer(SelfDeletionTimer.Disabled)
.withObserveMessageVisibility()
.withDeleteAssetLocally()
.arrange()

// When
val result = sendAssetUseCase.invoke(
conversationId = conversationId,
assetDataPath = inputDataPath,
assetDataSize = assetToSend.size.toLong(),
assetName = assetName,
assetMimeType = "image/png",
assetWidth = null,
assetHeight = null,
audioLengthInMs = 0
)
advanceUntilIdle()

// Then
assertTrue(result is ScheduleNewAssetMessageResult.Success)

verify(arrangement.validateAssetMimeTypeUseCase)
.function(arrangement.validateAssetMimeTypeUseCase::invoke)
.with(eq("image/png"), eq(listOf("png")))
.wasInvoked(exactly = once)
}

private class Arrangement(val coroutineScope: CoroutineScope) {

@Mock
Expand Down Expand Up @@ -587,6 +705,12 @@ class ScheduleNewAssetMessageUseCaseTest {
@Mock
private val messageRepository: MessageRepository = mock(MessageRepository::class)

@Mock
val validateAssetMimeTypeUseCase: ValidateAssetMimeTypeUseCase = mock(ValidateAssetMimeTypeUseCase::class)

@Mock
val observerFileSharingStatusUseCase: ObserveFileSharingStatusUseCase = mock(ObserveFileSharingStatusUseCase::class)

val someClientId = ClientId("some-client-id")

val completeStateFlow = MutableStateFlow<SlowSyncStatus>(SlowSyncStatus.Complete).asStateFlow()
Expand All @@ -596,6 +720,20 @@ class ScheduleNewAssetMessageUseCaseTest {
withToggleReadReceiptsStatus()
}

fun withValidateAsseMimeTypeResult(result: Boolean) = apply {
given(validateAssetMimeTypeUseCase)
.function(validateAssetMimeTypeUseCase::invoke)
.whenInvokedWith(any(), any())
.thenReturn(result)
}

fun withObserveFileSharingStatusResult(result: FileSharingStatus.Value) = apply {
given(observerFileSharingStatusUseCase)
.function(observerFileSharingStatusUseCase::invoke)
.whenInvoked()
.thenReturn(flowOf(FileSharingStatus(result, false)))
}

fun withToggleReadReceiptsStatus(enabled: Boolean = false) = apply {
given(userPropertyRepository)
.suspendFunction(userPropertyRepository::getReadReceiptsStatus)
Expand Down Expand Up @@ -785,6 +923,8 @@ class ScheduleNewAssetMessageUseCaseTest {
userPropertyRepository,
observeSelfDeletionTimerSettingsForConversation,
coroutineScope,
observerFileSharingStatusUseCase,
validateAssetMimeTypeUseCase,
testDispatcher
)
}
Expand Down
Loading