Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert unit tests to Typescript #258

Merged
merged 3 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { TProtocolVersion } from '../thrift/TCLIService_types';
import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient';
import IDriver from './contracts/IDriver';
import IClientContext, { ClientConfig } from './contracts/IClientContext';
import IThriftClient from './contracts/IThriftClient';
import HiveDriver from './hive/HiveDriver';
import DBSQLSession from './DBSQLSession';
import IDBSQLSession from './contracts/IDBSQLSession';
Expand Down Expand Up @@ -43,6 +44,8 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) {
};
}

export type ThriftLibrary = Pick<typeof thrift, 'createClient'>;

export default class DBSQLClient extends EventEmitter implements IDBSQLClient, IClientContext {
private static defaultLogger?: IDBSQLLogger;

Expand All @@ -52,17 +55,17 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I

private authProvider?: IAuthentication;

private client?: TCLIService.Client;
private client?: IThriftClient;

private readonly driver = new HiveDriver({
context: this,
});

private readonly logger: IDBSQLLogger;

private readonly thrift = thrift;
private thrift: ThriftLibrary = thrift;

private sessions = new CloseableCollection<DBSQLSession>();
private readonly sessions = new CloseableCollection<DBSQLSession>();

private static getDefaultLogger(): IDBSQLLogger {
if (!this.defaultLogger) {
Expand Down Expand Up @@ -113,7 +116,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
};
}

private initAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
private createAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
if (authProvider) {
return authProvider;
}
Expand Down Expand Up @@ -143,6 +146,10 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
}
}

private createConnectionProvider(options: ConnectionOptions): IConnectionProvider {
return new HttpConnection(this.getConnectionOptions(options), this);
}

/**
* Connects DBSQLClient to endpoint
* @public
Expand All @@ -153,9 +160,9 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
* const session = client.connect({host, path, token});
*/
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
this.authProvider = this.initAuthProvider(options, authProvider);
this.authProvider = this.createAuthProvider(options, authProvider);

this.connectionProvider = new HttpConnection(this.getConnectionOptions(options), this);
this.connectionProvider = this.createConnectionProvider(options);

const thriftConnection = await this.connectionProvider.getThriftConnection();

Expand Down Expand Up @@ -238,7 +245,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
return this.connectionProvider;
}

