Skip to content

Commit

Permalink
feat(NODE-6258): add signal support to cursor APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
nbbeeken committed Jan 7, 2025
1 parent c392465 commit e46e198
Show file tree
Hide file tree
Showing 19 changed files with 503 additions and 96 deletions.
8 changes: 2 additions & 6 deletions src/client-side-encryption/auto_encrypter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ export class AutoEncrypter {
socketOptions: autoSelectSocketOptions(this._client.s.options)
});

return deserialize(await stateMachine.execute(this, context, options.timeoutContext), {
return deserialize(await stateMachine.execute(this, context, options), {
promoteValues: false,
promoteLongs: false
});
Expand All @@ -419,11 +419,7 @@ export class AutoEncrypter {
socketOptions: autoSelectSocketOptions(this._client.s.options)
});

return await stateMachine.execute(
this,
context,
options.timeoutContext?.csotEnabled() ? options.timeoutContext : undefined
);
return await stateMachine.execute(this, context, options);
}

/**
Expand Down
10 changes: 6 additions & 4 deletions src/client-side-encryption/client_encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ export class ClientEncryption {
TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }));

const dataKey = deserialize(
await stateMachine.execute(this, context, timeoutContext)
await stateMachine.execute(this, context, { timeoutContext })
) as DataKey;

const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(
Expand Down Expand Up @@ -293,7 +293,9 @@ export class ClientEncryption {
resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })
);

const { v: dataKeys } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v: dataKeys } = deserialize(
await stateMachine.execute(this, context, { timeoutContext })
);
if (dataKeys.length === 0) {
return {};
}
Expand Down Expand Up @@ -696,7 +698,7 @@ export class ClientEncryption {
? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }))
: undefined;

const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext }));

return v;
}
Expand Down Expand Up @@ -780,7 +782,7 @@ export class ClientEncryption {
this._timeoutMS != null
? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }))
: undefined;
const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext }));
return v;
}
}
Expand Down
115 changes: 79 additions & 36 deletions src/client-side-encryption/state_machine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ import { CursorTimeoutContext } from '../cursor/abstract_cursor';
import { getSocks, type SocksLib } from '../deps';
import { MongoOperationTimeoutError } from '../error';
import { type MongoClient, type MongoClientOptions } from '../mongo_client';
import { type Abortable } from '../mongo_types';
import { Timeout, type TimeoutContext, TimeoutError } from '../timeout';
import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils';
import {
addAbortListener,
BufferPool,
kDispose,
MongoDBCollectionNamespace,
promiseWithResolvers
} from '../utils';
import { autoSelectSocketOptions, type DataKey } from './client_encryption';
import { MongoCryptError } from './errors';
import { type MongocryptdManager } from './mongocryptd_manager';
Expand Down Expand Up @@ -189,7 +196,7 @@ export class StateMachine {
async execute(
executor: StateMachineExecutable,
context: MongoCryptContext,
timeoutContext?: TimeoutContext
options: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array> {
const keyVaultNamespace = executor._keyVaultNamespace;
const keyVaultClient = executor._keyVaultClient;
Expand All @@ -214,7 +221,7 @@ export class StateMachine {
metaDataClient,
context.ns,
filter,
timeoutContext
options
);
if (collInfo) {
context.addMongoOperationResponse(collInfo);
Expand All @@ -235,9 +242,9 @@ export class StateMachine {
// When we are using the shared library, we don't have a mongocryptd manager.
const markedCommand: Uint8Array = mongocryptdManager
? await mongocryptdManager.withRespawn(
this.markCommand.bind(this, mongocryptdClient, context.ns, command, timeoutContext)
this.markCommand.bind(this, mongocryptdClient, context.ns, command, options)
)
: await this.markCommand(mongocryptdClient, context.ns, command, timeoutContext);
: await this.markCommand(mongocryptdClient, context.ns, command, options);

context.addMongoOperationResponse(markedCommand);
context.finishMongoOperation();
Expand All @@ -246,12 +253,7 @@ export class StateMachine {

case MONGOCRYPT_CTX_NEED_MONGO_KEYS: {
const filter = context.nextMongoOperation();
const keys = await this.fetchKeys(
keyVaultClient,
keyVaultNamespace,
filter,
timeoutContext
);
const keys = await this.fetchKeys(keyVaultClient, keyVaultNamespace, filter, options);

if (keys.length === 0) {
// See docs on EMPTY_V
Expand All @@ -273,7 +275,7 @@ export class StateMachine {
}

case MONGOCRYPT_CTX_NEED_KMS: {
await Promise.all(this.requests(context, timeoutContext));
await Promise.all(this.requests(context, options));
context.finishKMSRequests();
break;
}
Expand Down Expand Up @@ -315,11 +317,13 @@ export class StateMachine {
* @param kmsContext - A C++ KMS context returned from the bindings
* @returns A promise that resolves when the KMS reply has be fully parsed
*/
async kmsRequest(request: MongoCryptKMSRequest, timeoutContext?: TimeoutContext): Promise<void> {
async kmsRequest(
request: MongoCryptKMSRequest,
options: { timeoutContext?: TimeoutContext } & Abortable
): Promise<void> {
const parsedUrl = request.endpoint.split(':');
const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT;
const socketOptions = autoSelectSocketOptions(this.options.socketOptions || {});
const options: tls.ConnectionOptions & {
const socketOptions: tls.ConnectionOptions & {
host: string;
port: number;
autoSelectFamily?: boolean;
Expand All @@ -328,7 +332,7 @@ export class StateMachine {
host: parsedUrl[0],
servername: parsedUrl[0],
port,
...socketOptions
...autoSelectSocketOptions(this.options.socketOptions || {})
};
const message = request.message;
const buffer = new BufferPool();
Expand Down Expand Up @@ -363,7 +367,7 @@ export class StateMachine {
throw error;
}
try {
await this.setTlsOptions(providerTlsOptions, options);
await this.setTlsOptions(providerTlsOptions, socketOptions);
} catch (err) {
throw onerror(err);
}
Expand All @@ -380,23 +384,25 @@ export class StateMachine {
.once('close', () => rejectOnNetSocketError(onclose()))
.once('connect', () => resolveOnNetSocketConnect());

let abortListener;

try {
if (this.options.proxyOptions && this.options.proxyOptions.proxyHost) {
const netSocketOptions = {
...socketOptions,
host: this.options.proxyOptions.proxyHost,
port: this.options.proxyOptions.proxyPort || 1080,
...socketOptions
port: this.options.proxyOptions.proxyPort || 1080
};
netSocket.connect(netSocketOptions);
await willConnect;

try {
socks ??= loadSocks();
options.socket = (
socketOptions.socket = (
await socks.SocksClient.createConnection({
existing_socket: netSocket,
command: 'connect',
destination: { host: options.host, port: options.port },
destination: { host: socketOptions.host, port: socketOptions.port },
proxy: {
// host and port are ignored because we pass existing_socket
host: 'iLoveJavaScript',
Expand All @@ -412,7 +418,7 @@ export class StateMachine {
}
}

socket = tls.connect(options, () => {
socket = tls.connect(socketOptions, () => {
socket.write(message);
});

Expand All @@ -422,6 +428,11 @@ export class StateMachine {
resolve
} = promiseWithResolvers<void>();

abortListener = addAbortListener(options.signal, error => {
destroySockets();
rejectOnTlsSocketError(error);
});

socket
.once('error', err => rejectOnTlsSocketError(onerror(err)))
.once('close', () => rejectOnTlsSocketError(onclose()))
Expand All @@ -436,8 +447,11 @@ export class StateMachine {
resolve();
}
});
await (timeoutContext?.csotEnabled()
? Promise.all([willResolveKmsRequest, Timeout.expires(timeoutContext?.remainingTimeMS)])
await (options.timeoutContext?.csotEnabled()
? Promise.all([
willResolveKmsRequest,
Timeout.expires(options.timeoutContext?.remainingTimeMS)
])
: willResolveKmsRequest);
} catch (error) {
if (error instanceof TimeoutError)
Expand All @@ -446,16 +460,17 @@ export class StateMachine {
} finally {
// There's no need for any more activity on this socket at this point.
destroySockets();
abortListener?.[kDispose]();
}
}

*requests(context: MongoCryptContext, timeoutContext?: TimeoutContext) {
*requests(context: MongoCryptContext, options: { timeoutContext?: TimeoutContext } & Abortable) {
for (
let request = context.nextKMSRequest();
request != null;
request = context.nextKMSRequest()
) {
yield this.kmsRequest(request, timeoutContext);
yield this.kmsRequest(request, options);
}
}

Expand Down Expand Up @@ -516,14 +531,16 @@ export class StateMachine {
client: MongoClient,
ns: string,
filter: Document,
timeoutContext?: TimeoutContext
options: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array | null> {
const { db } = MongoDBCollectionNamespace.fromString(ns);

const cursor = client.db(db).listCollections(filter, {
promoteLongs: false,
promoteValues: false,
timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol())
timeoutContext:
options.timeoutContext && new CursorTimeoutContext(options.timeoutContext, Symbol()),
signal: options.signal
});

// There is always exactly zero or one matching documents, so this should always exhaust the cursor
Expand All @@ -547,17 +564,30 @@ export class StateMachine {
client: MongoClient,
ns: string,
command: Uint8Array,
timeoutContext?: TimeoutContext
options: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array> {
const { db } = MongoDBCollectionNamespace.fromString(ns);
const bsonOptions = { promoteLongs: false, promoteValues: false };
const rawCommand = deserialize(command, bsonOptions);

const commandOptions: {
timeoutMS?: number;
signal?: AbortSignal;
} = {
timeoutMS: undefined,
signal: undefined
};

if (options.timeoutContext?.csotEnabled()) {
commandOptions.timeoutMS = options.timeoutContext.remainingTimeMS;
}
if (options.signal) {
commandOptions.signal = options.signal;
}

const response = await client.db(db).command(rawCommand, {
...bsonOptions,
...(timeoutContext?.csotEnabled()
? { timeoutMS: timeoutContext?.remainingTimeMS }
: undefined)
...commandOptions
});

return serialize(response, this.bsonOptions);
Expand All @@ -575,17 +605,30 @@ export class StateMachine {
client: MongoClient,
keyVaultNamespace: string,
filter: Uint8Array,
timeoutContext?: TimeoutContext
options: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Array<DataKey>> {
const { db: dbName, collection: collectionName } =
MongoDBCollectionNamespace.fromString(keyVaultNamespace);

const commandOptions: {
timeoutContext?: CursorTimeoutContext;
signal?: AbortSignal;
} = {
timeoutContext: undefined,
signal: undefined
};

if (options.timeoutContext != null) {
commandOptions.timeoutContext = new CursorTimeoutContext(options.timeoutContext, Symbol());
}
if (options.signal != null) {
commandOptions.signal = options.signal;
}

return client
.db(dbName)
.collection<DataKey>(collectionName, { readConcern: { level: 'majority' } })
.find(deserialize(filter), {
timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol())
})
.find(deserialize(filter), commandOptions)
.toArray();
}
}
19 changes: 11 additions & 8 deletions src/cmap/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import {
import type { ServerApi, SupportedNodeConnectionOptions } from '../mongo_client';
import { type MongoClientAuthProviders } from '../mongo_client_auth_providers';
import { MongoLoggableComponent, type MongoLogger, SeverityLevel } from '../mongo_logger';
import { type CancellationToken, TypedEventEmitter } from '../mongo_types';
import { type Abortable, type CancellationToken, TypedEventEmitter } from '../mongo_types';
import { ReadPreference, type ReadPreferenceLike } from '../read_preference';
import { ServerType } from '../sdam/common';
import { applySession, type ClientSession, updateSessionFromResponse } from '../sessions';
Expand Down Expand Up @@ -438,7 +438,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {

private async *sendWire(
message: WriteProtocolMessageType,
options: CommandOptions,
options: CommandOptions & Abortable,
responseType?: MongoDBResponseConstructor
): AsyncGenerator<MongoDBResponse> {
this.throwIfAborted();
Expand All @@ -453,7 +453,8 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
await this.writeCommand(message, {
agreedCompressor: this.description.compressor ?? 'none',
zlibCompressionLevel: this.description.zlibCompressionLevel,
timeoutContext: options.timeoutContext
timeoutContext: options.timeoutContext,
signal: options.signal
});

if (options.noResponse || message.moreToCome) {
Expand Down Expand Up @@ -676,7 +677,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
agreedCompressor?: CompressorName;
zlibCompressionLevel?: number;
timeoutContext?: TimeoutContext;
}
} & Abortable
): Promise<void> {
const finalCommand =
options.agreedCompressor === 'none' || !OpCompressedRequest.canCompress(command)
Expand All @@ -701,7 +702,7 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {

if (this.socket.write(buffer)) return;

const drainEvent = once<void>(this.socket, 'drain');
const drainEvent = once<void>(this.socket, 'drain', { signal: options.signal });
const timeout = options?.timeoutContext?.timeoutForSocketWrite;
if (timeout) {
try {
Expand Down Expand Up @@ -729,9 +730,11 @@ export class Connection extends TypedEventEmitter<ConnectionEvents> {
*
* Note that `for-await` loops call `return` automatically when the loop is exited.
*/
private async *readMany(options: {
timeoutContext?: TimeoutContext;
}): AsyncGenerator<OpMsgResponse | OpReply> {
private async *readMany(
options: {
timeoutContext?: TimeoutContext;
} & Abortable
): AsyncGenerator<OpMsgResponse | OpReply> {
try {
this.dataEvents = onData(this.messageStream, options);
this.messageStream.resume();
Expand Down
Loading

0 comments on commit e46e198

Please sign in to comment.