Skip to content

Commit

Permalink
Convert unit tests to Typescript
Browse files Browse the repository at this point in the history
Signed-off-by: Levko Kravets <[email protected]>
  • Loading branch information
kravets-levko committed May 20, 2024
1 parent 3eed509 commit d178684
Show file tree
Hide file tree
Showing 149 changed files with 6,718 additions and 6,303 deletions.
27 changes: 17 additions & 10 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { TProtocolVersion } from '../thrift/TCLIService_types';
import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient';
import IDriver from './contracts/IDriver';
import IClientContext, { ClientConfig } from './contracts/IClientContext';
import IThriftClient from './contracts/IThriftClient';
import HiveDriver from './hive/HiveDriver';
import DBSQLSession from './DBSQLSession';
import IDBSQLSession from './contracts/IDBSQLSession';
Expand Down Expand Up @@ -43,26 +44,28 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) {
};
}

export type ThriftLibrary = Pick<typeof thrift, 'createClient'>;

export default class DBSQLClient extends EventEmitter implements IDBSQLClient, IClientContext {
private static defaultLogger?: IDBSQLLogger;

private readonly config: ClientConfig;

private connectionProvider?: IConnectionProvider;
protected connectionProvider?: IConnectionProvider;

private authProvider?: IAuthentication;
protected authProvider?: IAuthentication;

private client?: TCLIService.Client;
protected client?: IThriftClient;

private readonly driver = new HiveDriver({
context: this,
});

private readonly logger: IDBSQLLogger;

private readonly thrift = thrift;
protected thrift: ThriftLibrary = thrift;

private sessions = new CloseableCollection<DBSQLSession>();
protected sessions = new CloseableCollection<DBSQLSession>();

private static getDefaultLogger(): IDBSQLLogger {
if (!this.defaultLogger) {
Expand Down Expand Up @@ -99,7 +102,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
this.logger.log(LogLevel.info, 'Created DBSQLClient');
}

private getConnectionOptions(options: ConnectionOptions): IConnectionOptions {
protected getConnectionOptions(options: ConnectionOptions): IConnectionOptions {
return {
host: options.host,
port: options.port || 443,
Expand All @@ -113,7 +116,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
};
}

private initAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
protected createAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
if (authProvider) {
return authProvider;
}
Expand Down Expand Up @@ -143,6 +146,10 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
}
}

protected createConnectionProvider(options: ConnectionOptions): IConnectionProvider {
return new HttpConnection(this.getConnectionOptions(options), this);
}

/**
* Connects DBSQLClient to endpoint
* @public
Expand All @@ -153,9 +160,9 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
* const session = client.connect({host, path, token});
*/
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
this.authProvider = this.initAuthProvider(options, authProvider);
this.authProvider = this.createAuthProvider(options, authProvider);

this.connectionProvider = new HttpConnection(this.getConnectionOptions(options), this);
this.connectionProvider = this.createConnectionProvider(options);

const thriftConnection = await this.connectionProvider.getThriftConnection();

Expand Down Expand Up @@ -238,7 +245,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
return this.connectionProvider;
}