public async getClient(): Promise<TCLIService.Client> {
public async getClient(): Promise<IThriftClient> {
const connectionProvider = await this.getConnectionProvider();

if (!this.client) {
Expand Down
2 changes: 1 addition & 1 deletion lib/DBSQLOperation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ export default class DBSQLOperation implements IOperation {

private metadata?: TGetResultSetMetadataResp;

private state: number = TOperationState.INITIALIZED_STATE;
private state: TOperationState = TOperationState.INITIALIZED_STATE;

// Once operation is finished or fails - cache status response, because subsequent calls
// to `getOperationStatus()` may fail with irrelevant errors, e.g. HTTP 404
Expand Down
110 changes: 60 additions & 50 deletions lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,19 @@ import { OAuthScopes, scopeDelimiter } from './OAuthScope';
import IClientContext from '../../../contracts/IClientContext';
import AuthenticationError from '../../../errors/AuthenticationError';

export type DefaultOpenAuthUrlCallback = (authUrl: string) => Promise<void>;

export type OpenAuthUrlCallback = (authUrl: string, defaultOpenAuthUrl: DefaultOpenAuthUrlCallback) => Promise<void>;

export interface AuthorizationCodeOptions {
client: BaseClient;
ports: Array<number>;
context: IClientContext;
openAuthUrl?: OpenAuthUrlCallback;
}

async function startServer(
host: string,
port: number,
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
): Promise<Server> {
const server = http.createServer(requestHandler);

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.listen(port, host, () => {
server.off('error', errorListener);
resolve(server);
});
});
}

async function stopServer(server: Server): Promise<void> {
if (!server.listening) {
return;
}

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.close(() => {
server.off('error', errorListener);
resolve();
});
});
async function defaultOpenAuthUrl(authUrl: string): Promise<void> {
await open(authUrl);
}

export interface AuthorizationCodeFetchResult {
Expand All @@ -65,16 +34,12 @@ export default class AuthorizationCode {

private readonly host: string = 'localhost';

private readonly ports: Array<number>;
private readonly options: AuthorizationCodeOptions;

constructor(options: AuthorizationCodeOptions) {
this.client = options.client;
this.ports = options.ports;
this.context = options.context;
}

private async openUrl(url: string) {
return open(url);
this.options = options;
}

public async fetch(scopes: OAuthScopes): Promise<AuthorizationCodeFetchResult> {
Expand All @@ -84,7 +49,7 @@ export default class AuthorizationCode {

let receivedParams: CallbackParamsType | undefined;

const server = await this.startServer((req, res) => {
const server = await this.createServer((req, res) => {
const params = this.client.callbackParams(req);
if (params.state === state) {
receivedParams = params;
Expand All @@ -108,7 +73,8 @@ export default class AuthorizationCode {
redirect_uri: redirectUri,
});

await this.openUrl(authUrl);
const openAuthUrl = this.options.openAuthUrl ?? defaultOpenAuthUrl;
await openAuthUrl(authUrl, defaultOpenAuthUrl);
await server.stopped();

if (!receivedParams || !receivedParams.code) {
Expand All @@ -122,11 +88,11 @@ export default class AuthorizationCode {
return { code: receivedParams.code, verifier: verifierString, redirectUri };
}

private async startServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
for (const port of this.ports) {
private async createServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
for (const port of this.options.ports) {
const host = this.host; // eslint-disable-line prefer-destructuring
try {
const server = await startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop
const server = await this.startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop
this.context.getLogger().log(LogLevel.info, `Listening for OAuth authorization callback at ${host}:${port}`);

let resolveStopped: () => void;
Expand All @@ -140,7 +106,7 @@ export default class AuthorizationCode {
host,
port,
server,
stop: () => stopServer(server).then(resolveStopped).catch(rejectStopped),
stop: () => this.stopServer(server).then(resolveStopped).catch(rejectStopped),
stopped: () => stoppedPromise,
};
} catch (error) {
Expand All @@ -156,6 +122,50 @@ export default class AuthorizationCode {
throw new AuthenticationError('Failed to start server: all ports are in use');
}

private createHttpServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
return http.createServer(requestHandler);
}

private async startServer(
host: string,
port: number,
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
): Promise<Server> {
const server = this.createHttpServer(requestHandler);

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.listen(port, host, () => {
server.off('error', errorListener);
resolve(server);
});
});
}

private async stopServer(server: Server): Promise<void> {
if (!server.listening) {
return;
}

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.close(() => {
server.off('error', errorListener);
resolve();
});
});
}

private renderCallbackResponse(): string {
const applicationName = 'Databricks Sql Connector';

Expand Down
16 changes: 11 additions & 5 deletions lib/connection/auth/DatabricksOAuth/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import IClientContext from '../../../contracts/IClientContext';

export { OAuthFlow };

interface DatabricksOAuthOptions extends OAuthManagerOptions {
export interface DatabricksOAuthOptions extends OAuthManagerOptions {
scopes?: OAuthScopes;
persistence?: OAuthPersistence;
headers?: HeadersInit;
Expand All @@ -18,14 +18,13 @@ export default class DatabricksOAuth implements IAuthentication {

private readonly options: DatabricksOAuthOptions;

private readonly manager: OAuthManager;
private manager?: OAuthManager;

private readonly defaultPersistence = new OAuthPersistenceCache();

constructor(options: DatabricksOAuthOptions) {
this.context = options.context;
this.options = options;
this.manager = OAuthManager.getManager(this.options);
}

public async authenticate(): Promise<HeadersInit> {
Expand All @@ -35,15 +34,22 @@ export default class DatabricksOAuth implements IAuthentication {

let token = await persistence.read(host);
if (!token) {
token = await this.manager.getToken(scopes ?? defaultOAuthScopes);
token = await this.getManager().getToken(scopes ?? defaultOAuthScopes);
}

token = await this.manager.refreshAccessToken(token);
token = await this.getManager().refreshAccessToken(token);
await persistence.persist(host, token);

return {
...headers,
Authorization: `Bearer ${token.accessToken}`,
};
}

private getManager(): OAuthManager {
if (!this.manager) {
this.manager = OAuthManager.getManager(this.options);
}
return this.manager;
}
}
2 changes: 1 addition & 1 deletion lib/connection/connections/HttpConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export default class HttpConnection implements IConnectionProvider {
});
}

public async getAgent(): Promise<http.Agent> {
public async getAgent(): Promise<http.Agent | undefined> {
if (!this.agent) {
if (this.options.proxy !== undefined) {
this.agent = this.createProxyAgent(this.options.proxy);
Expand Down
2 changes: 1 addition & 1 deletion lib/connection/connections/HttpRetryPolicy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function delay(milliseconds: number): Promise<void> {
export default class HttpRetryPolicy implements IRetryPolicy<HttpTransactionDetails> {
private context: IClientContext;

private readonly startTime: number; // in milliseconds
private startTime: number; // in milliseconds

private attempt: number;

Expand Down
2 changes: 1 addition & 1 deletion lib/connection/contracts/IConnectionProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export interface HttpTransactionDetails {
export default interface IConnectionProvider {
getThriftConnection(): Promise<any>;

getAgent(): Promise<http.Agent>;
getAgent(): Promise<http.Agent | undefined>;

setHeaders(headers: HeadersInit): void;

Expand Down
4 changes: 2 additions & 2 deletions lib/contracts/IClientContext.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import IDBSQLLogger from './IDBSQLLogger';
import IDriver from './IDriver';
import IConnectionProvider from '../connection/contracts/IConnectionProvider';
import TCLIService from '../../thrift/TCLIService';
import IThriftClient from './IThriftClient';

export interface ClientConfig {
directResultsDefaultMaxRows: number;
Expand Down Expand Up @@ -29,7 +29,7 @@ export default interface IClientContext {

getConnectionProvider(): Promise<IConnectionProvider>;

getClient(): Promise<TCLIService.Client>;
getClient(): Promise<IThriftClient>;

getDriver(): Promise<IDriver>;
}
9 changes: 9 additions & 0 deletions lib/contracts/IThriftClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import TCLIService from '../../thrift/TCLIService';

type ThriftClient = TCLIService.Client;

type ThriftClientMethods = {
[K in keyof ThriftClient]: ThriftClient[K];
};

export default interface IThriftClient extends ThriftClientMethods {}
7 changes: 3 additions & 4 deletions lib/hive/Commands/BaseCommand.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import { Response } from 'node-fetch';
import TCLIService from '../../../thrift/TCLIService';
import HiveDriverError from '../../errors/HiveDriverError';
import RetryError, { RetryErrorCode } from '../../errors/RetryError';
import IClientContext from '../../contracts/IClientContext';

export default abstract class BaseCommand {
protected client: TCLIService.Client;
export default abstract class BaseCommand<ClientType> {
protected client: ClientType;

protected context: IClientContext;

constructor(client: TCLIService.Client, context: IClientContext) {
constructor(client: ClientType, context: IClientContext) {
this.client = client;
this.context = context;
}
Expand Down
5 changes: 4 additions & 1 deletion lib/hive/Commands/CancelDelegationTokenCommand.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import BaseCommand from './BaseCommand';
import { TCancelDelegationTokenReq, TCancelDelegationTokenResp } from '../../../thrift/TCLIService_types';
import IThriftClient from '../../contracts/IThriftClient';

export default class CancelDelegationTokenCommand extends BaseCommand {
type Client = Pick<IThriftClient, 'CancelDelegationToken'>;

export default class CancelDelegationTokenCommand extends BaseCommand<Client> {
execute(data: TCancelDelegationTokenReq): Promise<TCancelDelegationTokenResp> {
const request = new TCancelDelegationTokenReq(data);

Expand Down
Loading
Loading