Skip to content

Commit

Permalink
feat(chat): add onFinish hook to RAGChat for final response handling
Browse files Browse the repository at this point in the history
- Implemented an `onFinish` callback.
- Added a new test.
closes upstash#96
  • Loading branch information
ronal2do committed Dec 21, 2024
1 parent af2e30c commit e3b0e55
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/rag-chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,59 @@ describe("RAG Chat with non-embedding db", () => {
expect(called).toBeTrue();
});
});

describe("RAGChat with onFinish hook", () => {
const namespace = "result-metadata";
const vector = new Index({
token: process.env.UPSTASH_VECTOR_REST_TOKEN!,
url: process.env.UPSTASH_VECTOR_REST_URL!,
});

const ragChat = new RAGChat({
vector,
namespace,
streaming: true,
model: upstash("meta-llama/Meta-Llama-3-8B-Instruct"),
});

afterAll(async () => {
await vector.reset({ namespace });
await vector.deleteNamespace(namespace);
});

test(
"should call onFinish callback with correct output",
async () => {
// Set up test data
await ragChat.context.add({
type: "text",
data: "Tokyo is the Capital of Japan.",
options: { namespace, metadata: { unit: "Samurai" } },
});
await ragChat.context.add({
type: "text",
data: "Shakuhachi is a traditional wind instrument",
options: { namespace, metadata: { unit: "Shakuhachi" } },
});
await awaitUntilIndexed(vector);

// Create a spy for onFinish callback
let onFinishCalled = false;
let capturedOutput = "";

const result = await ragChat.chat<{ unit: string }>("Where is the capital of Japan?", {
namespace,
onFinish: ({ output }) => {
onFinishCalled = true;
capturedOutput = output;
},
});

expect(onFinishCalled).toBe(true);
expect(capturedOutput).toBe(result.output);
expect(result.output.toLowerCase()).toContain("tokyo");
expect(result.metadata).toEqual([{ unit: "Samurai" }, { unit: "Shakuhachi" }]);
},
{ timeout: 30_000 }
);
});
4 changes: 4 additions & 0 deletions src/rag-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ export class RAGChat {
if (!optionsWithDefault.disableHistory) {
await this.addAssistantMessageToHistory(output, optionsWithDefault);
}
if (optionsWithDefault.onFinish) {
optionsWithDefault.onFinish({ output });
}
},
},
this.debug
Expand Down Expand Up @@ -291,6 +294,7 @@ export class RAGChat {
promptFn: isRagDisabledAndPromptFunctionMissing
? DEFAULT_PROMPT_WITHOUT_RAG
: (options?.promptFn ?? this.config.prompt),
onFinish: options?.onFinish,
};
}
}
4 changes: 4 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ export type ChatOptions = {
* Must be provided if the Vector Database doesn't have default embeddings.
*/
embedding?: number[];
/**
* Hook to access the final response and modify as you wish.
*/
onFinish?: ({ output }: { output: string }) => void;
} & CommonChatAndRAGOptions;

export type PrepareChatResult = { data: string; id: string; metadata: unknown }[];
Expand Down

0 comments on commit e3b0e55

Please sign in to comment.