diff --git a/src/client-side-encryption/auto_encrypter.ts b/src/client-side-encryption/auto_encrypter.ts index edf731b92a..7111df25e5 100644 --- a/src/client-side-encryption/auto_encrypter.ts +++ b/src/client-side-encryption/auto_encrypter.ts @@ -398,7 +398,7 @@ export class AutoEncrypter { socketOptions: autoSelectSocketOptions(this._client.s.options) }); - return deserialize(await stateMachine.execute(this, context, options.timeoutContext), { + return deserialize(await stateMachine.execute(this, context, options), { promoteValues: false, promoteLongs: false }); @@ -419,11 +419,7 @@ export class AutoEncrypter { socketOptions: autoSelectSocketOptions(this._client.s.options) }); - return await stateMachine.execute( - this, - context, - options.timeoutContext?.csotEnabled() ? options.timeoutContext : undefined - ); + return await stateMachine.execute(this, context, options); } /** diff --git a/src/client-side-encryption/client_encryption.ts b/src/client-side-encryption/client_encryption.ts index 7482c513d3..487969cf4d 100644 --- a/src/client-side-encryption/client_encryption.ts +++ b/src/client-side-encryption/client_encryption.ts @@ -225,7 +225,7 @@ export class ClientEncryption { TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })); const dataKey = deserialize( - await stateMachine.execute(this, context, timeoutContext) + await stateMachine.execute(this, context, { timeoutContext }) ) as DataKey; const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString( @@ -293,7 +293,9 @@ export class ClientEncryption { resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }) ); - const { v: dataKeys } = deserialize(await stateMachine.execute(this, context, timeoutContext)); + const { v: dataKeys } = deserialize( + await stateMachine.execute(this, context, { timeoutContext }) + ); if (dataKeys.length === 0) { return {}; } @@ -696,7 +698,7 @@ export class ClientEncryption { ? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })) : undefined; - const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext)); + const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext })); return v; } @@ -780,7 +782,7 @@ export class ClientEncryption { this._timeoutMS != null ? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })) : undefined; - const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext)); + const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext })); return v; } } diff --git a/src/client-side-encryption/state_machine.ts b/src/client-side-encryption/state_machine.ts index d10776abe7..0541c7e7de 100644 --- a/src/client-side-encryption/state_machine.ts +++ b/src/client-side-encryption/state_machine.ts @@ -15,8 +15,15 @@ import { CursorTimeoutContext } from '../cursor/abstract_cursor'; import { getSocks, type SocksLib } from '../deps'; import { MongoOperationTimeoutError } from '../error'; import { type MongoClient, type MongoClientOptions } from '../mongo_client'; +import { type Abortable } from '../mongo_types'; import { Timeout, type TimeoutContext, TimeoutError } from '../timeout'; -import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils'; +import { + addAbortListener, + BufferPool, + kDispose, + MongoDBCollectionNamespace, + promiseWithResolvers +} from '../utils'; import { autoSelectSocketOptions, type DataKey } from './client_encryption'; import { MongoCryptError } from './errors'; import { type MongocryptdManager } from './mongocryptd_manager'; @@ -189,7 +196,7 @@ export class StateMachine { async execute( executor: StateMachineExecutable, context: MongoCryptContext, - timeoutContext?: TimeoutContext + options: { timeoutContext?: TimeoutContext } & Abortable ): Promise { const keyVaultNamespace = executor._keyVaultNamespace; const keyVaultClient = executor._keyVaultClient; @@ -214,7 +221,7 @@ export class StateMachine { metaDataClient, context.ns, filter, - timeoutContext + options ); if (collInfo) { context.addMongoOperationResponse(collInfo); @@ -235,9 +242,9 @@ export class StateMachine { // When we are using the shared library, we don't have a mongocryptd manager. const markedCommand: Uint8Array = mongocryptdManager ? await mongocryptdManager.withRespawn( - this.markCommand.bind(this, mongocryptdClient, context.ns, command, timeoutContext) + this.markCommand.bind(this, mongocryptdClient, context.ns, command, options) ) - : await this.markCommand(mongocryptdClient, context.ns, command, timeoutContext); + : await this.markCommand(mongocryptdClient, context.ns, command, options); context.addMongoOperationResponse(markedCommand); context.finishMongoOperation(); @@ -246,12 +253,7 @@ export class StateMachine { case MONGOCRYPT_CTX_NEED_MONGO_KEYS: { const filter = context.nextMongoOperation(); - const keys = await this.fetchKeys( - keyVaultClient, - keyVaultNamespace, - filter, - timeoutContext - ); + const keys = await this.fetchKeys(keyVaultClient, keyVaultNamespace, filter, options); if (keys.length === 0) { // See docs on EMPTY_V @@ -273,7 +275,7 @@ export class StateMachine { } case MONGOCRYPT_CTX_NEED_KMS: { - await Promise.all(this.requests(context, timeoutContext)); + await Promise.all(this.requests(context, options)); context.finishKMSRequests(); break; } @@ -315,11 +317,13 @@ export class StateMachine { * @param kmsContext - A C++ KMS context returned from the bindings * @returns A promise that resolves when the KMS reply has be fully parsed */ - async kmsRequest(request: MongoCryptKMSRequest, timeoutContext?: TimeoutContext): Promise { + async kmsRequest( + request: MongoCryptKMSRequest, + options: { timeoutContext?: TimeoutContext } & Abortable + ): Promise { const parsedUrl = request.endpoint.split(':'); const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT; - const socketOptions = autoSelectSocketOptions(this.options.socketOptions || {}); - const options: tls.ConnectionOptions & { + const socketOptions: tls.ConnectionOptions & { host: string; port: number; autoSelectFamily?: boolean; @@ -328,7 +332,7 @@ export class StateMachine { host: parsedUrl[0], servername: parsedUrl[0], port, - ...socketOptions + ...autoSelectSocketOptions(this.options.socketOptions || {}) }; const message = request.message; const buffer = new BufferPool(); @@ -363,7 +367,7 @@ export class StateMachine { throw error; } try { - await this.setTlsOptions(providerTlsOptions, options); + await this.setTlsOptions(providerTlsOptions, socketOptions); } catch (err) { throw onerror(err); } @@ -380,23 +384,25 @@ export class StateMachine { .once('close', () => rejectOnNetSocketError(onclose())) .once('connect', () => resolveOnNetSocketConnect()); + let abortListener; + try { if (this.options.proxyOptions && this.options.proxyOptions.proxyHost) { const netSocketOptions = { + ...socketOptions, host: this.options.proxyOptions.proxyHost, - port: this.options.proxyOptions.proxyPort || 1080, - ...socketOptions + port: this.options.proxyOptions.proxyPort || 1080 }; netSocket.connect(netSocketOptions); await willConnect; try { socks ??= loadSocks(); - options.socket = ( + socketOptions.socket = ( await socks.SocksClient.createConnection({ existing_socket: netSocket, command: 'connect', - destination: { host: options.host, port: options.port }, + destination: { host: socketOptions.host, port: socketOptions.port }, proxy: { // host and port are ignored because we pass existing_socket host: 'iLoveJavaScript', @@ -412,7 +418,7 @@ export class StateMachine { } } - socket = tls.connect(options, () => { + socket = tls.connect(socketOptions, () => { socket.write(message); }); @@ -422,6 +428,11 @@ export class StateMachine { resolve } = promiseWithResolvers(); + abortListener = addAbortListener(options.signal, error => { + destroySockets(); + rejectOnTlsSocketError(error); + }); + socket .once('error', err => rejectOnTlsSocketError(onerror(err))) .once('close', () => rejectOnTlsSocketError(onclose())) @@ -436,8 +447,11 @@ export class StateMachine { resolve(); } }); - await (timeoutContext?.csotEnabled() - ? Promise.all([willResolveKmsRequest, Timeout.expires(timeoutContext?.remainingTimeMS)]) + await (options.timeoutContext?.csotEnabled() + ? Promise.all([ + willResolveKmsRequest, + Timeout.expires(options.timeoutContext?.remainingTimeMS) + ]) : willResolveKmsRequest); } catch (error) { if (error instanceof TimeoutError) @@ -446,16 +460,17 @@ export class StateMachine { } finally { // There's no need for any more activity on this socket at this point. destroySockets(); + abortListener?.[kDispose](); } } - *requests(context: MongoCryptContext, timeoutContext?: TimeoutContext) { + *requests(context: MongoCryptContext, options: { timeoutContext?: TimeoutContext } & Abortable) { for ( let request = context.nextKMSRequest(); request != null; request = context.nextKMSRequest() ) { - yield this.kmsRequest(request, timeoutContext); + yield this.kmsRequest(request, options); } } @@ -516,14 +531,16 @@ export class StateMachine { client: MongoClient, ns: string, filter: Document, - timeoutContext?: TimeoutContext + options: { timeoutContext?: TimeoutContext } & Abortable ): Promise { const { db } = MongoDBCollectionNamespace.fromString(ns); const cursor = client.db(db).listCollections(filter, { promoteLongs: false, promoteValues: false, - timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol()) + timeoutContext: + options.timeoutContext && new CursorTimeoutContext(options.timeoutContext, Symbol()), + signal: options.signal }); // There is always exactly zero or one matching documents, so this should always exhaust the cursor @@ -547,17 +564,30 @@ export class StateMachine { client: MongoClient, ns: string, command: Uint8Array, - timeoutContext?: TimeoutContext + options: { timeoutContext?: TimeoutContext } & Abortable ): Promise { const { db } = MongoDBCollectionNamespace.fromString(ns); const bsonOptions = { promoteLongs: false, promoteValues: false }; const rawCommand = deserialize(command, bsonOptions); + const commandOptions: { + timeoutMS?: number; + signal?: AbortSignal; + } = { + timeoutMS: undefined, + signal: undefined + }; + + if (options.timeoutContext?.csotEnabled()) { + commandOptions.timeoutMS = options.timeoutContext.remainingTimeMS; + } + if (options.signal) { + commandOptions.signal = options.signal; + } + const response = await client.db(db).command(rawCommand, { ...bsonOptions, - ...(timeoutContext?.csotEnabled() - ? { timeoutMS: timeoutContext?.remainingTimeMS } - : undefined) + ...commandOptions }); return serialize(response, this.bsonOptions); @@ -575,17 +605,30 @@ export class StateMachine { client: MongoClient, keyVaultNamespace: string, filter: Uint8Array, - timeoutContext?: TimeoutContext + options: { timeoutContext?: TimeoutContext } & Abortable ): Promise> { const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(keyVaultNamespace); + const commandOptions: { + timeoutContext?: CursorTimeoutContext; + signal?: AbortSignal; + } = { + timeoutContext: undefined, + signal: undefined + }; + + if (options.timeoutContext != null) { + commandOptions.timeoutContext = new CursorTimeoutContext(options.timeoutContext, Symbol()); + } + if (options.signal != null) { + commandOptions.signal = options.signal; + } + return client .db(dbName) .collection(collectionName, { readConcern: { level: 'majority' } }) - .find(deserialize(filter), { - timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol()) - }) + .find(deserialize(filter), commandOptions) .toArray(); } } diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 6df81b34d9..6e3f04e59a 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -33,7 +33,7 @@ import { import type { ServerApi, SupportedNodeConnectionOptions } from '../mongo_client'; import { type MongoClientAuthProviders } from '../mongo_client_auth_providers'; import { MongoLoggableComponent, type MongoLogger, SeverityLevel } from '../mongo_logger'; -import { type CancellationToken, TypedEventEmitter } from '../mongo_types'; +import { type Abortable, type CancellationToken, TypedEventEmitter } from '../mongo_types'; import { ReadPreference, type ReadPreferenceLike } from '../read_preference'; import { ServerType } from '../sdam/common'; import { applySession, type ClientSession, updateSessionFromResponse } from '../sessions'; @@ -438,7 +438,7 @@ export class Connection extends TypedEventEmitter { private async *sendWire( message: WriteProtocolMessageType, - options: CommandOptions, + options: CommandOptions & Abortable, responseType?: MongoDBResponseConstructor ): AsyncGenerator { this.throwIfAborted(); @@ -453,7 +453,8 @@ export class Connection extends TypedEventEmitter { await this.writeCommand(message, { agreedCompressor: this.description.compressor ?? 'none', zlibCompressionLevel: this.description.zlibCompressionLevel, - timeoutContext: options.timeoutContext + timeoutContext: options.timeoutContext, + signal: options.signal }); if (options.noResponse || message.moreToCome) { @@ -676,7 +677,7 @@ export class Connection extends TypedEventEmitter { agreedCompressor?: CompressorName; zlibCompressionLevel?: number; timeoutContext?: TimeoutContext; - } + } & Abortable ): Promise { const finalCommand = options.agreedCompressor === 'none' || !OpCompressedRequest.canCompress(command) @@ -701,7 +702,7 @@ export class Connection extends TypedEventEmitter { if (this.socket.write(buffer)) return; - const drainEvent = once(this.socket, 'drain'); + const drainEvent = once(this.socket, 'drain', { signal: options.signal }); const timeout = options?.timeoutContext?.timeoutForSocketWrite; if (timeout) { try { @@ -729,9 +730,11 @@ export class Connection extends TypedEventEmitter { * * Note that `for-await` loops call `return` automatically when the loop is exited. */ - private async *readMany(options: { - timeoutContext?: TimeoutContext; - }): AsyncGenerator { + private async *readMany( + options: { + timeoutContext?: TimeoutContext; + } & Abortable + ): AsyncGenerator { try { this.dataEvents = onData(this.messageStream, options); this.messageStream.resume(); diff --git a/src/cmap/connection_pool.ts b/src/cmap/connection_pool.ts index bb2069de84..dc2b79658d 100644 --- a/src/cmap/connection_pool.ts +++ b/src/cmap/connection_pool.ts @@ -25,10 +25,18 @@ import { MongoRuntimeError, MongoServerError } from '../error'; -import { CancellationToken, TypedEventEmitter } from '../mongo_types'; +import { type Abortable, CancellationToken, TypedEventEmitter } from '../mongo_types'; import type { Server } from '../sdam/server'; import { type TimeoutContext, TimeoutError } from '../timeout'; -import { type Callback, List, makeCounter, now, promiseWithResolvers } from '../utils'; +import { + addAbortListener, + type Callback, + kDispose, + List, + makeCounter, + now, + promiseWithResolvers +} from '../utils'; import { connect } from './connect'; import { Connection, type ConnectionEvents, type ConnectionOptions } from './connection'; import { @@ -316,7 +324,7 @@ export class ConnectionPool extends TypedEventEmitter { * will be held by the pool. This means that if a connection is checked out it MUST be checked back in or * explicitly destroyed by the new owner. */ - async checkOut(options: { timeoutContext: TimeoutContext }): Promise { + async checkOut(options: { timeoutContext: TimeoutContext } & Abortable): Promise { const checkoutTime = now(); this.emitAndLog( ConnectionPool.CONNECTION_CHECK_OUT_STARTED, @@ -334,6 +342,11 @@ export class ConnectionPool extends TypedEventEmitter { checkoutTime }; + const abortListener = addAbortListener(options.signal, error => { + waitQueueMember.cancelled = true; + reject(error); + }); + this.waitQueue.push(waitQueueMember); process.nextTick(() => this.processWaitQueue()); @@ -364,6 +377,7 @@ export class ConnectionPool extends TypedEventEmitter { } throw error; } finally { + abortListener?.[kDispose](); timeout?.clear(); } } diff --git a/src/cmap/wire_protocol/on_data.ts b/src/cmap/wire_protocol/on_data.ts index f673261833..a532e17d52 100644 --- a/src/cmap/wire_protocol/on_data.ts +++ b/src/cmap/wire_protocol/on_data.ts @@ -1,7 +1,8 @@ import { type EventEmitter } from 'events'; +import { type Abortable } from '../../mongo_types'; import { type TimeoutContext } from '../../timeout'; -import { List, promiseWithResolvers } from '../../utils'; +import { addAbortListener, kDispose, List, promiseWithResolvers } from '../../utils'; /** * @internal @@ -21,7 +22,7 @@ type PendingPromises = Omit< */ export function onData( emitter: EventEmitter, - { timeoutContext }: { timeoutContext?: TimeoutContext } + { timeoutContext, signal }: { timeoutContext?: TimeoutContext } & Abortable ) { // Setup pending events and pending promise lists /** @@ -90,6 +91,7 @@ export function onData( // Adding event handlers emitter.on('data', eventHandler); emitter.on('error', errorHandler); + const abortListener = addAbortListener(signal, errorHandler); const timeoutForSocketRead = timeoutContext?.timeoutForSocketRead; timeoutForSocketRead?.throwIfExpired(); @@ -115,6 +117,7 @@ export function onData( // Adding event handlers emitter.off('data', eventHandler); emitter.off('error', errorHandler); + abortListener?.[kDispose](); finished = true; timeoutForSocketRead?.clear(); const doneResult = { value: undefined, done: finished } as const; diff --git a/src/collection.ts b/src/collection.ts index d7cdc12e8e..58a5b2d165 100644 --- a/src/collection.ts +++ b/src/collection.ts @@ -14,6 +14,7 @@ import type { Db } from './db'; import { MongoInvalidArgumentError, MongoOperationTimeoutError } from './error'; import type { MongoClient, PkFactory } from './mongo_client'; import type { + Abortable, Filter, Flatten, OptionalUnlessRequiredId, @@ -505,7 +506,7 @@ export class Collection { async findOne(filter: Filter): Promise | null>; async findOne( filter: Filter, - options: Omit + options: Omit & Abortable ): Promise | null>; // allow an override of the schema. @@ -532,9 +533,15 @@ export class Collection { * @param filter - The filter predicate. If unspecified, then all documents in the collection will match the predicate */ find(): FindCursor>; - find(filter: Filter, options?: FindOptions): FindCursor>; - find(filter: Filter, options?: FindOptions): FindCursor; - find(filter: Filter = {}, options: FindOptions = {}): FindCursor> { + find(filter: Filter, options?: FindOptions & Abortable): FindCursor>; + find( + filter: Filter, + options?: FindOptions & Abortable + ): FindCursor; + find( + filter: Filter = {}, + options: FindOptions & Abortable = {} + ): FindCursor> { return new FindCursor>( this.client, this.s.namespace, diff --git a/src/cursor/abstract_cursor.ts b/src/cursor/abstract_cursor.ts index 8eccdfcf63..1081958e61 100644 --- a/src/cursor/abstract_cursor.ts +++ b/src/cursor/abstract_cursor.ts @@ -12,7 +12,7 @@ import { MongoTailableCursorError } from '../error'; import type { MongoClient } from '../mongo_client'; -import { TypedEventEmitter } from '../mongo_types'; +import { type Abortable, TypedEventEmitter } from '../mongo_types'; import { executeOperation } from '../operations/execute_operation'; import { GetMoreOperation } from '../operations/get_more'; import { KillCursorsOperation } from '../operations/kill_cursors'; @@ -22,7 +22,14 @@ import { type AsyncDisposable, configureResourceManagement } from '../resource_m import type { Server } from '../sdam/server'; import { ClientSession, maybeClearPinnedConnection } from '../sessions'; import { type CSOTTimeoutContext, type Timeout, TimeoutContext } from '../timeout'; -import { type MongoDBNamespace, squashError } from '../utils'; +import { + addAbortListener, + type Disposable, + kDispose, + type MongoDBNamespace, + squashError, + throwIfAborted +} from '../utils'; /** * @internal @@ -247,12 +254,14 @@ export abstract class AbstractCursor< /** @internal */ protected deserializationOptions: OnDemandDocumentDeserializeOptions; + protected signal: AbortSignal | undefined; + private abortListener: Disposable | undefined; /** @internal */ protected constructor( client: MongoClient, namespace: MongoDBNamespace, - options: AbstractCursorOptions = {} + options: AbstractCursorOptions & Abortable = {} ) { super(); @@ -352,6 +361,9 @@ export abstract class AbstractCursor< }; this.timeoutContext = options.timeoutContext; + this.signal = options.signal; + // eslint-disable-next-line @typescript-eslint/no-misused-promises + this.abortListener = addAbortListener(this.signal, this.close.bind(this, undefined)); } /** @@ -455,6 +467,8 @@ export abstract class AbstractCursor< } async *[Symbol.asyncIterator](): AsyncGenerator { + throwIfAborted(this.signal); + if (this.closed) { return; } @@ -481,6 +495,7 @@ export abstract class AbstractCursor< } yield document; + throwIfAborted(this.signal); } } finally { // Only close the cursor if it has not already been closed. This finally clause handles @@ -526,6 +541,8 @@ export abstract class AbstractCursor< } async hasNext(): Promise { + throwIfAborted(this.signal); + if (this.cursorId === Long.ZERO) { return false; } @@ -551,6 +568,8 @@ export abstract class AbstractCursor< /** Get the next available document from the cursor, returns null if no more documents are available. */ async next(): Promise { + throwIfAborted(this.signal); + if (this.cursorId === Long.ZERO) { throw new MongoCursorExhaustedError(); } @@ -581,6 +600,8 @@ export abstract class AbstractCursor< * Try to get the next available document from the cursor or `null` if an empty batch is returned */ async tryNext(): Promise { + throwIfAborted(this.signal); + if (this.cursorId === Long.ZERO) { throw new MongoCursorExhaustedError(); } @@ -620,6 +641,8 @@ export abstract class AbstractCursor< * @deprecated - Will be removed in a future release. Use for await...of instead. */ async forEach(iterator: (doc: TSchema) => boolean | void): Promise { + throwIfAborted(this.signal); + if (typeof iterator !== 'function') { throw new MongoInvalidArgumentError('Argument "iterator" must be a function'); } @@ -645,6 +668,8 @@ export abstract class AbstractCursor< * cursor.rewind() can be used to reset the cursor. */ async toArray(): Promise { + throwIfAborted(this.signal); + const array: TSchema[] = []; // at the end of the loop (since readBufferedDocuments is called) the buffer will be empty // then, the 'await of' syntax will run a getMore call @@ -968,6 +993,7 @@ export abstract class AbstractCursor< /** @internal */ private async cleanup(timeoutMS?: number, error?: Error) { + this.abortListener?.[kDispose](); this.isClosed = true; const session = this.cursorSession; const timeoutContextForKillCursors = (): CursorTimeoutContext | undefined => { diff --git a/src/cursor/find_cursor.ts b/src/cursor/find_cursor.ts index 28cb373614..4c89307e66 100644 --- a/src/cursor/find_cursor.ts +++ b/src/cursor/find_cursor.ts @@ -72,7 +72,8 @@ export class FindCursor extends ExplainableCursor { const options = { ...this.findOptions, // NOTE: order matters here, we may need to refine this ...this.cursorOptions, - session + session, + signal: this.signal }; if (options.explain) { diff --git a/src/db.ts b/src/db.ts index 121d6fc4f1..bce2c7d2fe 100644 --- a/src/db.ts +++ b/src/db.ts @@ -8,7 +8,7 @@ import { ListCollectionsCursor } from './cursor/list_collections_cursor'; import { RunCommandCursor, type RunCursorCommandOptions } from './cursor/run_command_cursor'; import { MongoInvalidArgumentError } from './error'; import type { MongoClient, PkFactory } from './mongo_client'; -import type { TODO_NODE_3286 } from './mongo_types'; +import type { Abortable, TODO_NODE_3286 } from './mongo_types'; import type { AggregateOptions } from './operations/aggregate'; import { CollectionsOperation } from './operations/collections'; import { @@ -273,7 +273,7 @@ export class Db { * @param command - The command to run * @param options - Optional settings for the command */ - async command(command: Document, options?: RunCommandOptions): Promise { + async command(command: Document, options?: RunCommandOptions & Abortable): Promise { // Intentionally, we do not inherit options from parent for this operation. return await executeOperation( this.client, @@ -284,7 +284,8 @@ export class Db { ...resolveBSONOptions(options), timeoutMS: options?.timeoutMS ?? this.timeoutMS, session: options?.session, - readPreference: options?.readPreference + readPreference: options?.readPreference, + signal: options?.signal }) ) ); diff --git a/src/error.ts b/src/error.ts index 6d41087e3f..6c7c958dac 100644 --- a/src/error.ts +++ b/src/error.ts @@ -196,6 +196,18 @@ export class MongoError extends Error { } } +/** + * An error thrown when a signal is aborted + * + * @public + * @category Error + */ +export class MongoAbortedError extends MongoError { + override get name(): string { + return 'MongoAbortedError'; + } +} + /** * An error coming from the mongo server * diff --git a/src/mongo_types.ts b/src/mongo_types.ts index be116b3699..98485850fb 100644 --- a/src/mongo_types.ts +++ b/src/mongo_types.ts @@ -474,6 +474,14 @@ export class TypedEventEmitter extends EventEm /** @public */ export class CancellationToken extends TypedEventEmitter<{ cancel(): void }> {} +/** @public */ +export type Abortable = { + /** + * When provided the corresponding `AbortController` can be used to cancel an asynchronous action. + */ + signal?: AbortSignal | undefined; +}; + /** * Helper types for dot-notation filter attributes */ diff --git a/src/operations/execute_operation.ts b/src/operations/execute_operation.ts index 81601a6e16..a2fbf77685 100644 --- a/src/operations/execute_operation.ts +++ b/src/operations/execute_operation.ts @@ -198,7 +198,8 @@ async function tryOperation< let server = await topology.selectServer(selector, { session, operationName: operation.commandName, - timeoutContext + timeoutContext, + signal: operation.options.signal }); const hasReadAspect = operation.hasAspect(Aspect.READ_OPERATION); @@ -260,7 +261,8 @@ async function tryOperation< server = await topology.selectServer(selector, { session, operationName: operation.commandName, - previousServer + previousServer, + signal: operation.options.signal }); if (hasWriteAspect && !supportsRetryableWrites(server)) { diff --git a/src/operations/list_collections.ts b/src/operations/list_collections.ts index 6b3296fcf0..57f8aff45e 100644 --- a/src/operations/list_collections.ts +++ b/src/operations/list_collections.ts @@ -2,6 +2,7 @@ import type { Binary, Document } from '../bson'; import { CursorResponse } from '../cmap/wire_protocol/responses'; import { type CursorTimeoutContext, type CursorTimeoutMode } from '../cursor/abstract_cursor'; import type { Db } from '../db'; +import { type Abortable } from '../mongo_types'; import type { Server } from '../sdam/server'; import type { ClientSession } from '../sessions'; import { type TimeoutContext } from '../timeout'; @@ -10,7 +11,9 @@ import { CommandOperation, type CommandOperationOptions } from './command'; import { Aspect, defineAspects } from './operation'; /** @public */ -export interface ListCollectionsOptions extends Omit { +export interface ListCollectionsOptions + extends Omit, + Abortable { /** Since 4.0: If true, will only return the collection name in the response, and will omit additional info */ nameOnly?: boolean; /** Since 4.0: If true and nameOnly is true, allows a user without the required privilege (i.e. listCollections action on the database) to run the command when access control is enforced. */ diff --git a/src/operations/operation.ts b/src/operations/operation.ts index 029047543a..190f2a522b 100644 --- a/src/operations/operation.ts +++ b/src/operations/operation.ts @@ -1,4 +1,5 @@ import { type BSONSerializeOptions, type Document, resolveBSONOptions } from '../bson'; +import { type Abortable } from '../mongo_types'; import { ReadPreference, type ReadPreferenceLike } from '../read_preference'; import type { Server } from '../sdam/server'; import type { ClientSession } from '../sessions'; @@ -59,7 +60,7 @@ export abstract class AbstractOperation { // BSON serialization options bsonOptions?: BSONSerializeOptions; - options: OperationOptions; + options: OperationOptions & Abortable; /** Specifies the time an operation will run until it throws a timeout error. */ timeoutMS?: number; @@ -68,7 +69,7 @@ export abstract class AbstractOperation { static aspects?: Set; - constructor(options: OperationOptions = {}) { + constructor(options: OperationOptions & Abortable = {}) { this.readPreference = this.hasAspect(Aspect.WRITE_OPERATION) ? ReadPreference.primary : (ReadPreference.fromOptions(options) ?? ReadPreference.primary); diff --git a/src/sdam/server.ts b/src/sdam/server.ts index 1aa19a3e18..9094c2fc4e 100644 --- a/src/sdam/server.ts +++ b/src/sdam/server.ts @@ -36,7 +36,7 @@ import { needsRetryableWriteLabel } from '../error'; import type { ServerApi } from '../mongo_client'; -import { TypedEventEmitter } from '../mongo_types'; +import { type Abortable, TypedEventEmitter } from '../mongo_types'; import type { GetMoreOptions } from '../operations/get_more'; import type { ClientSession } from '../sessions'; import { type TimeoutContext } from '../timeout'; @@ -107,7 +107,7 @@ export type ServerEvents = { /** @internal */ export type ServerCommandOptions = Omit & { timeoutContext: TimeoutContext; -}; +} & Abortable; /** @internal */ export class Server extends TypedEventEmitter { @@ -285,7 +285,7 @@ export class Server extends TypedEventEmitter { public async command( ns: MongoDBNamespace, cmd: Document, - options: ServerCommandOptions, + paramOpts: ServerCommandOptions, responseType?: MongoDBResponseConstructor ): Promise { if (ns.db == null || typeof ns === 'string') { @@ -297,24 +297,25 @@ export class Server extends TypedEventEmitter { } // Clone the options - const finalOptions = Object.assign({}, options, { + const options = { + ...paramOpts, wireProtocolCommand: false, directConnection: this.topology.s.options.directConnection - }); + }; // There are cases where we need to flag the read preference not to get sent in // the command, such as pre-5.0 servers attempting to perform an aggregate write // with a non-primary read preference. In this case the effective read preference // (primary) is not the same as the provided and must be removed completely. - if (finalOptions.omitReadPreference) { - delete finalOptions.readPreference; + if (options.omitReadPreference) { + delete options.readPreference; } if (this.description.iscryptd) { - finalOptions.omitMaxTimeMS = true; + options.omitMaxTimeMS = true; } - const session = finalOptions.session; + const session = options.session; let conn = session?.pinnedConnection; this.incrementOperationCount(); @@ -333,11 +334,11 @@ export class Server extends TypedEventEmitter { try { try { - const res = await conn.command(ns, cmd, finalOptions, responseType); + const res = await conn.command(ns, cmd, options, responseType); throwIfWriteConcernError(res); return res; } catch (commandError) { - throw this.decorateCommandError(conn, cmd, finalOptions, commandError); + throw this.decorateCommandError(conn, cmd, options, commandError); } } catch (operationError) { if ( @@ -346,11 +347,11 @@ export class Server extends TypedEventEmitter { ) { await this.pool.reauthenticate(conn); try { - const res = await conn.command(ns, cmd, finalOptions, responseType); + const res = await conn.command(ns, cmd, options, responseType); throwIfWriteConcernError(res); return res; } catch (commandError) { - throw this.decorateCommandError(conn, cmd, finalOptions, commandError); + throw this.decorateCommandError(conn, cmd, options, commandError); } } else { throw operationError; diff --git a/src/sdam/topology.ts b/src/sdam/topology.ts index b6cad4097e..d509e0f08b 100644 --- a/src/sdam/topology.ts +++ b/src/sdam/topology.ts @@ -31,15 +31,17 @@ import { } from '../error'; import type { MongoClient, ServerApi } from '../mongo_client'; import { MongoLoggableComponent, type MongoLogger, SeverityLevel } from '../mongo_logger'; -import { TypedEventEmitter } from '../mongo_types'; +import { type Abortable, TypedEventEmitter } from '../mongo_types'; import { ReadPreference, type ReadPreferenceLike } from '../read_preference'; import type { ClientSession } from '../sessions'; import { Timeout, TimeoutContext, TimeoutError } from '../timeout'; import type { Transaction } from '../transactions'; import { + addAbortListener, type Callback, type EventEmitterWithState, HostAddress, + kDispose, List, makeStateMachine, now, @@ -525,7 +527,7 @@ export class Topology extends TypedEventEmitter { */ async selectServer( selector: string | ReadPreference | ServerSelector, - options: SelectServerOptions + options: SelectServerOptions & Abortable ): Promise { let serverSelector; if (typeof selector !== 'function') { @@ -602,6 +604,11 @@ export class Topology extends TypedEventEmitter { previousServer: options.previousServer }; + const abortListener = addAbortListener(options.signal, error => { + waitQueueMember.cancelled = true; + reject(error); + }); + this.waitQueue.push(waitQueueMember); processWaitQueue(this); @@ -647,6 +654,7 @@ export class Topology extends TypedEventEmitter { // Other server selection error throw error; } finally { + abortListener?.[kDispose](); if (options.timeoutContext?.clearServerSelectionTimeout) timeout?.clear(); } } diff --git a/src/utils.ts b/src/utils.ts index c23161612a..2b8173a281 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -18,6 +18,7 @@ import type { FindCursor } from './cursor/find_cursor'; import type { Db } from './db'; import { type AnyError, + MongoAbortedError, MongoAPIError, MongoCompatibilityError, MongoInvalidArgumentError, @@ -27,6 +28,7 @@ import { MongoRuntimeError } from './error'; import type { MongoClient } from './mongo_client'; +import { type Abortable } from './mongo_types'; import type { CommandOperationOptions, OperationParent } from './operations/command'; import type { Hint, OperationOptions } from './operations/operation'; import { ReadConcern } from './read_concern'; @@ -1349,19 +1351,20 @@ export const randomBytes = promisify(crypto.randomBytes); * @param ee - An event emitter that may emit `ev` * @param name - An event name to wait for */ -export async function once(ee: EventEmitter, name: string): Promise { +export async function once(ee: EventEmitter, name: string, options?: Abortable): Promise { const { promise, resolve, reject } = promiseWithResolvers(); const onEvent = (data: T) => resolve(data); const onError = (error: Error) => reject(error); + const abortListener = addAbortListener(options?.signal, reject); ee.once(name, onEvent).once('error', onError); + try { - const res = await promise; - ee.off('error', onError); - return res; - } catch (error) { + return await promise; + } finally { ee.off(name, onEvent); - throw error; + ee.off('error', onError); + abortListener?.[kDispose](); } } @@ -1468,3 +1471,28 @@ export function decorateDecryptionResult( decorateDecryptionResult(decrypted[k], originalValue, false); } } + +export const kDispose: unique symbol = (Symbol.dispose as any) ?? Symbol('dispose'); +export interface Disposable { + [kDispose](): void; +} + +export function addAbortListener( + signal: AbortSignal | undefined | null, + listener: (event: Error) => void +): Disposable | undefined { + if (signal == null) return; + + const convertReasonToError = () => + listener(new MongoAbortedError('Operation was aborted', { cause: signal.reason })); + + signal.addEventListener('abort', convertReasonToError); + + return { [kDispose]: () => signal.removeEventListener('abort', convertReasonToError) }; +} + +export function throwIfAborted(signal?: { aborted?: boolean; reason?: any }): void { + if (signal?.aborted) { + throw new MongoAbortedError('Operation was aborted', { cause: signal.reason }); + } +} diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts new file mode 100644 index 0000000000..88aede0263 --- /dev/null +++ b/test/integration/node-specific/abort_signal.test.ts @@ -0,0 +1,248 @@ +import * as util from 'node:util'; + +import { expect } from 'chai'; + +import { + type Collection, + type Db, + FindCursor, + MongoAbortedError, + type MongoClient, + ReadPreference, + setDifference +} from '../../mongodb'; +import { sleep } from '../../tools/utils'; + +const isAsyncGenerator = (value: any): value is AsyncGenerator => + value[Symbol.toStringTag] === 'AsyncGenerator'; + +const makeDescriptorGetter = value => prop => [prop, Object.getOwnPropertyDescriptor(value, prop)]; + +function getAllOwnProps(value) { + const props = []; + for (let obj = value; obj !== Object.prototype; obj = Object.getPrototypeOf(obj)) { + props.push(...Object.getOwnPropertyNames(obj).map(makeDescriptorGetter(obj))); + props.push(...Object.getOwnPropertySymbols(obj).map(makeDescriptorGetter(obj))); + } + return props; +} + +describe('AbortSignal support', () => { + let client: MongoClient; + let db: Db; + let collection: Collection<{ a: number }>; + const logs = []; + + beforeEach(async function () { + const utilClient = this.configuration.newClient(); + try { + await utilClient.db('abortSignal').collection('support').deleteMany({}); + await utilClient + .db('abortSignal') + .collection('support') + .insertMany([{ a: 1 }, { a: 2 }, { a: 3 }]); + } finally { + await utilClient.close(); + } + + logs.length = 0; + + client = this.configuration.newClient( + {}, + { + __enableMongoLogger: true, + __internalLoggerConfig: { + MONGODB_LOG_SERVER_SELECTION: 'debug' + }, + mongodbLogPath: { + write: log => logs.push(log) + }, + serverSelectionTimeoutMS: 5000 + } + ); + await client.connect(); + db = client.db('abortSignal'); + collection = db.collection('support'); + }); + + afterEach(async function () { + logs.length = 0; + const utilClient = this.configuration.newClient(); + try { + await utilClient.db('abortSignal').collection('support').deleteMany({}); + } finally { + await utilClient.close(); + } + await client.close(); + }); + + describe('when find() is given a signal', () => { + const cursorAPIs = { + tryNext: [], + hasNext: [], + next: [], + toArray: [], + forEach: [async () => true], + [Symbol.asyncIterator]: [] + }; + + function captureCursor(onResolve, onReject) { + return async function (cursor, cursorAPI, args) { + const apiReturnValue = cursor[cursorAPI](...args); + + let result; + if (isAsyncGenerator(apiReturnValue)) { + result = await apiReturnValue.next().then(onResolve, onReject); + } else { + result = await apiReturnValue.then(onResolve, onReject); + } + + return result; + }; + } + + const captureCursorError = captureCursor( + () => { + expect.fail('expected an error'); + }, + error => error + ); + + const captureCursorResult = captureCursor(result => result, undefined); + + it('should test all the async APIs', () => { + const knownNotTested = [ + 'asyncDispose', + 'close', + 'getMore', + 'cursorInit', + 'fetchBatch', + 'cleanup', + 'transformDocument', + Symbol.asyncDispose + ]; + + const allCursorAsyncAPIs = getAllOwnProps(FindCursor.prototype) + .filter(([, { value }]) => util.types.isAsyncFunction(value)) + .map(([key]) => key); + + expect(setDifference(Object.keys(cursorAPIs), allCursorAsyncAPIs)).to.be.empty; + + const notTested = allCursorAsyncAPIs.filter( + fn => knownNotTested.includes(fn) && Object.keys(cursorAPIs).includes(fn) + ); + + expect(notTested, 'new async function found, should respond to signal state or be internal') + .to.be.empty; + }); + + describe('and the signal is already aborted', () => { + let signal: AbortSignal; + let cursor: FindCursor<{ a: number }>; + + beforeEach(() => { + const controller = new AbortController(); + signal = controller.signal; + controller.abort('The operation was aborted'); + + cursor = collection.find({}, { signal }); + }); + + afterEach(async () => { + await cursor.close(); + }); + + for (const [cursorAPI, { value: args }] of getAllOwnProps(cursorAPIs)) { + it(`rejects ${cursorAPI.toString()} with MongoAbortedError and the cause is signal.reason`, async () => { + const result = await captureCursorError(cursor, cursorAPI, args); + expect(result).to.be.instanceOf(MongoAbortedError); + expect(result.cause).to.equal(signal.reason); + }); + } + }); + + describe('and the signal is aborted after use', () => { + let controller: AbortController; + let signal: AbortSignal; + let cursor: FindCursor<{ a: number }>; + + beforeEach(() => { + controller = new AbortController(); + signal = controller.signal; + cursor = collection.find({}, { signal }); + }); + + afterEach(async () => { + await cursor.close(); + }); + + for (const [cursorAPI, { value: args }] of getAllOwnProps(cursorAPIs)) { + it(`resolves ${cursorAPI.toString()} without Error`, async () => { + const result = await captureCursorResult(cursor, cursorAPI, args); + controller.abort('The operation was aborted'); + expect(result).to.not.be.instanceOf(Error); + }); + + it(`rejects ${cursorAPI.toString()} on the subsequent call with MongoAbortedError`, async () => { + const result = await captureCursorResult(cursor, cursorAPI, args); + expect(result).to.not.be.instanceOf(Error); + + controller.abort('The operation was aborted'); + + const error = await captureCursorError(cursor, cursorAPI, args); + expect(error).to.be.instanceOf(MongoAbortedError); + }); + } + }); + + describe('and the signal is aborted during server selection', () => { + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: FindCursor<{ a: number }>; + + beforeEach(() => { + controller = new AbortController(); + signal = controller.signal; + cursor = collection.find( + {}, + { + signal, + // Pick an unselectable server + readPreference: new ReadPreference('secondary', [ + { something: 'that does not exist' } + ]) + } + ); + }); + + afterEach(async () => { + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()} with MongoAbortedError`, async () => { + const willBeResult = captureCursorError(cursor, cursorAPI, args); + + await sleep(10); + expect(logs.findLast(l => l.operation === 'find')).to.have.property( + 'message', + 'Waiting for suitable server to become available' + ); + + controller.abort(new Error('please stop')); + const start = performance.now(); + const result = await willBeResult; + const end = performance.now(); + expect(end - start).to.be.lessThan(1000); // should be way less than 30s server selection timeout + + expect(result).to.be.instanceOf(MongoAbortedError); + expect(result.cause).to.equal(signal.reason); + }); + } + + for (const [cursorAPI, { value: args }] of getAllOwnProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + }); +});