public async getClient(): Promise<TCLIService.Client> {
public async getClient(): Promise<IThriftClient> {
const connectionProvider = await this.getConnectionProvider();

if (!this.client) {
Expand Down
20 changes: 10 additions & 10 deletions lib/DBSQLOperation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,23 @@ async function delay(ms?: number): Promise<void> {
}

export default class DBSQLOperation implements IOperation {
private readonly context: IClientContext;
protected readonly context: IClientContext;

private readonly operationHandle: TOperationHandle;
protected readonly operationHandle: TOperationHandle;

public onClose?: () => void;
protected readonly _data: RowSetProvider;

private readonly _data: RowSetProvider;
protected readonly closeOperation?: TCloseOperationResp;

private readonly closeOperation?: TCloseOperationResp;
protected closed: boolean = false;

private closed: boolean = false;
protected cancelled: boolean = false;

private cancelled: boolean = false;
protected metadata?: TGetResultSetMetadataResp;

private metadata?: TGetResultSetMetadataResp;
protected state: number = TOperationState.INITIALIZED_STATE;

private state: number = TOperationState.INITIALIZED_STATE;
public onClose?: () => void;

// Once operation is finished or fails - cache status response, because subsequent calls
// to `getOperationStatus()` may fail with irrelevant errors, e.g. HTTP 404
Expand Down Expand Up @@ -376,7 +376,7 @@ export default class DBSQLOperation implements IOperation {
return this.metadata;
}

private async getResultHandler(): Promise<ResultSlicer<any>> {
protected async getResultHandler(): Promise<ResultSlicer<any>> {
const metadata = await this.fetchMetadata();
const resultFormat = definedOrError(metadata.resultFormat);

Expand Down
6 changes: 3 additions & 3 deletions lib/DBSQLSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ export default class DBSQLSession implements IDBSQLSession {

private readonly sessionHandle: TSessionHandle;

private isOpen = true;
protected isOpen = true;

public onClose?: () => void;
protected operations = new CloseableCollection<DBSQLOperation>();

private operations = new CloseableCollection<DBSQLOperation>();
public onClose?: () => void;

constructor({ handle, context }: DBSQLSessionConstructorOptions) {
this.sessionHandle = handle;
Expand Down
113 changes: 65 additions & 48 deletions lib/connection/auth/DatabricksOAuth/AuthorizationCode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,26 @@ import { OAuthScopes, scopeDelimiter } from './OAuthScope';
import IClientContext from '../../../contracts/IClientContext';
import AuthenticationError from '../../../errors/AuthenticationError';

export type DefaultOpenAuthUrlFunction = (authUrl: string) => Promise<void>;

export type CustomOpenAuthUrlFunction = (
authUrl: string,
defaultOpenAuthUrl: DefaultOpenAuthUrlFunction,
) => Promise<void>;

export interface AuthorizationCodeOptions {
client: BaseClient;
ports: Array<number>;
context: IClientContext;
openAuthUrl?: CustomOpenAuthUrlFunction;
}

async function startServer(
host: string,
port: number,
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
): Promise<Server> {
const server = http.createServer(requestHandler);

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.listen(port, host, () => {
server.off('error', errorListener);
resolve(server);
});
});
async function defaultOpenAuthUrl(authUrl: string) {
await open(authUrl);
}

async function stopServer(server: Server): Promise<void> {
if (!server.listening) {
return;
}

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.close(() => {
server.off('error', errorListener);
resolve();
});
});
async function openAuthUrl(authUrl: string, defaultOpenUrl: DefaultOpenAuthUrlFunction) {
return defaultOpenUrl(authUrl);
}

export interface AuthorizationCodeFetchResult {
Expand All @@ -65,16 +41,12 @@ export default class AuthorizationCode {

private readonly host: string = 'localhost';

private readonly ports: Array<number>;
private readonly options: AuthorizationCodeOptions;

constructor(options: AuthorizationCodeOptions) {
this.client = options.client;
this.ports = options.ports;
this.context = options.context;
}

private async openUrl(url: string) {
return open(url);
this.options = options;
}

public async fetch(scopes: OAuthScopes): Promise<AuthorizationCodeFetchResult> {
Expand All @@ -84,7 +56,7 @@ export default class AuthorizationCode {

let receivedParams: CallbackParamsType | undefined;

const server = await this.startServer((req, res) => {
const server = await this.createServer((req, res) => {
const params = this.client.callbackParams(req);
if (params.state === state) {
receivedParams = params;
Expand All @@ -108,7 +80,8 @@ export default class AuthorizationCode {
redirect_uri: redirectUri,
});

await this.openUrl(authUrl);
const openUrl = this.options.openAuthUrl ?? openAuthUrl;
await openUrl(authUrl, defaultOpenAuthUrl);
await server.stopped();

if (!receivedParams || !receivedParams.code) {
Expand All @@ -122,11 +95,11 @@ export default class AuthorizationCode {
return { code: receivedParams.code, verifier: verifierString, redirectUri };
}

private async startServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
for (const port of this.ports) {
private async createServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
for (const port of this.options.ports) {
const host = this.host; // eslint-disable-line prefer-destructuring
try {
const server = await startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop
const server = await this.startServer(host, port, requestHandler); // eslint-disable-line no-await-in-loop
this.context.getLogger().log(LogLevel.info, `Listening for OAuth authorization callback at ${host}:${port}`);

let resolveStopped: () => void;
Expand All @@ -140,7 +113,7 @@ export default class AuthorizationCode {
host,
port,
server,
stop: () => stopServer(server).then(resolveStopped).catch(rejectStopped),
stop: () => this.stopServer(server).then(resolveStopped).catch(rejectStopped),
stopped: () => stoppedPromise,
};
} catch (error) {
Expand All @@ -156,6 +129,50 @@ export default class AuthorizationCode {
throw new AuthenticationError('Failed to start server: all ports are in use');
}

protected createHttpServer(requestHandler: (req: IncomingMessage, res: ServerResponse) => void) {
return http.createServer(requestHandler);
}

private async startServer(
host: string,
port: number,
requestHandler: (req: IncomingMessage, res: ServerResponse) => void,
): Promise<Server> {
const server = this.createHttpServer(requestHandler);

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.listen(port, host, () => {
server.off('error', errorListener);
resolve(server);
});
});
}

private async stopServer(server: Server): Promise<void> {
if (!server.listening) {
return;
}

return new Promise((resolve, reject) => {
const errorListener = (error: Error) => {
server.off('error', errorListener);
reject(error);
};

server.on('error', errorListener);
server.close(() => {
server.off('error', errorListener);
resolve();
});
});
}

private renderCallbackResponse(): string {
const applicationName = 'Databricks Sql Connector';

Expand Down
12 changes: 8 additions & 4 deletions lib/connection/auth/DatabricksOAuth/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import IClientContext from '../../../contracts/IClientContext';

export { OAuthFlow };

interface DatabricksOAuthOptions extends OAuthManagerOptions {
export interface DatabricksOAuthOptions extends OAuthManagerOptions {
scopes?: OAuthScopes;
persistence?: OAuthPersistence;
headers?: HeadersInit;
Expand All @@ -16,16 +16,16 @@ interface DatabricksOAuthOptions extends OAuthManagerOptions {
export default class DatabricksOAuth implements IAuthentication {
private readonly context: IClientContext;

private readonly options: DatabricksOAuthOptions;
protected readonly options: DatabricksOAuthOptions;

private readonly manager: OAuthManager;
protected readonly manager: OAuthManager;

private readonly defaultPersistence = new OAuthPersistenceCache();

constructor(options: DatabricksOAuthOptions) {
this.context = options.context;
this.options = options;
this.manager = OAuthManager.getManager(this.options);
this.manager = this.createManager(this.options);
}

public async authenticate(): Promise<HeadersInit> {
Expand All @@ -46,4 +46,8 @@ export default class DatabricksOAuth implements IAuthentication {
Authorization: `Bearer ${token.accessToken}`,
};
}

protected createManager(options: OAuthManagerOptions): OAuthManager {
return OAuthManager.getManager(options);
}
}
8 changes: 4 additions & 4 deletions lib/connection/auth/PlainHttpAuthentication.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ interface PlainHttpAuthenticationOptions {
}

export default class PlainHttpAuthentication implements IAuthentication {
private readonly context: IClientContext;
protected readonly context: IClientContext;

private readonly username: string;
protected readonly username: string;

private readonly password: string;
protected readonly password: string;

private readonly headers: HeadersInit;
protected readonly headers: HeadersInit;

constructor(options: PlainHttpAuthenticationOptions) {
this.context = options.context;
Expand Down
4 changes: 2 additions & 2 deletions lib/connection/connections/HttpConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export default class HttpConnection implements IConnectionProvider {

private readonly context: IClientContext;

private headers: HeadersInit = {};
protected headers: HeadersInit = {};

private connection?: ThriftHttpConnection;

Expand All @@ -36,7 +36,7 @@ export default class HttpConnection implements IConnectionProvider {
});
}

public async getAgent(): Promise<http.Agent> {
public async getAgent(): Promise<http.Agent | undefined> {
if (!this.agent) {
if (this.options.proxy !== undefined) {
this.agent = this.createProxyAgent(this.options.proxy);
Expand Down
Loading

0 comments on commit d178684

Please sign in to comment.