Skip to content

Commit

Permalink
Feature/channel prompts (#44)
Browse files Browse the repository at this point in the history
* Add support for custom prompts for query relaxation and generation
- uses db config rules to target specific workspaces, channels or Slack apps
  • Loading branch information
bdb-dd authored May 13, 2024
1 parent c627d92 commit fab3571
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 45 deletions.
2 changes: 1 addition & 1 deletion apps/admin/src/components/BotReplyMetadata.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ const BotReplyDetails: React.FC<BotReplyDetailsProps> = ({ message }) => {
{message?.reactions?.map((query: Reaction, index: number) => (
<li key={query.name}>
<pre>
{query.name} ({query.count})
{query.name} ({query.count})
</pre>
</li>
))}
Expand Down
59 changes: 59 additions & 0 deletions apps/admin/src/components/RagPromptView.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import React from "react";
import { Link, Box } from "@mui/material";
import ReactMarkdown from "react-markdown";
import { LightAsync as SyntaxHighlighter } from "react-syntax-highlighter";
import { github } from "react-syntax-highlighter/dist/esm/styles/hljs";
import { DocsBotReplyMessage } from "../models/Models";
import ErrorBoundary from "./ErrorBoundary";

export type Params = DocsBotReplyMessage & {};

const RagSourceView: React.FC<Params> = ({ message }) => {
const components = {
code({ node, inline, className, children, ...props }) {
const match = /language-(\w+)/.exec(className || "");
return !inline && match ? (
<SyntaxHighlighter
children={String(children).replace(/\n$/, "")}
style={github}
language={match[1]}
PreTag="div"
{...props}
/>
) : (
<code className={className} {...props}>
{children}
</code>
);
},
};

return (
<Box sx={{ flexWrap: "wrap" }}>
<React.Fragment>
<ErrorBoundary>
{" "}
<ul>
{Object.entries(message?.content?.prompts || {}).map(
([key, value], index) => (
<li key={key}>
<Box flexDirection="column">
<span>
Prompt #{index + 1}: {key}
</span>
<ReactMarkdown components={components}>
{value}
</ReactMarkdown>
</Box>
<hr />
</li>
),
)}
</ul>
</ErrorBoundary>
</React.Fragment>
</Box>
);
};

export default RagSourceView;
6 changes: 5 additions & 1 deletion apps/admin/src/components/ThreadViewPane.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import { Link, List, ListItem, Box, Tab } from "@mui/material";
import { TabContext, TabList, TabPanel } from "@mui/lab";
import { Props, Message } from "../models/Models";
import { useThreadReplies } from "../hooks/useThreadReplies";
import { RagPipelineResult } from "@digdir/assistants";
import BotReplyContent from "./BotReplyContent";
import BotReplyMetadata from "./BotReplyMetadata";
import RagSourceView from "./RagSourceView";
import RagPromptView from "./RagPromptView";

const ThreadViewPane: React.FC<Props> = ({
channelId,
Expand Down Expand Up @@ -55,6 +55,7 @@ const ThreadViewPane: React.FC<Props> = ({
<Tab label="English" value="english" />
<Tab label="Original" value="original" />
<Tab label="Sources" value="sources" />
<Tab label="Prompts" value="prompts" />
</TabList>
</Box>
<TabPanel value="english" style={{ padding: "0px 8px" }}>
Expand Down Expand Up @@ -94,6 +95,9 @@ const ThreadViewPane: React.FC<Props> = ({
<TabPanel value="sources" style={{ padding: "0px 8px" }}>
<RagSourceView message={threadMessages[0]} />
</TabPanel>
<TabPanel value="prompts" style={{ padding: "0px 8px" }}>
<RagPromptView message={threadMessages[0]} />
</TabPanel>
</TabContext>
</Box>
);
Expand Down
18 changes: 1 addition & 17 deletions apps/admin/src/models/Models.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { z } from "zod";
import { RagPipelineResult } from "@digdir/assistant-lib"

const ReactionSchema = z.object({
name: z.string(),
Expand Down Expand Up @@ -40,23 +41,6 @@ const DocsUserQuerySchema = z.object({
});
export type DocsUserQuery = z.infer<typeof DocsUserQuerySchema>;

const RagPipelineResultSchema = z.object({
original_user_query: z.string(),
english_user_query: z.string(),
user_query_language_name: z.string(),
english_answer: z.string(),
translated_answer: z.string(),
rag_success: z.boolean(),
search_queries: z.array(z.string()),
source_urls: z.array(z.string()),
source_documents: z.array(z.any()), // Assuming we don't have a specific structure for documents
relevant_urls: z.array(z.string()),
not_loaded_urls: z.array(z.string()),
durations: z.record(z.string(), z.number()), // Assuming durations is an object with string keys and number values
});

export type RagPipelineResult = z.infer<typeof RagPipelineResultSchema>;

export type RagPipelineMessage = Message & {
content: RagPipelineResult;
};
Expand Down
2 changes: 1 addition & 1 deletion apps/slack-app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"codestyle:check": "prettier src/ --check",
"codestyle:fix": "prettier src/ --write",
"build": "tsc -p .",
"watch": "nodemon src/app.ts",
"dev": "nodemon src/app.ts",
"start": "node dist/src/app.js",
"go": "yarn run build && yarn start"
},
Expand Down
36 changes: 32 additions & 4 deletions apps/slack-app/src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ import {
import { lookupConfig } from './utils/bot-config';
import { envVar, round, lapTimer, timeoutPromise } from '@digdir/assistant-lib';
import { userInputAnalysis, UserQueryAnalysis } from '@digdir/assistant-lib';
import { ragPipeline, RagPipelineResult } from '@digdir/assistant-lib';
import { ragPipeline, qaTemplate } from '@digdir/assistant-lib';
import { splitToSections, isNullOrEmpty } from '@digdir/assistant-lib';
import { botLog, BotLogEntry, updateReactions } from './utils/bot-log';
import OpenAI from 'openai';
import { isNumber } from 'remeda';
import { RagPipelineResult } from '@digdir/assistant-lib';

const expressReceiver = new ExpressReceiver({
signingSecret: envVar('SLACK_BOT_SIGNING_SECRET'),
Expand Down Expand Up @@ -92,6 +93,27 @@ app.message(async ({ message, say }) => {
true,
);

const queryRelaxCustom = await lookupConfig(slackApp, srcEvtContext, 'prompt.rag.queryRelax', '');

const promptRagQueryRelax =
`You have access to a search API that returns relevant documentation.
Your task is to generate an array of up to 7 search queries that are relevant to this question.
Use a variation of related keywords and synonyms for the queries, trying to be as general as possible.
Include as many queries as you can think of, including and excluding terms.
For example, include queries like ['keyword_1 keyword_2', 'keyword_1', 'keyword_2'].
Be creative. The more queries you include, the more likely you are to find relevant results.
` + queryRelaxCustom;

const promptRagGenerateCustom = await lookupConfig(
slackApp,
srcEvtContext,
'prompt.rag.generate',
'',
);
const promptRagGenerate = qaTemplate(promptRagGenerateCustom || '');

if (envVar('LOG_LEVEL') == 'debug') {
console.log(`slackApp:\n${JSON.stringify(slackApp)}`);
console.log(`slackContext:\n${JSON.stringify(srcEvtContext)}`);
Expand Down Expand Up @@ -259,6 +281,8 @@ app.message(async ({ message, say }) => {
ragResponse = await ragPipeline(
stage1Result.questionTranslatedToEnglish,
stage1Result.userInputLanguageName,
promptRagQueryRelax || '',
promptRagGenerate || '',
updateSlackMsgCallback(app, firstThreadTs),
translatedMsgCallback,
);
Expand All @@ -279,6 +303,10 @@ app.message(async ({ message, say }) => {
relevant_urls: ragResponse.relevant_urls,
not_loaded_urls: ragResponse.not_loaded_urls || [],
rag_success: !!ragResponse.rag_success,
prompts: {
queryRelax: promptRagQueryRelax || '',
generate: promptRagGenerate || '',
},
};
} catch (e) {
if (e instanceof OpenAI.APIConnectionError) {
Expand Down Expand Up @@ -355,8 +383,8 @@ app.message(async ({ message, say }) => {
durations: ragResponse.durations,
step_name: 'rag_with_typesense',
content: {
error: finalizeError,
...payload,
error: finalizeError,
},
content_type: 'docs_bot_error',
};
Expand Down Expand Up @@ -450,7 +478,7 @@ async function handleReactionEvents(eventBody: any) {
channel: itemContext.channel_id,
timestamp: itemContext.ts_date + '.' + itemContext.ts_time,
};

const botInfo = await app.client.auth.test();
const botId = botInfo.user_id;

Expand All @@ -459,7 +487,7 @@ async function handleReactionEvents(eventBody: any) {
if (botId === eventUserId) {
console.log('Reaction on message from Assistant, will update reactions in DB');
} else {
console.log('Reaction was for something else, ignoring.');
console.log(`Reaction was for something else, ignoring. Bot ID: ${botId}, reaction was on item with item_user: ${eventUserId}`);
return;
}
}
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"build:assistant-lib": "cd packages/assistant-lib && yarn codestyle:fix && yarn build && cd ../..",
"build:slack-app": "cd apps/slack-app && yarn codestyle:fix && yarn build && cd ../..",
"build:admin": "cd apps/admin && yarn codestyle:fix && yarn build && cd ../..",
"build": "export && yarn clean && yarn build:assistant-lib && yarn build:admin && yarn build:slack-app",
"build": "yarn clean && yarn build:assistant-lib && yarn build:admin && yarn build:slack-app",
"run:slack-app": "node ./apps/slack-app/dist/src/app.js"
},
"devDependencies": {
Expand Down
3 changes: 2 additions & 1 deletion packages/assistant-lib/src/docs/prompts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ Catch all category if none of the above categories matches well.
3. Finally, return the category, original language code and name.
`;

export function qaTemplate() {
export function qaTemplate(promptRagGenerate: string = "") {
const translate_hint =
"\nOnly return the helpful answer below, along with relevant source code examples when possible.\n";

const prompt_text =
`Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
${promptRagGenerate}
Context: {context}
Expand Down
13 changes: 6 additions & 7 deletions packages/assistant-lib/src/docs/query-relaxation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,18 @@ const openaiClientInstance = Instructor({

const QueryRelaxationSchema = z.object({
searchQueries: z.array(z.string()),
promptUsed: z.string(),
});

export type QueryRelaxation = z.infer<typeof QueryRelaxationSchema> | null;

export async function queryRelaxation(
user_input: string,
promptRagQueryRelax: string = "",
): Promise<QueryRelaxation> {
let query_result: QueryRelaxation | null = null;

const prompt = `You have access to a search API that returns relevant documentation.
Your task is to generate an array of up to 7 search queries that are relevant to this question.
Use a variation of related keywords and synonyms for the queries, trying to be as general as possible.
Include as many queries as you can think of, including and excluding terms.
For example, include queries like ['keyword_1 keyword_2', 'keyword_1', 'keyword_2'].
Be creative. The more queries you include, the more likely you are to find relevant results.`;
const prompt = promptRagQueryRelax;

if (envVar("USE_AZURE_OPENAI_API", false) == "true") {
// query_result = await azureClient.chat.completions.create({
Expand All @@ -52,6 +48,9 @@ export async function queryRelaxation(
console.log(
`${stage_name} model name: ${envVar("OPENAI_API_MODEL_NAME", "")}`,
);
if (envVar("LOG_LEVEL") == "debug") {
console.log(`prompt.rag.queryRelax: \n${prompt}`);
}
query_result = await openaiClientInstance.chat.completions.create({
model: envVar("OPENAI_API_MODEL_NAME"),
response_model: {
Expand Down
32 changes: 26 additions & 6 deletions packages/assistant-lib/src/docs/rag.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { qaTemplate } from "./prompts";
import { queryRelaxation } from "./query-relaxation";
import {
lookupSearchPhrasesSimilar,
Expand All @@ -24,12 +23,18 @@ const RagContextRefsSchema = z.object({
source: z.string().min(1),
});

const RagPromptReplySchema = z.object({
const RagGenerateResultSchema = z.object({
helpful_answer: z.string(),
i_dont_know: z.boolean(),
relevant_contexts: z.array(RagContextRefsSchema),
});

const RagPromptSchema = z.object({
queryRelax: z.string(),
generate: z.string(),
});
export type RagPrompt = z.infer<typeof RagPromptSchema>;

const RagPipelineResultSchema = z.object({
original_user_query: z.string(),
english_user_query: z.string(),
Expand All @@ -39,17 +44,20 @@ const RagPipelineResultSchema = z.object({
rag_success: z.boolean(),
search_queries: z.array(z.string()),
source_urls: z.array(z.string()),
source_documents: z.array(z.any()), // Assuming we don't have a specific structure for documents
source_documents: z.array(z.any()),
relevant_urls: z.array(z.string()),
not_loaded_urls: z.array(z.string()),
durations: z.record(z.string(), z.number()), // Assuming durations is an object with string keys and number values
durations: z.record(z.string(), z.number()),
prompts: RagPromptSchema.optional(),
});

export type RagPipelineResult = z.infer<typeof RagPipelineResultSchema>;

export async function ragPipeline(
user_input: string,
user_query_language_name: string,
promptRagQueryRelax: string,
promptRagGenerate: string,
stream_callback_msg1: any = null,
stream_callback_msg2: any = null,
): Promise<RagPipelineResult> {
Expand All @@ -70,7 +78,10 @@ export async function ragPipeline(
const total_start = performance.now();
var start = total_start;

const extract_search_queries = await queryRelaxation(user_input);
const extract_search_queries = await queryRelaxation(
user_input,
promptRagQueryRelax,
);
durations.generate_searches = round(lapTimer(total_start));

if (envVar("LOG_LEVEL") === "debug") {
Expand Down Expand Up @@ -260,10 +271,15 @@ export async function ragPipeline(
let relevant_sources: string[] = [];

const contextYaml = yaml.dump(loadedDocs);
const fullPrompt = qaTemplate()
const partialPrompt = promptRagGenerate;
const fullPrompt = partialPrompt
.replace("{context}", contextYaml)
.replace("{question}", user_input);

if (envVar("LOG_LEVEL") == "debug") {
console.log(`rag prompt:\n${partialPrompt}`);
}

if (typeof stream_callback_msg1 !== "function") {
if (envVar("USE_AZURE_OPENAI_API") === "true") {
// const chatResponse = await azureClient.chat.completions.create({
Expand Down Expand Up @@ -348,6 +364,10 @@ export async function ragPipeline(
relevant_urls: relevant_sources,
not_loaded_urls: notLoadedUrls,
durations,
prompts: {
queryRelax: promptRagQueryRelax || "",
generate: promptRagGenerate || "",
},
};

return response;
Expand Down
Loading

0 comments on commit fab3571

Please sign in to comment.