Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RustCrypto.getCrossSigningStatus: check the client is not stopped #3682

Merged
merged 2 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions spec/unit/rust-crypto/rust-crypto.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,44 @@ describe("RustCrypto", () => {
await expect(rustCrypto.getCrossSigningKeyId()).resolves.toBe(null);
});

describe("getCrossSigningStatus", () => {
it("returns sensible values on a default client", async () => {
const secretStorage = {
isStored: jest.fn().mockResolvedValue(null),
} as unknown as Mocked<ServerSideSecretStorage>;
const rustCrypto = await makeTestRustCrypto(undefined, undefined, undefined, secretStorage);

const result = await rustCrypto.getCrossSigningStatus();

expect(secretStorage.isStored).toHaveBeenCalledWith("m.cross_signing.master");
expect(result).toEqual({
privateKeysCachedLocally: {
masterKey: false,
selfSigningKey: false,
userSigningKey: false,
},
privateKeysInSecretStorage: false,
publicKeysOnDevice: false,
});
});

it("throws if `stop` is called mid-call", async () => {
const secretStorage = {
isStored: jest.fn().mockResolvedValue(null),
} as unknown as Mocked<ServerSideSecretStorage>;
const rustCrypto = await makeTestRustCrypto(undefined, undefined, undefined, secretStorage);

// start the call off
const result = rustCrypto.getCrossSigningStatus();

// call `.stop`
rustCrypto.stop();

// getCrossSigningStatus should abort
await expect(result).rejects.toEqual(new Error("MatrixClient has been stopped"));
});
});

it("bootstrapCrossSigning delegates to CrossSigningIdentity", async () => {
const rustCrypto = await makeTestRustCrypto();
const mockCrossSigningIdentity = {
Expand Down
7 changes: 6 additions & 1 deletion src/crypto-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ export interface CryptoApi {
* return true.
*
* @returns True if cross-signing is ready to be used on this device
*
* @throws May throw {@link ClientStoppedError} if the `MatrixClient` is stopped before or during the call.
*/
isCrossSigningReady(): Promise<boolean>;

Expand Down Expand Up @@ -234,7 +236,10 @@ export interface CryptoApi {
/**
* Get the status of our cross-signing keys.
*
* @returns The current status of cross-signing keys: whether we have public and private keys cached locally, and whether the private keys are in secret storage.
* @returns The current status of cross-signing keys: whether we have public and private keys cached locally, and
* whether the private keys are in secret storage.
*
* @throws May throw {@link ClientStoppedError} if the `MatrixClient` is stopped before or during the call.
*/
getCrossSigningStatus(): Promise<CrossSigningStatus>;

Expand Down
14 changes: 14 additions & 0 deletions src/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,17 @@ export class KeySignatureUploadError extends Error {
super(message);
}
}

/**
* It is invalid to call most methods once {@link MatrixClient#stopClient} has been called.
*
* This error will be thrown if you attempt to do so.
*
* {@link MatrixClient#stopClient} itself is an exception to this: it may safely be called multiple times on the same
* instance.
*/
export class ClientStoppedError extends Error {
public constructor() {
super("MatrixClient has been stopped");
}
}
20 changes: 18 additions & 2 deletions src/rust-crypto/rust-crypto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import { TypedEventEmitter } from "../models/typed-event-emitter";
import { RustBackupCryptoEventMap, RustBackupCryptoEvents, RustBackupManager } from "./backup";
import { TypedReEmitter } from "../ReEmitter";
import { randomString } from "../randomstring";
import { ClientStoppedError } from "../errors";

const ALL_VERIFICATION_METHODS = ["m.sas.v1", "m.qr_code.scan.v1", "m.qr_code.show.v1", "m.reciprocate.v1"];

Expand Down Expand Up @@ -138,6 +139,20 @@ export class RustCrypto extends TypedEventEmitter<RustCryptoEvents, RustCryptoEv
);
}

/**
* Return the OlmMachine only if {@link RustCrypto#stop} has not been called.
*
* This allows us to better handle race conditions where the client is stopped before or during a crypto API call.
*
* @throws ClientStoppedError if {@link RustCrypto#stop} has been called.
*/
private getOlmMachineOrThrow(): RustSdkCryptoJs.OlmMachine {
if (this.stopped) {
throw new ClientStoppedError();
}
return this.olmMachine;
}

///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//
// CryptoBackend implementation
Expand Down Expand Up @@ -635,16 +650,17 @@ export class RustCrypto extends TypedEventEmitter<RustCryptoEvents, RustCryptoEv
* Implementation of {@link CryptoApi#getCrossSigningStatus}
*/
public async getCrossSigningStatus(): Promise<CrossSigningStatus> {
const userIdentity: RustSdkCryptoJs.OwnUserIdentity | null = await this.olmMachine.getIdentity(
const userIdentity: RustSdkCryptoJs.OwnUserIdentity | null = await this.getOlmMachineOrThrow().getIdentity(
new RustSdkCryptoJs.UserId(this.userId),
);

const publicKeysOnDevice =
Boolean(userIdentity?.masterKey) &&
Boolean(userIdentity?.selfSigningKey) &&
Boolean(userIdentity?.userSigningKey);
const privateKeysInSecretStorage = await secretStorageContainsCrossSigningKeys(this.secretStorage);
const crossSigningStatus: RustSdkCryptoJs.CrossSigningStatus | null =
await this.olmMachine.crossSigningStatus();
await this.getOlmMachineOrThrow().crossSigningStatus();

return {
publicKeysOnDevice,
Expand Down
Loading