Skip to content

Commit

Permalink
Convert unit tests to Typescript (#258)
Browse files Browse the repository at this point in the history
* Convert unit tests to Typescript

Signed-off-by: Levko Kravets <[email protected]>

* Polish & cleanup

Signed-off-by: Levko Kravets <[email protected]>

* Polish & cleanup

Signed-off-by: Levko Kravets <[email protected]>

---------

Signed-off-by: Levko Kravets <[email protected]>
  • Loading branch information
kravets-levko authored May 27, 2024
1 parent 3eed509 commit 3c29fe2
Show file tree
Hide file tree
Showing 142 changed files with 6,089 additions and 5,851 deletions.
21 changes: 14 additions & 7 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { TProtocolVersion } from '../thrift/TCLIService_types';
import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient';
import IDriver from './contracts/IDriver';
import IClientContext, { ClientConfig } from './contracts/IClientContext';
import IThriftClient from './contracts/IThriftClient';
import HiveDriver from './hive/HiveDriver';
import DBSQLSession from './DBSQLSession';
import IDBSQLSession from './contracts/IDBSQLSession';
Expand Down Expand Up @@ -43,6 +44,8 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) {
};
}

export type ThriftLibrary = Pick<typeof thrift, 'createClient'>;

export default class DBSQLClient extends EventEmitter implements IDBSQLClient, IClientContext {
private static defaultLogger?: IDBSQLLogger;

Expand All @@ -52,17 +55,17 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I

private authProvider?: IAuthentication;

private client?: TCLIService.Client;
private client?: IThriftClient;

private readonly driver = new HiveDriver({
context: this,
});

private readonly logger: IDBSQLLogger;

private readonly thrift = thrift;
private thrift: ThriftLibrary = thrift;

private sessions = new CloseableCollection<DBSQLSession>();
private readonly sessions = new CloseableCollection<DBSQLSession>();

private static getDefaultLogger(): IDBSQLLogger {
if (!this.defaultLogger) {
Expand Down Expand Up @@ -113,7 +116,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
};
}

private initAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
private createAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
if (authProvider) {
return authProvider;
}
Expand Down Expand Up @@ -143,6 +146,10 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
}
}

private createConnectionProvider(options: ConnectionOptions): IConnectionProvider {
return new HttpConnection(this.getConnectionOptions(options), this);
}

/**
* Connects DBSQLClient to endpoint
* @public
Expand All @@ -153,9 +160,9 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
* const session = client.connect({host, path, token});
*/
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
this.authProvider = this.initAuthProvider(options, authProvider);
this.authProvider = this.createAuthProvider(options, authProvider);

this.connectionProvider = new HttpConnection(this.getConnectionOptions(options), this);
this.connectionProvider = this.createConnectionProvider(options);

const thriftConnection = await this.connectionProvider.getThriftConnection();

Expand Down Expand Up @@ -238,7 +245,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
return this.connectionProvider;
}

