Skip to content

Commit

Permalink
chore: darft change for AE
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSun90 committed Sep 27, 2024
1 parent 62dd02c commit 1775071
Show file tree
Hide file tree
Showing 9 changed files with 1,682 additions and 84 deletions.
14 changes: 7 additions & 7 deletions src/always-encrypted/key-crypto.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright (c) 2019 Microsoft Corporation

import { type CryptoMetadata, type EncryptionKeyInfo } from './types';
import { type InternalConnectionOptions as ConnectionOptions } from '../connection';
import { type ParserOptions } from '../token/stream-parser';
import SymmetricKey from './symmetric-key';
import { getKey } from './symmetric-key-cache';
import { AeadAes256CbcHmac256Algorithm, algorithmName } from './aead-aes-256-cbc-hmac-algorithm';
Expand All @@ -16,7 +16,7 @@ export const validateAndGetEncryptionAlgorithmName = (cipherAlgorithmId: number,
return algorithmName;
};

export const encryptWithKey = async (plaintext: Buffer, md: CryptoMetadata, options: ConnectionOptions): Promise<Buffer> => {
export const encryptWithKey = async (plaintext: Buffer, md: CryptoMetadata, options: ParserOptions): Promise<Buffer> => {
if (!options.trustedServerNameAE) {
throw new Error('Server name should not be null in EncryptWithKey');
}
Expand All @@ -38,14 +38,14 @@ export const encryptWithKey = async (plaintext: Buffer, md: CryptoMetadata, opti
return cipherText;
};

export const decryptWithKey = (cipherText: Buffer, md: CryptoMetadata, options: ConnectionOptions): Buffer => {
export const decryptWithKey = async (cipherText: Buffer, md: CryptoMetadata, options: ParserOptions): Promise<Buffer> => {
if (!options.trustedServerNameAE) {
throw new Error('Server name should not be null in DecryptWithKey');
}

// if (!md.cipherAlgorithm) {
// await decryptSymmetricKey(md, options);
// }
if (!md.cipherAlgorithm) {
await decryptSymmetricKey(md, options);
}

if (!md.cipherAlgorithm) {
throw new Error('Cipher Algorithm should not be null in DecryptWithKey');
Expand All @@ -60,7 +60,7 @@ export const decryptWithKey = (cipherText: Buffer, md: CryptoMetadata, options:
return plainText;
};

export const decryptSymmetricKey = async (md: CryptoMetadata, options: ConnectionOptions): Promise<void> => {
export const decryptSymmetricKey = async (md: CryptoMetadata, options: ParserOptions): Promise<void> => {
if (!md) {
throw new Error('md should not be null in DecryptSymmetricKey.');
}
Expand Down
4 changes: 2 additions & 2 deletions src/always-encrypted/symmetric-key-cache.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import { type EncryptionKeyInfo } from './types';
import SymmetricKey from './symmetric-key';
import { type InternalConnectionOptions as ConnectionOptions } from '../connection';
import { type ParserOptions } from '../token/stream-parser';
import LRU from 'lru-cache';

const cache = new LRU<string, SymmetricKey>(0);

export const getKey = async (keyInfo: EncryptionKeyInfo, options: ConnectionOptions): Promise<SymmetricKey> => {
export const getKey = async (keyInfo: EncryptionKeyInfo, options: ParserOptions): Promise<SymmetricKey> => {
if (!options.trustedServerNameAE) {
throw new Error('Server name should not be null in getKey');
}
Expand Down
99 changes: 98 additions & 1 deletion src/connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ import Message from './message';
import { type Metadata } from './metadata-parser';
import { createNTLMRequest } from './ntlm';
import { ColumnEncryptionAzureKeyVaultProvider } from './always-encrypted/keystore-provider-azure-key-vault';
import { shouldHonorAE } from './always-encrypted/utils';
import { getParameterEncryptionMetadata } from './always-encrypted/get-parameter-encryption-metadata';

import { type Parameter, TYPES } from './data-type';
import { BulkLoadPayload } from './bulk-load-payload';
Expand Down Expand Up @@ -397,6 +399,11 @@ interface KeyStoreProviderMap {
[key: string]: ColumnEncryptionAzureKeyVaultProvider;
}

interface KeyStoreProvider {
key: string;
value: ColumnEncryptionAzureKeyVaultProvider;
}

/**
* @private
*/
Expand Down Expand Up @@ -520,6 +527,10 @@ export interface ConnectionOptions {
*/
cancelTimeout?: number;

columnEncryptionKeyCacheTTL?: number;

columnEncryptionSetting?: boolean;

/**
* A function with parameters `(columnName, index, columnMetaData)` and returning a string. If provided,
* this will be called once per column per result-set. The returned value will be used instead of the SQL-provided
Expand Down Expand Up @@ -675,6 +686,8 @@ export interface ConnectionOptions {
*/
encrypt?: string | boolean;

encryptionKeyStoreProviders?: KeyStoreProvider[];

/**
* By default, if the database requested by [[database]] cannot be accessed,
* the connection will fail with an error. However, if [[fallbackToDefaultDb]] is
Expand Down Expand Up @@ -871,6 +884,7 @@ interface RoutingData {
port: number;
}


/**
* A [[Connection]] instance represents a single connection to a database server.
*
Expand Down Expand Up @@ -1690,6 +1704,26 @@ class Connection extends EventEmitter {
this.config.options.useUTC = config.options.useUTC;
}

if (config.options.columnEncryptionSetting !== undefined) {
if (typeof config.options.columnEncryptionSetting !== 'boolean') {
throw new TypeError('The "config.options.columnEncryptionSetting" property must be of type boolean.');
}

this.config.options.columnEncryptionSetting = config.options.columnEncryptionSetting;
}

if (config.options.columnEncryptionKeyCacheTTL !== undefined) {
if (typeof config.options.columnEncryptionKeyCacheTTL !== 'number') {
throw new TypeError('The "config.options.columnEncryptionKeyCacheTTL" property must be of type number.');
}

if (config.options.columnEncryptionKeyCacheTTL <= 0) {
throw new TypeError('The "config.options.columnEncryptionKeyCacheTTL" property must be greater than 0.');
}

this.config.options.columnEncryptionKeyCacheTTL = config.options.columnEncryptionKeyCacheTTL;
}

if (config.options.workstationId !== undefined) {
if (typeof config.options.workstationId !== 'string') {
throw new TypeError('The "config.options.workstationId" property must be of type string.');
Expand All @@ -1705,6 +1739,51 @@ class Connection extends EventEmitter {

this.config.options.lowerCaseGuids = config.options.lowerCaseGuids;
}

if (config.options.encryptionKeyStoreProviders) {
for (const entry of config.options.encryptionKeyStoreProviders) {
const providerName = entry.key;

if (!providerName || providerName.length === 0) {
throw new TypeError('Invalid key store provider name specified. Key store provider names cannot be null or empty.');
}

if (providerName.substring(0, 6).toUpperCase().localeCompare('MSSQL_') === 0) {
throw new TypeError(`Invalid key store provider name ${providerName}. MSSQL_ prefix is reserved for system key store providers.`);
}

if (!entry.value) {
throw new TypeError(`Null reference specified for key store provider ${providerName}. Expecting a non-null value.`);
}

if (!this.config.options.encryptionKeyStoreProviders) {
this.config.options.encryptionKeyStoreProviders = {};
}

this.config.options.encryptionKeyStoreProviders[providerName] = entry.value;
}
}
}

let serverName = this.config.server;
if (!serverName) {
serverName = 'localhost';
}

const px = serverName.indexOf('\\');

if (px > 0) {
serverName = serverName.substring(0, px);
}

this.config.options.trustedServerNameAE = serverName;

if (this.config.options.instanceName) {
this.config.options.trustedServerNameAE = `${this.config.options.trustedServerNameAE}:${this.config.options.instanceName}`;
}

if (this.config.options.port) {
this.config.options.trustedServerNameAE = `${this.config.options.trustedServerNameAE}:${this.config.options.port}`;
}

this.secureContextOptions = this.config.options.cryptoCredentialsDetails;
Expand Down Expand Up @@ -2592,7 +2671,7 @@ class Connection extends EventEmitter {
*
* @param request A [[Request]] object representing the request.
*/
execSql(request: Request) {
_execSql(request: Request) {
try {
request.validateParameters(this.databaseCollation);
} catch (error: any) {
Expand Down Expand Up @@ -2635,6 +2714,24 @@ class Connection extends EventEmitter {
this.makeRequest(request, TYPE.RPC_REQUEST, new RpcRequestPayload(Procedures.Sp_ExecuteSql, parameters, this.currentTransactionDescriptor(), this.config.options, this.databaseCollation));
}

execSql(request: Request) {
request.shouldHonorAE = shouldHonorAE(request.statementColumnEncryptionSetting, this.config.options.columnEncryptionSetting);
if (request.shouldHonorAE && request.cryptoMetadataLoaded === false && (request.parameters && request.parameters.length > 0)) {
getParameterEncryptionMetadata(this, request, (error?: Error) => {
if (error != null) {
process.nextTick(() => {
this.transitionTo(this.STATE.LOGGED_IN);
this.debug.log(error.message);
request.callback(error);
});
return;
}
this._execSql(request);
});
} else {
this._execSql(request);
}
}
/**
* Creates a new BulkLoad instance.
*
Expand Down
7 changes: 5 additions & 2 deletions src/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ export interface ParameterOptions {
length?: number;
precision?: number;
scale?: number;

forceEncrypt?: boolean;
}

interface RequestOptions {
Expand Down Expand Up @@ -399,7 +401,7 @@ class Request extends EventEmitter {
*/
// TODO: `type` must be a valid TDS value type
addParameter(name: string, type: DataType, value?: unknown, options?: Readonly<ParameterOptions> | null) {
const { output = false, length, precision, scale } = options ?? {};
const { output = false, length, precision, scale, forceEncrypt = false } = options ?? {};

const parameter: Parameter = {
type: type,
Expand All @@ -408,7 +410,8 @@ class Request extends EventEmitter {
output: output,
length: length,
precision: precision,
scale: scale
scale: scale,
forceEncrypt: forceEncrypt
};

this.parameters.push(parameter);
Expand Down
41 changes: 5 additions & 36 deletions src/token/row-token-parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import Parser from './stream-parser';
import { type ColumnMetadata } from './colmetadata-token-parser';

import { RowToken } from './token';
import * as iconv from 'iconv-lite';
// import * as iconv from 'iconv-lite';

import { isPLPStream, readPLPStream, readValue } from '../value-parser';
import { NotEnoughDataError } from './helpers';
import { readValue, readDecrypt } from '../value-parser';

Check failure on line 9 in src/token/row-token-parser.ts

View workflow job for this annotation

GitHub Actions / Linting

'readValue' is defined but never used
// import { NotEnoughDataError } from './helpers';

interface Column {
value: unknown;
Expand All @@ -16,40 +16,9 @@ interface Column {

async function rowParser(parser: Parser): Promise<RowToken> {
const columns: Column[] = [];

for (const metadata of parser.colMetadata) {
while (true) {
if (isPLPStream(metadata)) {
const chunks = await readPLPStream(parser);

if (chunks === null) {
columns.push({ value: chunks, metadata });
} else if (metadata.type.name === 'NVarChar' || metadata.type.name === 'Xml') {
columns.push({ value: Buffer.concat(chunks).toString('ucs2'), metadata });
} else if (metadata.type.name === 'VarChar') {
columns.push({ value: iconv.decode(Buffer.concat(chunks), metadata.collation?.codepage ?? 'utf8'), metadata });
} else if (metadata.type.name === 'VarBinary' || metadata.type.name === 'UDT') {
columns.push({ value: Buffer.concat(chunks), metadata });
}
} else {
let result;
try {
result = readValue(parser.buffer, parser.position, metadata, parser.options);
} catch (err) {
if (err instanceof NotEnoughDataError) {
await parser.waitForChunk();
continue;
}

throw err;
}

parser.position = result.offset;
columns.push({ value: result.value, metadata });
}

break;
}
const result = await readDecrypt(parser, metadata, parser.options);
columns.push({ value: result.value, metadata });
}

if (parser.options.useColumnNames) {
Expand Down
8 changes: 7 additions & 1 deletion src/token/stream-parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ import returnValueParser from './returnvalue-token-parser';
import rowParser from './row-token-parser';
import nbcRowParser from './nbcrow-token-parser';
import sspiParser from './sspi-token-parser';
import { decryptSymmetricKey } from '../always-encrypted/key-crypto';
import { NotEnoughDataError } from './helpers';

export type ParserOptions = Pick<InternalConnectionOptions, 'useUTC' | 'lowerCaseGuids' | 'tdsVersion' | 'useColumnNames' | 'columnNameReplacer' | 'camelCaseColumns' | 'serverSupportsColumnEncryption'>;
export type ParserOptions = Pick<InternalConnectionOptions, 'useUTC' | 'lowerCaseGuids' | 'tdsVersion' | 'useColumnNames' | 'columnNameReplacer' | 'camelCaseColumns' | 'serverSupportsColumnEncryption' | 'trustedServerNameAE' | 'encryptionKeyStoreProviders' | 'columnEncryptionKeyCacheTTL'>;

class Parser {
debug: Debug;
Expand Down Expand Up @@ -158,6 +159,11 @@ class Parser {
async readColMetadataToken(): Promise<ColMetadataToken> {
const token = await colMetadataParser(this);
this.colMetadata = token.columns;
await Promise.all(this.colMetadata.map(async (metadata) => {
if (metadata.cryptoMetadata) {
await decryptSymmetricKey(metadata.cryptoMetadata, this.options);
}
}));
return token;
}

Expand Down
Loading

0 comments on commit 1775071

Please sign in to comment.