Skip to content

Commit

Permalink
Merge pull request #3 from gabrielh-silvestre/feature/add-flexibility
Browse files Browse the repository at this point in the history
Feature/add flexibility
  • Loading branch information
gabrielh-silvestre authored Jan 12, 2024
2 parents 979bd45 + 6c47d70 commit 1f2fd71
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 47 deletions.
7 changes: 0 additions & 7 deletions __tests__/bun/agents/base-agent.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import {
OUTPUT_TOOL_RESPONSE,
RETRIEVE_RUN_RESPONSE,
mockFunction,
mockOpenAI,
mockOpenAIRun,
} from '__tests__/mocks/openai-mock';
import { Run } from 'openai/resources/beta/threads/runs/runs';
Expand Down Expand Up @@ -52,7 +51,6 @@ const { id: runId, thread_id: threadId } = CREATE_AND_RUN_RESPONSE;
describe('[Unit] Tests for AgentOpenAI', () => {
let agent: TestAgent;

const mockedOpenAI = mockOpenAI();
const mockedOpenAIRun = mockOpenAIRun();
const mockedFunction = mockFunction();

Expand All @@ -68,7 +66,6 @@ describe('[Unit] Tests for AgentOpenAI', () => {
beforeEach(() => {
agent = new TestAgent({
agentId: 'agent-123',
openai: mockedOpenAI as any,
functions: [],
});

Expand All @@ -95,7 +92,6 @@ describe('[Unit] Tests for AgentOpenAI', () => {
['agentId', 'empty', ''],
['agentId', 'null', null],
['agentId', 'undefined', undefined],
['openai', 'empty', {}],
['poolingInterval', 'zero', 0],
[
'functions',
Expand All @@ -106,7 +102,6 @@ describe('[Unit] Tests for AgentOpenAI', () => {
try {
const options = {
agentId: 'agent-123',
openai: mockOpenAI() as any,
[prop]: val,
};

Expand Down Expand Up @@ -189,7 +184,6 @@ describe('[Unit] Tests for AgentOpenAI', () => {
it('should execute a function', async () => {
const localAgent = new TestAgent({
agentId: 'agent-123',
openai: mockedOpenAI as any,
functions: [mockedFunction],
});

Expand Down Expand Up @@ -254,7 +248,6 @@ describe('[Unit] Tests for AgentOpenAI', () => {

expect(spyOpenaiCreateAndRun).toHaveBeenCalled();
expect(spyOpenaiRetrieveRun).toHaveBeenCalled();
expect(spyOpenaiListMessages).toHaveBeenCalled();
}
);
});
11 changes: 3 additions & 8 deletions __tests__/bun/agents/sns-functionts.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import { GuardError } from 'src/errors/guard-error';

import { MockSnsHandler } from '__tests__/mocks/openai-mock';

const TOPIC_ARN = 'arn:aws:sns:us-east-1:000000000000:TestTopic';

class TestSnsFunction extends SnsPublishFunction {
constructor() {
super({
Expand All @@ -26,10 +28,7 @@ class TestSnsFunction extends SnsPublishFunction {
},
},

sns: {
handler: MockSnsHandler as any,
topicArn: 'arn:aws:sns:us-east-1:000000000000:TestTopic',
},
sns: { topicArn: TOPIC_ARN },
});
}
}
Expand Down Expand Up @@ -75,7 +74,6 @@ describe('[Unit] Tests for SnsPublishFunction', () => {
},

sns: {
handler: MockSnsHandler as any,
topicArn: 'arn:aws:sns:us-east-1:000000000000:TestTopic',
},
});
Expand All @@ -90,9 +88,6 @@ describe('[Unit] Tests for SnsPublishFunction', () => {
['sns', 'empty', {}],
['sns', 'null', null],
['sns', 'undefined', undefined],
['sns.handler', 'empty', { handler: {} }],
['sns.handler', 'null', { handler: null }],
['sns.handler', 'undefined', { handler: undefined }],
['sns.topicArn', 'empty', { topicArn: '' }],
['sns.topicArn', 'null', { topicArn: null }],
['sns.topicArn', 'undefined', { topicArn: undefined }],
Expand Down
21 changes: 17 additions & 4 deletions __tests__/bun/setup.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import { afterAll, beforeEach, jest } from 'bun:test';
import { afterEach, beforeEach, jest, mock } from 'bun:test';

import SNS from '@aws-sdk/client-sns';
import OpenAI from 'openai';

import { MockSnsHandler, mockOpenAI } from '__tests__/mocks/openai-mock';

beforeEach(() => {
jest.restoreAllMocks();
mock.module('@aws-sdk/client-sns', () => ({
...SNS,
SNSClient: mock(() => MockSnsHandler),
}));

mock.module('openai', () => ({
...OpenAI,
OpenAI: mock(() => mockOpenAI()),
}));
});

afterAll(() => {
afterEach(() => {
jest.restoreAllMocks();
});
})
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { describe, expect, it } from 'bun:test';
import { GuardError } from 'src/errors/guard-error';
import { InternalError } from 'src/errors';
import { Validator } from 'src/utils/validator';

describe('[Unit] Tests for InternalError', () => {
describe('[Unit] Tests for Validator', () => {
it.each([
['condition is falsy', false],
['condition is null', null],
Expand All @@ -11,7 +11,7 @@ describe('[Unit] Tests for InternalError', () => {
['condition is empty', ''],
])('should throw an error when "%s"', (_, arg) => {
try {
InternalError.guard(arg, 'message');
Validator.guard(arg, 'message');
expect().fail('should throw an error');
} catch (error) {
expect(error).toBeInstanceOf(GuardError);
Expand All @@ -23,7 +23,7 @@ describe('[Unit] Tests for InternalError', () => {
['undefined', undefined],
])('should throw an error when value is %s', (_, arg) => {
try {
InternalError.notNull(arg, 'message');
Validator.notNull(arg, 'message');
expect().fail('should throw an error');
} catch (error) {
expect(error).toBeInstanceOf(GuardError);
Expand All @@ -35,7 +35,7 @@ describe('[Unit] Tests for InternalError', () => {
['blank', ' '],
])('should throw an error when value is %s', (_, arg) => {
try {
InternalError.notEmpty(arg, 'message');
Validator.notEmpty(arg, 'message');
expect().fail('should throw an error');
} catch (error) {
expect(error).toBeInstanceOf(GuardError);
Expand All @@ -44,7 +44,7 @@ describe('[Unit] Tests for InternalError', () => {

it('should throw an error when value array is empty', () => {
try {
InternalError.notEmptyArray([], 'message');
Validator.notEmptyArray([], 'message');
expect().fail('should throw an error');
} catch (error) {
expect(error).toBeInstanceOf(GuardError);
Expand All @@ -53,7 +53,7 @@ describe('[Unit] Tests for InternalError', () => {

it('should throw an error when value object is empty', () => {
try {
InternalError.notEmptyObject({}, 'message');
Validator.notEmptyObject({}, 'message');
expect().fail('should throw an error');
} catch (error) {
expect(error).toBeInstanceOf(GuardError);
Expand All @@ -65,7 +65,7 @@ describe('[Unit] Tests for InternalError', () => {
['string is duplicated', ['a', 'a']],
])('should throw an error when %s', (_, arg) => {
try {
InternalError.duplicatedArray<any>(arg, 'message');
Validator.duplicatedArray<any>(arg, 'message');
expect().fail('should throw an error');
} catch (error) {
expect(error).toBeInstanceOf(GuardError);
Expand Down
14 changes: 7 additions & 7 deletions src/agents/base-agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { Run } from 'openai/resources/beta/threads/runs/runs';

import { Agent, AgentOptions, AgentProps } from '../types/agent';
import { AgentFunction } from './function';
import { InternalError } from 'src/errors';
import { Validator } from '../utils/validator';

export class AgentOpenAI implements Agent {
private currThreadId: string | null = null;
Expand All @@ -20,11 +20,11 @@ export class AgentOpenAI implements Agent {
private static guardProps(props: AgentProps) {
const { agentId, poolingInterval } = props;

InternalError.notEmpty(agentId, 'Agent ID is required');
InternalError.notEmptyObject(props.openai, 'OpenAI instance is required');
Validator.notEmpty(agentId, 'Agent ID is required');
Validator.notEmptyObject(props.openai, 'OpenAI instance is required');

const isPoolingIntervalValid = !!poolingInterval && poolingInterval > 500;
InternalError.guard(
Validator.guard(
isPoolingIntervalValid,
'Pooling interval must be greater than 500ms'
);
Expand All @@ -34,15 +34,15 @@ export class AgentOpenAI implements Agent {
const hasAnyFunction = functions?.length > 0;
if (!hasAnyFunction) return;

InternalError.duplicatedArray(
Validator.duplicatedArray(
functions?.map((fn) => fn.name),
'Functions must have unique names'
);

const invalidFunction = functions?.find(
(fn) => !(fn instanceof AgentFunction)
);
InternalError.guard(
Validator.guard(
!invalidFunction,
`Function "${invalidFunction?.name}" is invalid`
);
Expand All @@ -53,7 +53,7 @@ export class AgentOpenAI implements Agent {
this._props.log = opts.log ?? this._props.log;
this._props.poolingInterval =
opts.poolingInterval ?? this._props.poolingInterval;
this._props.openai = opts.openai ?? this._props.openai;
this._props.openai = new OpenAI(opts.openai) ?? this._props.openai;

this.guardFunctions(opts);

Expand Down
9 changes: 4 additions & 5 deletions src/agents/sns-function.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import { SNSClient, PublishCommand } from '@aws-sdk/client-sns';

import { AgentFunction } from './function';
import { InternalError } from '../errors';
import { Validator } from '../utils/validator';
import { SnsPublishFunctionOptions } from '../types/function';

export class SnsPublishFunction extends AgentFunction {
protected handler: SNSClient;
protected topicArn: string;

private static guardSnsOptions({ sns }: SnsPublishFunctionOptions) {
InternalError.notEmptyObject(sns?.handler, 'SNS handler is required');
InternalError.notEmpty(sns?.topicArn, 'SNS topic ARN is required');
Validator.notEmpty(sns?.topicArn, 'SNS topic ARN is required');
}

private static guardFunctionName(name: string) {
InternalError.guard(
Validator.guard(
name.startsWith('cloud.'),
'Function name must start with "cloud."'
);
Expand All @@ -26,7 +25,7 @@ export class SnsPublishFunction extends AgentFunction {

super(opts);

this.handler = opts.sns.handler;
this.handler = new SNSClient(opts.sns.client);
this.topicArn = opts.sns.topicArn;
}

Expand Down
11 changes: 8 additions & 3 deletions src/types/agent.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import OpenAI from 'openai';
import { ClientOptions, OpenAI } from 'openai';
import { AgentFunction } from '..';

export type Agent = {
Expand All @@ -10,14 +10,19 @@ export type RequiredAgentOptions = {
};

export type OptionalAgentOptions = {
openai?: OpenAI;
/**
* The OpenAI instance configuration.
* @see https://github.com/openai/openai-node?tab=readme-ov-file#usage
*/
openai?: ClientOptions;
functions?: AgentFunction[];
log?: boolean;
poolingInterval?: number; // in milliseconds
};

export type AgentOptions = RequiredAgentOptions & OptionalAgentOptions;

export type AgentProps = Omit<Required<AgentOptions>, 'functions'> & {
export type AgentProps = Omit<Required<AgentOptions>, 'functions' | 'openai'> & {
functions: Map<string, AgentFunction>;
openai: OpenAI;
};
7 changes: 4 additions & 3 deletions src/types/function.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { SNSClient } from '@aws-sdk/client-sns';
import { SNSClientConfig } from '@aws-sdk/client-sns';

export type IFunction = {
execute(args: object): Promise<any>;
Expand Down Expand Up @@ -78,9 +78,10 @@ export type OpenaiFunctionSchema = {
export type SnsPublishFunctionOptions = FunctionOptions & {
sns: {
/**
* The handler function for the SNS publish.
* The SNS client configuration.
* @see https://docs.aws.amazon.com/AWSJavaScriptSDK/v3/latest/clients/client-sns/interfaces/snsclientconfig.html
*/
handler: SNSClient;
client?: SNSClientConfig;
/**
* The ARN (Amazon Resource Name) of the SNS topic.
*/
Expand Down
2 changes: 2 additions & 0 deletions src/utils/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
export * from './validator';

export const deepClone = <T = any>(obj: object): T =>
JSON.parse(JSON.stringify(obj));
4 changes: 2 additions & 2 deletions src/errors/index.ts → src/utils/validator.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { GuardError } from './guard-error';
import { GuardError } from '../errors/guard-error';

/**
* Represents an internal error of the library.
*/
export class InternalError {
export class Validator {
/**
* Guards against a specified condition and throws a `GuardError` if the condition is not met.
* @param condition - The condition to be checked.
Expand Down

0 comments on commit 1f2fd71

Please sign in to comment.