From 9913fbd48caa4f9d06aa26ce71d7c214bec45aa6 Mon Sep 17 00:00:00 2001 From: Toomas Vahter Date: Thu, 1 Aug 2024 15:41:33 +0300 Subject: [PATCH] Fix rare crashes on logout by updating contexts after running batch delete --- CHANGELOG.md | 4 +- .../Database/DatabaseContainer.swift | 46 +++- .../Database/DatabaseContainer_Tests.swift | 197 +++++++++++------- 3 files changed, 159 insertions(+), 88 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec55282a178..4fbe59b665f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). # Upcoming -### 🔄 Changed +## StreamChat +### 🐞 Fixed +- Fix rare crashes when deleting local database content on logout [#3355](https://github.com/GetStream/stream-chat-swift/pull/3355) # [4.61.0](https://github.com/GetStream/stream-chat-swift/releases/tag/4.61.0) _July 30, 2024_ diff --git a/Sources/StreamChat/Database/DatabaseContainer.swift b/Sources/StreamChat/Database/DatabaseContainer.swift index d9c28a9b8cb..f413fe80313 100644 --- a/Sources/StreamChat/Database/DatabaseContainer.swift +++ b/Sources/StreamChat/Database/DatabaseContainer.swift @@ -77,6 +77,7 @@ class DatabaseContainer: NSPersistentContainer { return context }() + private var canWriteData = true private var stateLayerContextRefreshObservers = [NSObjectProtocol]() private var loggerNotificationObserver: NSObjectProtocol? private let localCachingSettings: ChatClientConfig.LocalCaching? @@ -217,6 +218,12 @@ class DatabaseContainer: NSPersistentContainer { func write(_ actions: @escaping (DatabaseSession) throws -> Void, completion: @escaping (Error?) -> Void) { writableContext.perform { log.debug("Starting a database session.", subsystems: .database) + guard self.canWriteData else { + log.debug("Discarding write attempt.", subsystems: .database) + completion(nil) + return + } + do { FetchCache.clear() try actions(self.writableContext) @@ -299,22 +306,39 @@ class DatabaseContainer: NSPersistentContainer { context.reset() } } - - writableContext.performAndWait { [weak self] in - let entityNames = self?.managedObjectModel.entities.compactMap(\.name) - var deleteError: Error? - entityNames?.forEach { [weak self] entityName in - let deleteFetch = NSFetchRequest(entityName: entityName) - let deleteRequest = NSBatchDeleteRequest(fetchRequest: deleteFetch) + + let entityNames = managedObjectModel.entities.compactMap(\.name) + writableContext.perform { [weak self] in + self?.canWriteData = false + let requests = entityNames + .map { NSFetchRequest(entityName: $0) } + .map { fetchRequest in + let batchDelete = NSBatchDeleteRequest(fetchRequest: fetchRequest) + batchDelete.resultType = .resultTypeObjectIDs + return batchDelete + } + var lastEncounteredError: Error? + var deletedObjectIds = [NSManagedObjectID]() + for request in requests { do { - try self?.writableContext.execute(deleteRequest) - try self?.writableContext.save() + let result = try self?.writableContext.execute(request) as? NSBatchDeleteResult + if let objectIds = result?.result as? [NSManagedObjectID] { + deletedObjectIds.append(contentsOf: objectIds) + } } catch { log.error("Batch delete request failed with error \(error)") - deleteError = error + lastEncounteredError = error } } - completion?(deleteError) + if !deletedObjectIds.isEmpty, let contexts = self?.allContext { + log.debug("Merging \(deletedObjectIds.count) deletions to contexts", subsystems: .database) + NSManagedObjectContext.mergeChanges( + fromRemoteContextSave: [NSDeletedObjectsKey: deletedObjectIds], + into: contexts + ) + } + self?.canWriteData = true + completion?(lastEncounteredError) } } diff --git a/Tests/StreamChatTests/Database/DatabaseContainer_Tests.swift b/Tests/StreamChatTests/Database/DatabaseContainer_Tests.swift index 7fb07efe9df..97a01bd3cb6 100644 --- a/Tests/StreamChatTests/Database/DatabaseContainer_Tests.swift +++ b/Tests/StreamChatTests/Database/DatabaseContainer_Tests.swift @@ -62,86 +62,12 @@ final class DatabaseContainer_Tests: XCTestCase { wait(for: [errorPathExpectation], timeout: defaultTimeout) } - + func test_removingAllData() throws { let container = DatabaseContainer(kind: .inMemory) // // Create data for all our entities in the DB - try container.writeSynchronously { session in - let cid = ChannelId.unique - let currentUserId = UserId.unique - try session.saveChannel(payload: self.dummyPayload(with: cid), query: .init(filter: .nonEmpty), cache: nil) - try session.saveChannel(payload: self.dummyPayload(with: .unique), query: nil, cache: nil) - try session.saveChannel(payload: self.dummyPayload(with: .unique), query: nil, cache: nil) - try session.saveMember(payload: .dummy(), channelId: cid, query: .init(cid: cid), cache: nil) - try session.saveCurrentUser(payload: .dummy(userId: currentUserId, role: .admin)) - try session.saveCurrentDevice("123") - try session.saveChannelMute(payload: .init( - mutedChannel: .dummy(cid: cid), - user: .dummy(userId: currentUserId), - createdAt: .unique, - updatedAt: .unique - )) - session.saveThreadList( - payload: ThreadListPayload( - threads: [ - self.dummyThreadPayload( - threadParticipants: [self.dummyThreadParticipantPayload(), self.dummyThreadParticipantPayload()], - read: [self.dummyThreadReadPayload(), self.dummyThreadReadPayload()] - ), - self.dummyThreadPayload() - ], - next: nil - ) - ) - try session.saveUser(payload: .dummy(userId: .unique), query: .user(withID: currentUserId), cache: nil) - try session.saveUser(payload: .dummy(userId: .unique)) - let messages: [MessagePayload] = [ - .dummy( - reactionGroups: [ - "like": MessageReactionGroupPayload( - sumScores: 1, - count: 1, - firstReactionAt: .unique, - lastReactionAt: .unique - ) - ], - moderationDetails: .init(originalText: "yo", action: "spam") - ), - .dummy( - poll: self.dummyPollPayload( - createdById: currentUserId, - id: "pollId", - options: [self.dummyPollOptionPayload(id: "test")], - latestVotesByOption: ["test": [self.dummyPollVotePayload(pollId: "pollId")]], - user: .dummy(userId: currentUserId) - ) - ), - .dummy(), - .dummy(), - .dummy() - ] - try messages.forEach { - let message = try session.saveMessage(payload: $0, for: cid, syncOwnReactions: true, cache: nil) - try session.saveReaction( - payload: .dummy(messageId: message.id, user: .dummy(userId: currentUserId)), - query: .init(messageId: message.id, filter: .equal(.authorId, to: currentUserId)), - cache: nil - ) - } - try session.saveMessage( - payload: .dummy(channel: .dummy(cid: cid)), - for: MessageSearchQuery(channelFilter: .noTeam, messageFilter: .withoutAttachments), - cache: nil - ) - try session.savePollVote( - payload: self.dummyPollVotePayload(pollId: "pollId"), - query: .init(pollId: "pollId", optionId: "test", filter: .contains(.pollId, value: "pollId")), - cache: nil - ) - - QueuedRequestDTO.createRequest(date: .unique, endpoint: Data(), context: container.writableContext) - } + try writeALotOfData(to: container) // Fetch the data from all out entities let totalEntities = container.managedObjectModel.entities.count @@ -193,6 +119,45 @@ final class DatabaseContainer_Tests: XCTestCase { } } } + + func test_removingAllData_whileAnotherWrite() throws { + let container = DatabaseContainer(kind: .inMemory) + try writeALotOfData(to: container) + + // Schedule saving just before removing it all + container.write { session in + try session.saveChannel(payload: self.dummyPayload(with: .unique), query: nil, cache: nil) + } + + let expectation = XCTestExpectation(description: "Remove") + container.removeAllData { error in + XCTAssertNil(error) + expectation.fulfill() + } + + // Save just after triggering remove all + container.write { session in + try session.saveChannel(payload: self.dummyPayload(with: .unique), query: nil, cache: nil) + } + + wait(for: [expectation], timeout: defaultTimeout) + + let counts = try container.readSynchronously { session in + guard let context = session as? NSManagedObjectContext else { return [String: Int]() } + var counts = [String: Int]() + let requests = container.managedObjectModel.entities + .compactMap(\.name) + .map { NSFetchRequest(entityName: $0) } + for request in requests { + let count = try context.count(for: request) + counts[request.entityName!] = count + } + return counts + } + for count in counts { + XCTAssertEqual(0, count.value, count.key) + } + } func test_databaseContainer_callsResetEphemeralValues_onAllEphemeralValuesContainerEntities() throws { // Create a new on-disc database with the test data model @@ -365,4 +330,84 @@ final class DatabaseContainer_Tests: XCTestCase { XCTAssertEqual(database.backgroundReadOnlyContext.shouldShowShadowedMessages, shouldShowShadowedMessages) } } + + // MARK: - + + private func writeALotOfData(to container: DatabaseContainer) throws { + try container.writeSynchronously { session in + let cid = ChannelId.unique + let currentUserId = UserId.unique + try session.saveChannel(payload: self.dummyPayload(with: cid), query: .init(filter: .nonEmpty), cache: nil) + try session.saveChannel(payload: self.dummyPayload(with: .unique), query: nil, cache: nil) + try session.saveChannel(payload: self.dummyPayload(with: .unique), query: nil, cache: nil) + try session.saveMember(payload: .dummy(), channelId: cid, query: .init(cid: cid), cache: nil) + try session.saveCurrentUser(payload: .dummy(userId: currentUserId, role: .admin)) + try session.saveCurrentDevice("123") + try session.saveChannelMute(payload: .init( + mutedChannel: .dummy(cid: cid), + user: .dummy(userId: currentUserId), + createdAt: .unique, + updatedAt: .unique + )) + session.saveThreadList( + payload: ThreadListPayload( + threads: [ + self.dummyThreadPayload( + threadParticipants: [self.dummyThreadParticipantPayload(), self.dummyThreadParticipantPayload()], + read: [self.dummyThreadReadPayload(), self.dummyThreadReadPayload()] + ), + self.dummyThreadPayload() + ], + next: nil + ) + ) + try session.saveUser(payload: .dummy(userId: .unique), query: .user(withID: currentUserId), cache: nil) + try session.saveUser(payload: .dummy(userId: .unique)) + let messages: [MessagePayload] = [ + .dummy( + reactionGroups: [ + "like": MessageReactionGroupPayload( + sumScores: 1, + count: 1, + firstReactionAt: .unique, + lastReactionAt: .unique + ) + ], + moderationDetails: .init(originalText: "yo", action: "spam") + ), + .dummy( + poll: self.dummyPollPayload( + createdById: currentUserId, + id: "pollId", + options: [self.dummyPollOptionPayload(id: "test")], + latestVotesByOption: ["test": [self.dummyPollVotePayload(pollId: "pollId")]], + user: .dummy(userId: currentUserId) + ) + ), + .dummy(), + .dummy(), + .dummy() + ] + try messages.forEach { + let message = try session.saveMessage(payload: $0, for: cid, syncOwnReactions: true, cache: nil) + try session.saveReaction( + payload: .dummy(messageId: message.id, user: .dummy(userId: currentUserId)), + query: .init(messageId: message.id, filter: .equal(.authorId, to: currentUserId)), + cache: nil + ) + } + try session.saveMessage( + payload: .dummy(channel: .dummy(cid: cid)), + for: MessageSearchQuery(channelFilter: .noTeam, messageFilter: .withoutAttachments), + cache: nil + ) + try session.savePollVote( + payload: self.dummyPollVotePayload(pollId: "pollId"), + query: .init(pollId: "pollId", optionId: "test", filter: .contains(.pollId, value: "pollId")), + cache: nil + ) + + QueuedRequestDTO.createRequest(date: .unique, endpoint: Data(), context: container.writableContext) + } + } }