Skip to content

Commit

Permalink
Merge pull request #53 from boldare/fix/run-resolver
Browse files Browse the repository at this point in the history
Fix: run resolver for multiple function callings
  • Loading branch information
sebastianmusial authored Apr 11, 2024
2 parents 65d0afb + 2cbdffa commit 70683f2
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 76 deletions.
1 change: 1 addition & 0 deletions .env.dist
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
OPENAI_API_KEY=
# Assistant ID - leave it empty if you don't have an assistant yet
ASSISTANT_ID=
ASSISTANT_IS_LOGGER_ENABLED=

# Agents:
# -------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion apps/api/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ async function bootstrap() {
const globalPrefix = 'api';
const config = new DocumentBuilder()
.setTitle('@boldare/openai-assistant')
.setVersion('1.0.1')
.setVersion('1.0.2')
.build();
const document = SwaggerModule.createDocument(app, config);

Expand Down
2 changes: 1 addition & 1 deletion libs/openai-assistant/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@boldare/openai-assistant",
"description": "NestJS library for building chatbot solutions based on the OpenAI Assistant API",
"version": "1.0.1",
"version": "1.0.2",
"private": false,
"dependencies": {
"tslib": "^2.3.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ describe('AssistantService', () => {
jest
.spyOn(aiService.provider.beta.assistants, 'update')
.mockRejectedValueOnce('error');
jest.spyOn(assistantService, 'create').mockResolvedValueOnce(undefined);
jest.spyOn(assistantService, 'create').mockResolvedValueOnce({} as Assistant);

await assistantService.init();

Expand All @@ -97,7 +97,7 @@ describe('AssistantService', () => {
.spyOn(configService, 'get')
.mockReturnValue({ ...assistantConfigMock, id: '' });

jest.spyOn(assistantService, 'create').mockResolvedValueOnce(undefined);
jest.spyOn(assistantService, 'create').mockResolvedValueOnce({} as Assistant);

await assistantService.init();

Expand Down
9 changes: 6 additions & 3 deletions libs/openai-assistant/src/lib/assistant/assistant.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export class AssistantService {
};
}

async init(): Promise<void> {
async init(): Promise<Assistant> {
const { id, options } = this.assistantConfig.get();

if (!id) {
Expand All @@ -43,16 +43,17 @@ export class AssistantService {
this.getParams(),
options,
);
return this.assistant;
} catch (e) {
await this.create();
return await this.create();
}
}

async update(params: Partial<AssistantCreateParams>): Promise<void> {
this.assistant = await this.assistants.update(this.assistant.id, params);
}

async create(): Promise<void> {
async create(): Promise<Assistant> {
const { options } = this.assistantConfig.get();
const params = this.getParams();
this.assistant = await this.assistants.create(params, options);
Expand All @@ -63,6 +64,8 @@ export class AssistantService {

this.logger.log(`Created new assistant (${this.assistant.id})`);
await this.assistantMemoryService.saveAssistantId(this.assistant.id);

return this.assistant;
}

async updateFiles(fileNames?: string[]): Promise<Assistant> {
Expand Down
96 changes: 54 additions & 42 deletions libs/openai-assistant/src/lib/chat/chat.gateway.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,54 @@ export class ChatGateway implements OnGatewayConnection {
this.logger = new Logger(ChatGateway.name);
}

log(message: string): void {
try {
const isLoggerEnabled: string = JSON.parse(
(process.env['ASSISTANT_IS_LOGGER_ENABLED'] || 'false').toLowerCase(),
);

if (isLoggerEnabled) {
this.logger.log(message);
}
} catch (error) {
this.logger.error('"ASSISTANT_IS_LOGGER_ENABLED" should be boolean');
}
}

async handleConnection() {
this.logger.log('Client connected');
this.log('Client connected');
}

getCallbacks(socketId: string): ChatCallCallbacks {
return {
[ChatEvents.MessageCreated]: this.emitMessageCreated.bind(this, socketId),
[ChatEvents.MessageDelta]: this.emitMessageDelta.bind(this, socketId),
[ChatEvents.MessageDone]: this.emitMessageDone.bind(this, socketId),
[ChatEvents.TextCreated]: this.emitTextCreated.bind(this, socketId),
[ChatEvents.TextDelta]: this.emitTextDelta.bind(this, socketId),
[ChatEvents.TextDone]: this.emitTextDone.bind(this, socketId),
[ChatEvents.MessageCreated]: eventData =>
this.emitMessageCreated(socketId, eventData),
[ChatEvents.MessageDelta]: eventData =>
this.emitMessageDelta(socketId, eventData),
[ChatEvents.MessageDone]: eventData =>
this.emitMessageDone(socketId, eventData),
[ChatEvents.TextCreated]: eventData =>
this.emitTextCreated(socketId, eventData),
[ChatEvents.TextDelta]: eventData =>
this.emitTextDelta(socketId, eventData),
[ChatEvents.TextDone]: eventData =>
this.emitTextDone(socketId, eventData),
[ChatEvents.ToolCallCreated]: this.emitToolCallCreated.bind(
this,
socketId,
),
[ChatEvents.ToolCallDelta]: this.emitToolCallDelta.bind(this, socketId),
[ChatEvents.ToolCallDone]: this.emitToolCallDone.bind(this, socketId),
[ChatEvents.ImageFileDone]: this.emitImageFileDone.bind(this, socketId),
[ChatEvents.RunStepCreated]: this.emitRunStepCreated.bind(this, socketId),
[ChatEvents.RunStepDelta]: this.emitRunStepDelta.bind(this, socketId),
[ChatEvents.RunStepDone]: this.emitRunStepDone.bind(this, socketId),
[ChatEvents.ToolCallDelta]: eventData =>
this.emitToolCallDelta(socketId, eventData),
[ChatEvents.ToolCallDone]: eventData =>
this.emitToolCallDone(socketId, eventData),
[ChatEvents.ImageFileDone]: eventData =>
this.emitImageFileDone(socketId, eventData),
[ChatEvents.RunStepCreated]: eventData =>
this.emitRunStepCreated(socketId, eventData),
[ChatEvents.RunStepDelta]: eventData =>
this.emitRunStepDelta(socketId, eventData),
[ChatEvents.RunStepDone]: eventData =>
this.emitRunStepDone(socketId, eventData),
};
}

Expand All @@ -69,15 +95,15 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() request: ChatCallDto,
@ConnectedSocket() socket: Socket,
) {
this.logger.log(
this.log(
`Socket "${ChatEvents.CallStart}" | threadId ${request.threadId} | files: ${request?.file_ids?.join(', ')} | content: ${request.content}`,
);

const callbacks: ChatCallCallbacks = this.getCallbacks(socket.id);
const message = await this.chatsService.call(request, callbacks);

this.server?.to(socket.id).emit(ChatEvents.CallDone, message);
this.logger.log(
this.log(
`Socket "${ChatEvents.CallDone}" | threadId ${message.threadId} | content: ${message.content}`,
);
}
Expand All @@ -87,7 +113,7 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: MessageCreatedPayload,
) {
this.server.to(socketId).emit(ChatEvents.MessageCreated, data);
this.logger.log(
this.log(
`Socket "${ChatEvents.MessageCreated}" | threadId: ${data.message.thread_id}`,
);
}
Expand All @@ -97,7 +123,7 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: MessageDeltaPayload,
) {
this.server.to(socketId).emit(ChatEvents.MessageDelta, data);
this.logger.log(
this.log(
`Socket "${ChatEvents.MessageDelta}" | threadId: ${data.message.thread_id}`,
);
}
Expand All @@ -107,7 +133,7 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: MessageDonePayload,
) {
this.server.to(socketId).emit(ChatEvents.MessageDone, data);
this.logger.log(
this.log(
`Socket "${ChatEvents.MessageDone}" | threadId: ${data.message.thread_id}`,
);
}
Expand All @@ -117,19 +143,17 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: TextCreatedPayload,
) {
this.server.to(socketId).emit(ChatEvents.TextCreated, data);
this.logger.log(`Socket "${ChatEvents.TextCreated}" | ${data.text.value}`);
this.log(`Socket "${ChatEvents.TextCreated}" | ${data.text.value}`);
}

async emitTextDelta(socketId: string, @MessageBody() data: TextDeltaPayload) {
this.server.to(socketId).emit(ChatEvents.TextDelta, data);
this.logger.log(
`Socket "${ChatEvents.TextDelta}" | ${data.textDelta.value}`,
);
this.log(`Socket "${ChatEvents.TextDelta}" | ${data.textDelta.value}`);
}

async emitTextDone(socketId: string, @MessageBody() data: TextDonePayload) {
this.server.to(socketId).emit(ChatEvents.TextDone, data);
this.logger.log(
this.log(
`Socket "${ChatEvents.TextDone}" | threadId: ${data.message?.thread_id} | ${data.text.value}`,
);
}
Expand All @@ -139,9 +163,7 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: ToolCallCreatedPayload,
) {
this.server.to(socketId).emit(ChatEvents.ToolCallCreated, data);
this.logger.log(
`Socket "${ChatEvents.ToolCallCreated}": ${data.toolCall.id}`,
);
this.log(`Socket "${ChatEvents.ToolCallCreated}": ${data.toolCall.id}`);
}

codeInterpreterHandler(
Expand Down Expand Up @@ -185,9 +207,7 @@ export class ChatGateway implements OnGatewayConnection {
socketId: string,
@MessageBody() data: ToolCallDeltaPayload,
) {
this.logger.log(
`Socket "${ChatEvents.ToolCallDelta}": ${data.toolCall.id}`,
);
this.log(`Socket "${ChatEvents.ToolCallDelta}": ${data.toolCall.id}`);

switch (data.toolCallDelta.type) {
case 'code_interpreter':
Expand All @@ -211,46 +231,38 @@ export class ChatGateway implements OnGatewayConnection {
@MessageBody() data: ToolCallDonePayload,
) {
this.server.to(socketId).emit(ChatEvents.ToolCallDone, data);
this.logger.log(`Socket "${ChatEvents.ToolCallDone}": ${data.toolCall.id}`);
this.log(`Socket "${ChatEvents.ToolCallDone}": ${data.toolCall.id}`);
}

async emitImageFileDone(
socketId: string,
@MessageBody() data: ImageFileDonePayload,
) {
this.server.to(socketId).emit(ChatEvents.ImageFileDone, data);
this.logger.log(
`Socket "${ChatEvents.ImageFileDone}": ${data.content.file_id}`,
);
this.log(`Socket "${ChatEvents.ImageFileDone}": ${data.content.file_id}`);
}

async emitRunStepCreated(
socketId: string,
@MessageBody() data: RunStepCreatedPayload,
) {
this.server.to(socketId).emit(ChatEvents.RunStepCreated, data);
this.logger.log(
`Socket "${ChatEvents.RunStepCreated}": ${data.runStep.status}`,
);
this.log(`Socket "${ChatEvents.RunStepCreated}": ${data.runStep.status}`);
}

async emitRunStepDelta(
socketId: string,
@MessageBody() data: RunStepDeltaPayload,
) {
this.server.to(socketId).emit(ChatEvents.RunStepDelta, data);
this.logger.log(
`Socket "${ChatEvents.RunStepDelta}": ${data.runStep.status}`,
);
this.log(`Socket "${ChatEvents.RunStepDelta}": ${data.runStep.status}`);
}

async emitRunStepDone(
socketId: string,
@MessageBody() data: RunStepDonePayload,
) {
this.server.to(socketId).emit(ChatEvents.RunStepDone, data);
this.logger.log(
`Socket "${ChatEvents.RunStepDone}": ${data.runStep.status}`,
);
this.log(`Socket "${ChatEvents.RunStepDone}": ${data.runStep.status}`);
}
}
15 changes: 12 additions & 3 deletions libs/openai-assistant/src/lib/chat/chat.service.spec.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import { Test } from '@nestjs/testing';
import { APIPromise } from 'openai/core';
import { Message, Run } from 'openai/resources/beta/threads';
import { AssistantStream } from 'openai/lib/AssistantStream';
import { AiModule } from './../ai/ai.module';
import { ChatModule } from './chat.module';
import { ChatService } from './chat.service';
import { ChatHelpers } from './chat.helpers';
import { ChatCallDto } from './chat.model';
import { AssistantStream } from 'openai/lib/AssistantStream';
import { RunService } from '../run/run.service';

jest.mock('../stream/stream.utils', () => ({
assistantStreamEventHandler: jest.fn(),
}));

describe('ChatService', () => {
let chatService: ChatService;
let chatbotHelpers: ChatHelpers;
let runService: RunService;

beforeEach(async () => {
const moduleRef = await Test.createTestingModule({
Expand All @@ -19,18 +25,21 @@ describe('ChatService', () => {

chatService = moduleRef.get<ChatService>(ChatService);
chatbotHelpers = moduleRef.get<ChatHelpers>(ChatHelpers);
runService = moduleRef.get<RunService>(RunService);

jest.spyOn(runService, 'resolve').mockReturnThis();

jest
.spyOn(chatbotHelpers, 'getAnswer')
.mockReturnValue(Promise.resolve('Hello response') as Promise<string>);


jest
.spyOn(chatService.threads.messages, 'create')
.mockReturnValue({} as APIPromise<Message>);

jest.spyOn(chatService, 'assistantStream').mockReturnValue({
jest.spyOn(chatService, 'getAssistantStream').mockReturnValue({
finalRun: jest.fn(),
on: () => jest.fn(),
} as unknown as Promise<AssistantStream>);
});

Expand Down
Loading

0 comments on commit 70683f2

Please sign in to comment.