public async getClient(): Promise<TCLIService.Client> {
public async getClient(): Promise<IThriftClient> {
const connectionProvider = await this.getConnectionProvider();

if (!this.client) {
Expand Down
2 changes: 1 addition & 1 deletion lib/DBSQLOperation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ export default class DBSQLOperation implements IOperation {

private metadata?: TGetResultSetMetadataResp;

private state: number = TOperationState.INITIALIZED_STATE;
private state: TOperationState = TOperationState.INITIALIZED_STATE;

// Once operation is finished or fails - cache status response, because subsequent calls
// to `getOperationStatus()` may fail with irrelevant errors, e.g. HTTP 404
Expand Down
110 changes: 60 additions & 50 deletions lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,19 @@ import { OAuthScopes, scopeDelimiter } from './OAuthScope';
import IClientContext from '../../../contracts/IClientContext';
import AuthenticationError from '../../../errors/AuthenticationError';

export type DefaultOpenAuthUrlCallback = (authUrl: string) => Promise<void>;

export type OpenAuthUrlCallback = (authUrl: string, defaultOpenAuthUrl: DefaultOpenAuthUrlCallback) => Promise<void>;

export interface AuthorizationCodeOptions {
client: BaseClient;
ports: Array<number>;
context: IClientContext;
openAuthUrl?: OpenAuthUrlCallback;
}

async function startServer(
host: string,
port: number,
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
): Promise<Server> {
const server = http.createServer(requestHandler);

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.listen(port, host, () => {
server.off('error', errorListener);
resolve(server);
});
});
}

async function stopServer(server: Server): Promise<void> {
if (!server.listening) {
return;
}

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.close(() => {
server.off('error', errorListener);
resolve();
});
});
async function defaultOpenAuthUrl(authUrl: string): Promise<void> {
await open(authUrl);
}

export interface AuthorizationCodeFetchResult {
Expand All @@ -65,16 +34,12 @@ export default class AuthorizationCode {

private readonly host: string = 'localhost';

private readonly ports: Array<number>;
private readonly options: AuthorizationCodeOptions;

constructor(options: AuthorizationCodeOptions) {
this.client = options.client;
this.ports = options.ports;
this.context = options.context;
}

private async openUrl(url: string) {
return open(url);
this.options = options;
}

public async fetch(scopes: OAuthScopes): Promise<AuthorizationCodeFetchResult> {
Expand All @@ -84,7 +49,7 @@ export default class AuthorizationCode {

let receivedParams: CallbackParamsType | undefined;

const server = await this.startServer((req, res) => {
const server = await this.createServer((req, res) => {
const params = this.client.callbackParams(req);
if (params.state === state) {
receivedParams = params;
Expand All @@ -108,7 +73,8 @@ export default class AuthorizationCode {
redirect_uri: redirectUri,
});

await this.openUrl(authUrl);
const openAuthUrl = this.options.openAuthUrl ?? defaultOpenAuthUrl;
await openAuthUrl(authUrl, defaultOpenAuthUrl);
await server.stopped();

if (!receivedParams || !receivedParams.code) {
Expand All @@ -122,11 +88,11 @@ export default class AuthorizationCode {
return { code: receivedParams.code, verifier: verifierString, redirectUri };
}

private async startServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
for (const port of this.ports) {
private async createServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
for (const port of this.options.ports) {
const host = this.host; // eslint-disable-line prefer-destructuring
try {
const server = await startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop
const server = await this.startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop
this.context.getLogger().log(LogLevel.info, `Listening for OAuth authorization callback at ${host}:${port}`);

let resolveStopped: () => void;
Expand All @@ -140,7 +106,7 @@ export default class AuthorizationCode {
host,
port,
server,
stop: () => stopServer(server).then(resolveStopped).catch(rejectStopped),
stop: () => this.stopServer(server).then(resolveStopped).catch(rejectStopped),
stopped: () => stoppedPromise,
};
} catch (error) {
Expand All @@ -156,6 +122,50 @@ export default class AuthorizationCode {
throw new AuthenticationError('Failed to start server: all ports are in use');
}

private createHttpServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
return http.createServer(requestHandler);
}

private async startServer(
host: string,
port: number,
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
): Promise<Server> {
const server = this.createHttpServer(requestHandler);

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.listen(port, host, () => {
server.off('error', errorListener);
resolve(server);
});
});
}

private async stopServer(server: Server): Promise<void> {
if (!server.listening) {
return;
}

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.close(() => {
server.off('error', errorListener);
resolve();
});
});
}

private renderCallbackResponse(): string {
const applicationName = 'Databricks Sql Connector';

Expand Down
16 changes: 11 additions & 5 deletions lib/connection/auth/DatabricksOAuth/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import IClientContext from '../../../contracts/IClientContext';

export { OAuthFlow };

interface DatabricksOAuthOptions extends OAuthManagerOptions {
export interface DatabricksOAuthOptions extends OAuthManagerOptions {
scopes?: OAuthScopes;
persistence?: OAuthPersistence;
headers?: HeadersInit;
Expand All @@ -18,14 +18,13 @@ export default class DatabricksOAuth implements IAuthentication {

private readonly options: DatabricksOAuthOptions;

private readonly manager: OAuthManager;
private manager?: OAuthManager;

private readonly defaultPersistence = new OAuthPersistenceCache();

constructor(options: DatabricksOAuthOptions) {
this.context = options.context;
this.options = options;
this.manager = OAuthManager.getManager(this.options);
}

public async authenticate(): Promise<HeadersInit> {
Expand All @@ -35,15 +34,22 @@ export default class DatabricksOAuth implements IAuthentication {

let token = await persistence.read(host);
if (!token) {
token = await this.manager.getToken(scopes ?? defaultOAuthScopes);
token = await this.getManager().getToken(scopes ?? defaultOAuthScopes);
}

token = await this.manager.refreshAccessToken(token);
token = await this.getManager().refreshAccessToken(token);
await persistence.persist(host, token);

return {
...headers,
Authorization: `Bearer ${token.accessToken}`,
};
}

private getManager(): OAuthManager {
if (!this.manager) {
this.manager = OAuthManager.getManager(this.options);
}
return this.manager;
}
}
2 changes: 1 addition & 1 deletion lib/connection/connections/HttpConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export default class HttpConnection implements IConnectionProvider {
});
}

public async getAgent(): Promise<http.Agent> {
public async getAgent(): Promise<http.Agent | undefined> {
if (!this.agent) {
if (this.options.proxy !== undefined) {
this.agent = this.createProxyAgent(this.options.proxy);
Expand Down
2 changes: 1 addition & 1 deletion lib/connection/connections/HttpRetryPolicy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function delay(milliseconds: number): Promise<void> {
export default class HttpRetryPolicy implements IRetryPolicy<HttpTransactionDetails> {
private context: IClientContext;

private readonly startTime: number; // in milliseconds
private startTime: number; // in milliseconds

private attempt: number;

Expand Down
2 changes: 1 addition & 1 deletion lib/connection/contracts/IConnectionProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export interface HttpTransactionDetails {
export default interface IConnectionProvider {
getThriftConnection(): Promise<any>;

getAgent(): Promise<http.Agent>;
getAgent(): Promise<http.Agent | undefined>;

setHeaders(headers: HeadersInit): void;

Expand Down
4 changes: 2 additions & 2 deletions lib/contracts/IClientContext.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import IDBSQLLogger from './IDBSQLLogger';
import IDriver from './IDriver';
import IConnectionProvider from '../connection/contracts/IConnectionProvider';
import TCLIService from '../../thrift/TCLIService';
import IThriftClient from './IThriftClient';

export interface ClientConfig {
directResultsDefaultMaxRows: number;
Expand Down Expand Up @@ -29,7 +29,7 @@ export default interface IClientContext {

getConnectionProvider(): Promise<IConnectionProvider>;

getClient(): Promise<TCLIService.Client>;
getClient(): Promise<IThriftClient>;

getDriver(): Promise<IDriver>;
}
9 changes: 9 additions & 0 deletions lib/contracts/IThriftClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import TCLIService from '../../thrift/TCLIService';

type ThriftClient = TCLIService.Client;

type ThriftClientMethods = {
[K in keyof ThriftClient]: ThriftClient[K];
};

export default interface IThriftClient extends ThriftClientMethods {}
7 changes: 3 additions & 4 deletions lib/hive/Commands/BaseCommand.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import { Response } from 'node-fetch';
import TCLIService from '../../../thrift/TCLIService';
import HiveDriverError from '../../errors/HiveDriverError';
import RetryError, { RetryErrorCode } from '../../errors/RetryError';
import IClientContext from '../../contracts/IClientContext';

export default abstract class BaseCommand {
protected client: TCLIService.Client;
export default abstract class BaseCommand<ClientType> {
protected client: ClientType;

protected context: IClientContext;

constructor(client: TCLIService.Client, context: IClientContext) {
constructor(client: ClientType, context: IClientContext) {
this.client = client;
this.context = context;
}
Expand Down
5 changes: 4 additions & 1 deletion lib/hive/Commands/CancelDelegationTokenCommand.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import BaseCommand from './BaseCommand';
import { TCancelDelegationTokenReq, TCancelDelegationTokenResp } from '../../../thrift/TCLIService_types';
import IThriftClient from '../../contracts/IThriftClient';

export default class CancelDelegationTokenCommand extends BaseCommand {
type Client = Pick<IThriftClient, 'CancelDelegationToken'>;

export default class CancelDelegationTokenCommand extends BaseCommand<Client> {
execute(data: TCancelDelegationTokenReq): Promise<TCancelDelegationTokenResp> {
const request = new TCancelDelegationTokenReq(data);

Expand Down
Loading

0 comments on commit 3c29fe2

Please sign in to comment.