diff --git a/spec/integ/crypto/megolm-backup.spec.ts b/spec/integ/crypto/megolm-backup.spec.ts index 11e1c2eb705..18ef5f73eb3 100644 --- a/spec/integ/crypto/megolm-backup.spec.ts +++ b/spec/integ/crypto/megolm-backup.spec.ts @@ -45,6 +45,7 @@ import { KeyBackupInfo, KeyBackupSession } from "../../../src/crypto-api/keyback import { IKeyBackup } from "../../../src/crypto/backup"; import { flushPromises } from "../../test-utils/flushPromises"; import { defer, IDeferred } from "../../../src/utils"; +import { ImportRoomKeysOpts } from "../../../src/crypto-api"; const ROOM_ID = testData.TEST_ROOM_ID; @@ -298,6 +299,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe describe("recover from backup", () => { let aliceCrypto: CryptoApi; + let importMockImpl: jest.Mock; beforeEach(async () => { fetchMock.get("path:/_matrix/client/v3/room_keys/version", testData.SIGNED_BACKUP_DATA); @@ -309,6 +311,20 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe // tell Alice to trust the dummy device that signed the backup await waitForDeviceList(); await aliceCrypto.setDeviceVerified(testData.TEST_USER_ID, testData.TEST_DEVICE_ID); + + importMockImpl = jest.fn().mockImplementation((keys: IMegolmSessionData[], opts?: ImportRoomKeysOpts) => { + // need to report progress + if (opts?.progressCallback) { + opts.progressCallback({ + stage: "load_keys", + successes: keys.length, + failures: 0, + total: keys.length, + }); + } + }); + // @ts-ignore - mock a private method for testing purpose + aliceCrypto.importBackedUpRoomKeys = importMockImpl; }); it("can restore from backup (Curve25519 version)", async function () { @@ -384,10 +400,6 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe } it("Should import full backup in chunks", async function () { - const importMockImpl = jest.fn(); - // @ts-ignore - mock a private method for testing purpose - aliceCrypto.importBackedUpRoomKeys = importMockImpl; - // We need several rooms with several sessions to test chunking const { response, expectedTotal } = createBackupDownloadResponse([45, 300, 345, 12, 130]); @@ -446,7 +458,7 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe throw new Error("test error"); }) // Ok for other chunks - .mockResolvedValue(undefined); + .mockImplementation(importMockImpl); const { response, expectedTotal } = createBackupDownloadResponse([100, 300]); @@ -485,9 +497,6 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe }); it("Should continue if some keys fails to decrypt", async function () { - // @ts-ignore - mock a private method for testing purpose - aliceCrypto.importBackedUpRoomKeys = jest.fn(); - const decryptionFailureCount = 2; const mockDecryptor = { @@ -527,6 +536,45 @@ describe.each(Object.entries(CRYPTO_BACKENDS))("megolm-keys backup (%s)", (backe expect(result.imported).toStrictEqual(expectedTotal - decryptionFailureCount); }); + it("Should report failures when decryption works but import fails", async function () { + // @ts-ignore - mock a private method for testing purpose + aliceCrypto.importBackedUpRoomKeys = jest + .fn() + .mockImplementationOnce((keys: IMegolmSessionData[], opts?: ImportRoomKeysOpts) => { + // report 10 failures to import + opts!.progressCallback!({ + stage: "load_keys", + successes: 20, + failures: 10, + total: 30, + }); + return Promise.resolve(); + }) + // Ok for other chunks + .mockResolvedValue(importMockImpl); + + const { response, expectedTotal } = createBackupDownloadResponse([30]); + + fetchMock.get("express:/_matrix/client/v3/room_keys/keys", response); + + const check = await aliceCrypto.checkKeyBackupAndEnable(); + + const progressCallback = jest.fn(); + const result = await aliceClient.restoreKeyBackupWithRecoveryKey( + testData.BACKUP_DECRYPTION_KEY_BASE58, + undefined, + undefined, + check!.backupInfo!, + { + progressCallback, + }, + ); + + expect(result.total).toStrictEqual(expectedTotal); + // A chunk failed to import + expect(result.imported).toStrictEqual(20); + }); + it("recover specific session from backup", async function () { fetchMock.get( "express:/_matrix/client/v3/room_keys/keys/:room_id/:session_id", diff --git a/src/client.ts b/src/client.ts index 3192e144924..398ee2254b2 100644 --- a/src/client.ts +++ b/src/client.ts @@ -219,7 +219,13 @@ import { LocalNotificationSettings } from "./@types/local_notifications"; import { buildFeatureSupportMap, Feature, ServerSupport } from "./feature"; import { BackupDecryptor, CryptoBackend } from "./common-crypto/CryptoBackend"; import { RUST_SDK_STORE_PREFIX } from "./rust-crypto/constants"; -import { BootstrapCrossSigningOpts, CrossSigningKeyInfo, CryptoApi, ImportRoomKeysOpts } from "./crypto-api"; +import { + BootstrapCrossSigningOpts, + CrossSigningKeyInfo, + CryptoApi, + ImportRoomKeyProgressData, + ImportRoomKeysOpts, +} from "./crypto-api"; import { DeviceInfoMap } from "./crypto/DeviceList"; import { AddSecretStorageKeyOpts, @@ -3923,10 +3929,18 @@ export class MatrixClient extends TypedEventEmitter { // We have a chunk of decrypted keys: import them try { + let success = 0; + let failures = 0; + const partialProgress = (stage: ImportRoomKeyProgressData): void => { + success = stage.successes ?? 0; + failures = stage.failures ?? 0; + }; await this.cryptoBackend!.importBackedUpRoomKeys(chunk, { untrusted, + progressCallback: partialProgress, }); - totalImported += chunk.length; + totalImported += success; + totalFailures += failures; } catch (e) { totalFailures += chunk.length; // We failed to import some keys, but we should still try to import the rest? @@ -3953,11 +3967,25 @@ export class MatrixClient extends TypedEventEmitter { + success = stage.successes ?? 0; + failures = stage.failures ?? 0; + }; + await this.cryptoBackend!.importBackedUpRoomKeys(chunk, { + untrusted, + progressCallback: partialProgress, + }); + totalImported += success; + totalFailures += failures; + } catch (e) { + totalFailures += keys.length; + // We failed to import some keys, but we should still try to import the rest? + // Log the error and continue + logger.error("Error importing keys from backup", e); + } } else { totalKeyCount = 1; try { @@ -3973,6 +4001,7 @@ export class MatrixClient extends TypedEventEmitter { const importOpt: ImportRoomKeyProgressData = { @@ -235,6 +235,17 @@ export class RustBackupManager extends TypedEventEmitter | null = null;