From 810afc2d8335b67a34d2093dde1cd269ff969e7c Mon Sep 17 00:00:00 2001 From: Oleksii-Klimov <133792808+Oleksii-Klimov@users.noreply.github.com> Date: Fri, 5 Jan 2024 13:48:12 +0000 Subject: [PATCH] feat: support addon name override via request (#47) * Support addon name override via request. --- .../application/assistant_application.py | 74 +++++++++++-------- .../application/assistant_callback.py | 42 ++++++++--- aidial_assistant/application/prompts.py | 6 +- aidial_assistant/utils/open_ai_plugin.py | 4 +- tests/unit_tests/application/test_prompts.py | 8 +- 5 files changed, 84 insertions(+), 50 deletions(-) diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index 34bb975..69fbc9c 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -5,6 +5,7 @@ from aidial_sdk.chat_completion.base import ChatCompletion from aidial_sdk.chat_completion.request import Addon, Message, Request, Role from aidial_sdk.chat_completion.response import Response +from pydantic import BaseModel from aidial_assistant.application.addons_dialogue_limiter import ( AddonsDialogueLimiter, @@ -39,6 +40,11 @@ logger = logging.getLogger(__name__) +class AddonReference(BaseModel): + name: str | None + url: str + + def _get_request_args(request: Request) -> dict[str, str]: args = { "model": request.model, @@ -51,14 +57,18 @@ def _get_request_args(request: Request) -> dict[str, str]: return {k: v for k, v in args.items() if v is not None} -def _validate_addons(addons: list[Addon] | None): - if addons and any(addon.url is None for addon in addons): - for index, addon in enumerate(addons): - if addon.url is None: - raise RequestParameterValidationError( - f"Missing required addon url at index {index}.", - param="addons", - ) +def _validate_addons(addons: list[Addon] | None) -> list[AddonReference]: + addon_references: list[AddonReference] = [] + for index, addon in enumerate(addons or []): + if addon.url is None: + raise RequestParameterValidationError( + f"Missing required addon url at index {index}.", + param="addons", + ) + + addon_references.append(AddonReference(name=addon.name, url=addon.url)) + + return addon_references def _validate_messages(messages: list[Message]) -> None: @@ -73,11 +83,6 @@ def _validate_messages(messages: list[Message]) -> None: ) -def _validate_request(request: Request) -> None: - _validate_messages(request.messages) - _validate_addons(request.addons) - - class AssistantApplication(ChatCompletion): def __init__(self, config_dir: Path): self.args = parse_args(config_dir) @@ -86,7 +91,8 @@ def __init__(self, config_dir: Path): async def chat_completion( self, request: Request, response: Response ) -> None: - _validate_request(request) + _validate_messages(request.messages) + addon_references = _validate_addons(request.addons) chat_args = self.args.openai_conf.dict() | _get_request_args(request) model = ModelClient( @@ -99,47 +105,53 @@ async def chat_completion( buffer_size=self.args.chat_conf.buffer_size, ) - addons: list[str] = ( - [addon.url for addon in request.addons] if request.addons else [] # type: ignore + token_source = AddonTokenSource( + request.headers, + (addon_reference.url for addon_reference in addon_references), ) - token_source = AddonTokenSource(request.headers, addons) - tools: dict[str, PluginInfo] = {} - tool_descriptions: dict[str, str] = {} - for addon in addons: - info = await get_open_ai_plugin_info(addon) - tools[info.ai_plugin.name_for_model] = PluginInfo( + addons: dict[str, PluginInfo] = {} + # DIAL Core has own names for addons, so in stages we need to map them to the names used by the user + addon_name_mapping: dict[str, str] = {} + for addon_reference in addon_references: + info = await get_open_ai_plugin_info(addon_reference.url) + addons[info.ai_plugin.name_for_model] = PluginInfo( info=info, auth=get_plugin_auth( info.ai_plugin.auth.type, info.ai_plugin.auth.authorization_type, - addon, + addon_reference.url, token_source, ), ) - tool_descriptions[info.ai_plugin.name_for_model] = ( - info.open_api.info.description # type: ignore - or info.ai_plugin.description_for_human - ) + if addon_reference.name: + addon_name_mapping[ + info.ai_plugin.name_for_model + ] = addon_reference.name # TODO: Add max_addons_dialogue_tokens as a request parameter max_addons_dialogue_tokens = 1000 command_dict: CommandDict = { RunPlugin.token(): lambda: RunPlugin( - model, tools, max_addons_dialogue_tokens + model, addons, max_addons_dialogue_tokens ), Reply.token(): Reply, } chain = CommandChain( model_client=model, name="ASSISTANT", command_dict=command_dict ) + addon_descriptions = { + name: addon.info.open_api.info.description + or addon.info.ai_plugin.description_for_human + for name, addon in addons.items() + } history = History( assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build( - tools=tool_descriptions + addons=addon_descriptions ), best_effort_template=MAIN_BEST_EFFORT_TEMPLATE.build( - tools=tool_descriptions + addons=addon_descriptions ), scoped_messages=parse_history(request.messages), ) @@ -154,7 +166,7 @@ async def chat_completion( choice = response.create_single_choice() choice.open() - callback = AssistantChainCallback(choice) + callback = AssistantChainCallback(choice, addon_name_mapping) finish_reason = FinishReason.STOP try: model_request_limiter = AddonsDialogueLimiter( diff --git a/aidial_assistant/application/assistant_callback.py b/aidial_assistant/application/assistant_callback.py index 4fbcb8d..b75ecab 100644 --- a/aidial_assistant/application/assistant_callback.py +++ b/aidial_assistant/application/assistant_callback.py @@ -17,23 +17,37 @@ class PluginNameArgCallback(ArgCallback): - def __init__(self, callback: Callable[[str], None]): + def __init__( + self, + callback: Callable[[str], None], + addon_name_mapping: dict[str, str], + ): super().__init__(0, callback) + self.addon_name_mapping = addon_name_mapping + + self._plugin_name = "" @override def on_arg(self, chunk: str): chunk = chunk.replace('"', "") - if len(chunk) > 0: - self.callback(chunk) + self._plugin_name += chunk @override def on_arg_end(self): - self.callback("(") + self.callback( + self.addon_name_mapping.get(self._plugin_name, self._plugin_name) + + "(" + ) class RunPluginArgsCallback(ArgsCallback): - def __init__(self, callback: Callable[[str], None]): + def __init__( + self, + callback: Callable[[str], None], + addon_name_mapping: dict[str, str], + ): super().__init__(callback) + self.addon_name_mapping = addon_name_mapping @override def on_args_start(self): @@ -43,20 +57,24 @@ def on_args_start(self): def arg_callback(self) -> ArgCallback: self.arg_index += 1 if self.arg_index == 0: - return PluginNameArgCallback(self.callback) + return PluginNameArgCallback(self.callback, self.addon_name_mapping) else: return ArgCallback(self.arg_index - 1, self.callback) class AssistantCommandCallback(CommandCallback): - def __init__(self, stage: Stage): + def __init__(self, stage: Stage, addon_name_mapping: dict[str, str]): self.stage = stage + self.addon_name_mapping = addon_name_mapping + self._args_callback = ArgsCallback(self._on_stage_name) @override def on_command(self, command: str): if command == RunPlugin.token(): - self._args_callback = RunPluginArgsCallback(self._on_stage_name) + self._args_callback = RunPluginArgsCallback( + self._on_stage_name, self.addon_name_mapping + ) else: self._on_stage_name(command) @@ -109,15 +127,19 @@ def on_result(self, chunk: str): class AssistantChainCallback(ChainCallback): - def __init__(self, choice: Choice): + def __init__(self, choice: Choice, addon_name_mapping: dict[str, str]): self.choice = choice + self.addon_name_mapping = addon_name_mapping + self._invocations: list[Invocation] = [] self._invocation_index: int = -1 self._discarded_messages: int = 0 @override def command_callback(self) -> CommandCallback: - return AssistantCommandCallback(self.choice.create_stage()) + return AssistantCommandCallback( + self.choice.create_stage(), self.addon_name_mapping + ) @override def on_state(self, request: str, response: str): diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index 194a13f..a1f8e22 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -67,13 +67,13 @@ def build(self, **kwargs) -> Template: {{request_format}} ## Commands -{%- if tools %} +{%- if addons %} * run-addon This command executes a specified addon to address a one-time task described in natural language. Addons do not see current conversation and require all details to be provided in the query to solve the task. Arguments: - NAME is one of the following addons: -{%- for name, description in tools.items() %} +{%- for name, description in addons.items() %} * {{name}} - {{description | decap}} {%- endfor %} - QUERY is the query string. @@ -117,7 +117,7 @@ def build(self, **kwargs) -> Template: You were allowed to use the following addons to answer the query below. === ADDONS === -{% for name, description in tools.items() %} +{% for name, description in addons.items() %} * {{name}} - {{description | decap}} {%- endfor %} diff --git a/aidial_assistant/utils/open_ai_plugin.py b/aidial_assistant/utils/open_ai_plugin.py index 1f41e03..822d197 100644 --- a/aidial_assistant/utils/open_ai_plugin.py +++ b/aidial_assistant/utils/open_ai_plugin.py @@ -1,5 +1,5 @@ import logging -from typing import Mapping +from typing import Iterable, Mapping from urllib.parse import urljoin from aiocache import cached @@ -45,7 +45,7 @@ class OpenAIPluginInfo(BaseModel): class AddonTokenSource: - def __init__(self, headers: Mapping[str, str], urls: list[str]): + def __init__(self, headers: Mapping[str, str], urls: Iterable[str]): self.headers = headers self.urls = { url: f"x-addon-token-{index}" for index, url in enumerate(urls) diff --git a/tests/unit_tests/application/test_prompts.py b/tests/unit_tests/application/test_prompts.py index 5600e6f..2b8058c 100644 --- a/tests/unit_tests/application/test_prompts.py +++ b/tests/unit_tests/application/test_prompts.py @@ -6,7 +6,7 @@ def test_main_best_effort_prompt(): actual = MAIN_BEST_EFFORT_TEMPLATE.build( - tools={"tool name": "Tool description"} + addons={"addon name": "Addon description"} ).render( error="", message="", @@ -19,7 +19,7 @@ def test_main_best_effort_prompt(): === ADDONS === -* tool name - tool description +* addon name - addon description === QUERY === @@ -38,7 +38,7 @@ def test_main_best_effort_prompt(): def test_main_best_effort_prompt_with_empty_dialogue(): actual = MAIN_BEST_EFFORT_TEMPLATE.build( - tools={"tool name": "Tool description"} + addons={"addon name": "Addon description"} ).render( error="", message="", @@ -51,7 +51,7 @@ def test_main_best_effort_prompt_with_empty_dialogue(): === ADDONS === -* tool name - tool description +* addon name - addon description === QUERY ===