Skip to content

Commit

Permalink
feat: support addon name override via request (#47)
Browse files Browse the repository at this point in the history
* Support addon name override via request.
  • Loading branch information
Oleksii-Klimov authored Jan 5, 2024
1 parent 0b08e47 commit 810afc2
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 50 deletions.
74 changes: 43 additions & 31 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Addon, Message, Request, Role
from aidial_sdk.chat_completion.response import Response
from pydantic import BaseModel

from aidial_assistant.application.addons_dialogue_limiter import (
AddonsDialogueLimiter,
Expand Down Expand Up @@ -39,6 +40,11 @@
logger = logging.getLogger(__name__)


class AddonReference(BaseModel):
name: str | None
url: str


def _get_request_args(request: Request) -> dict[str, str]:
args = {
"model": request.model,
Expand All @@ -51,14 +57,18 @@ def _get_request_args(request: Request) -> dict[str, str]:
return {k: v for k, v in args.items() if v is not None}


def _validate_addons(addons: list[Addon] | None):
if addons and any(addon.url is None for addon in addons):
for index, addon in enumerate(addons):
if addon.url is None:
raise RequestParameterValidationError(
f"Missing required addon url at index {index}.",
param="addons",
)
def _validate_addons(addons: list[Addon] | None) -> list[AddonReference]:
addon_references: list[AddonReference] = []
for index, addon in enumerate(addons or []):
if addon.url is None:
raise RequestParameterValidationError(
f"Missing required addon url at index {index}.",
param="addons",
)

addon_references.append(AddonReference(name=addon.name, url=addon.url))

return addon_references


def _validate_messages(messages: list[Message]) -> None:
Expand All @@ -73,11 +83,6 @@ def _validate_messages(messages: list[Message]) -> None:
)


def _validate_request(request: Request) -> None:
_validate_messages(request.messages)
_validate_addons(request.addons)


class AssistantApplication(ChatCompletion):
def __init__(self, config_dir: Path):
self.args = parse_args(config_dir)
Expand All @@ -86,7 +91,8 @@ def __init__(self, config_dir: Path):
async def chat_completion(
self, request: Request, response: Response
) -> None:
_validate_request(request)
_validate_messages(request.messages)
addon_references = _validate_addons(request.addons)
chat_args = self.args.openai_conf.dict() | _get_request_args(request)

model = ModelClient(
Expand All @@ -99,47 +105,53 @@ async def chat_completion(
buffer_size=self.args.chat_conf.buffer_size,
)

addons: list[str] = (
[addon.url for addon in request.addons] if request.addons else [] # type: ignore
token_source = AddonTokenSource(
request.headers,
(addon_reference.url for addon_reference in addon_references),
)
token_source = AddonTokenSource(request.headers, addons)

tools: dict[str, PluginInfo] = {}
tool_descriptions: dict[str, str] = {}
for addon in addons:
info = await get_open_ai_plugin_info(addon)
tools[info.ai_plugin.name_for_model] = PluginInfo(
addons: dict[str, PluginInfo] = {}
# DIAL Core has own names for addons, so in stages we need to map them to the names used by the user
addon_name_mapping: dict[str, str] = {}
for addon_reference in addon_references:
info = await get_open_ai_plugin_info(addon_reference.url)
addons[info.ai_plugin.name_for_model] = PluginInfo(
info=info,
auth=get_plugin_auth(
info.ai_plugin.auth.type,
info.ai_plugin.auth.authorization_type,
addon,
addon_reference.url,
token_source,
),
)

tool_descriptions[info.ai_plugin.name_for_model] = (
info.open_api.info.description # type: ignore
or info.ai_plugin.description_for_human
)
if addon_reference.name:
addon_name_mapping[
info.ai_plugin.name_for_model
] = addon_reference.name

# TODO: Add max_addons_dialogue_tokens as a request parameter
max_addons_dialogue_tokens = 1000
command_dict: CommandDict = {
RunPlugin.token(): lambda: RunPlugin(
model, tools, max_addons_dialogue_tokens
model, addons, max_addons_dialogue_tokens
),
Reply.token(): Reply,
}
chain = CommandChain(
model_client=model, name="ASSISTANT", command_dict=command_dict
)
addon_descriptions = {
name: addon.info.open_api.info.description
or addon.info.ai_plugin.description_for_human
for name, addon in addons.items()
}
history = History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
tools=tool_descriptions
addons=addon_descriptions
),
best_effort_template=MAIN_BEST_EFFORT_TEMPLATE.build(
tools=tool_descriptions
addons=addon_descriptions
),
scoped_messages=parse_history(request.messages),
)
Expand All @@ -154,7 +166,7 @@ async def chat_completion(
choice = response.create_single_choice()
choice.open()

callback = AssistantChainCallback(choice)
callback = AssistantChainCallback(choice, addon_name_mapping)
finish_reason = FinishReason.STOP
try:
model_request_limiter = AddonsDialogueLimiter(
Expand Down
42 changes: 32 additions & 10 deletions aidial_assistant/application/assistant_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,37 @@


class PluginNameArgCallback(ArgCallback):
def __init__(self, callback: Callable[[str], None]):
def __init__(
self,
callback: Callable[[str], None],
addon_name_mapping: dict[str, str],
):
super().__init__(0, callback)
self.addon_name_mapping = addon_name_mapping

self._plugin_name = ""

@override
def on_arg(self, chunk: str):
chunk = chunk.replace('"', "")
if len(chunk) > 0:
self.callback(chunk)
self._plugin_name += chunk

@override
def on_arg_end(self):
self.callback("(")
self.callback(
self.addon_name_mapping.get(self._plugin_name, self._plugin_name)
+ "("
)


class RunPluginArgsCallback(ArgsCallback):
def __init__(self, callback: Callable[[str], None]):
def __init__(
self,
callback: Callable[[str], None],
addon_name_mapping: dict[str, str],
):
super().__init__(callback)
self.addon_name_mapping = addon_name_mapping

@override
def on_args_start(self):
Expand All @@ -43,20 +57,24 @@ def on_args_start(self):
def arg_callback(self) -> ArgCallback:
self.arg_index += 1
if self.arg_index == 0:
return PluginNameArgCallback(self.callback)
return PluginNameArgCallback(self.callback, self.addon_name_mapping)
else:
return ArgCallback(self.arg_index - 1, self.callback)


class AssistantCommandCallback(CommandCallback):
def __init__(self, stage: Stage):
def __init__(self, stage: Stage, addon_name_mapping: dict[str, str]):
self.stage = stage
self.addon_name_mapping = addon_name_mapping

self._args_callback = ArgsCallback(self._on_stage_name)

@override
def on_command(self, command: str):
if command == RunPlugin.token():
self._args_callback = RunPluginArgsCallback(self._on_stage_name)
self._args_callback = RunPluginArgsCallback(
self._on_stage_name, self.addon_name_mapping
)
else:
self._on_stage_name(command)

Expand Down Expand Up @@ -109,15 +127,19 @@ def on_result(self, chunk: str):


class AssistantChainCallback(ChainCallback):
def __init__(self, choice: Choice):
def __init__(self, choice: Choice, addon_name_mapping: dict[str, str]):
self.choice = choice
self.addon_name_mapping = addon_name_mapping

self._invocations: list[Invocation] = []
self._invocation_index: int = -1
self._discarded_messages: int = 0

@override
def command_callback(self) -> CommandCallback:
return AssistantCommandCallback(self.choice.create_stage())
return AssistantCommandCallback(
self.choice.create_stage(), self.addon_name_mapping
)

@override
def on_state(self, request: str, response: str):
Expand Down
6 changes: 3 additions & 3 deletions aidial_assistant/application/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ def build(self, **kwargs) -> Template:
{{request_format}}
## Commands
{%- if tools %}
{%- if addons %}
* run-addon
This command executes a specified addon to address a one-time task described in natural language.
Addons do not see current conversation and require all details to be provided in the query to solve the task.
Arguments:
- NAME is one of the following addons:
{%- for name, description in tools.items() %}
{%- for name, description in addons.items() %}
* {{name}} - {{description | decap}}
{%- endfor %}
- QUERY is the query string.
Expand Down Expand Up @@ -117,7 +117,7 @@ def build(self, **kwargs) -> Template:
You were allowed to use the following addons to answer the query below.
=== ADDONS ===
{% for name, description in tools.items() %}
{% for name, description in addons.items() %}
* {{name}} - {{description | decap}}
{%- endfor %}
Expand Down
4 changes: 2 additions & 2 deletions aidial_assistant/utils/open_ai_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Mapping
from typing import Iterable, Mapping
from urllib.parse import urljoin

from aiocache import cached
Expand Down Expand Up @@ -45,7 +45,7 @@ class OpenAIPluginInfo(BaseModel):


class AddonTokenSource:
def __init__(self, headers: Mapping[str, str], urls: list[str]):
def __init__(self, headers: Mapping[str, str], urls: Iterable[str]):
self.headers = headers
self.urls = {
url: f"x-addon-token-{index}" for index, url in enumerate(urls)
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/application/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def test_main_best_effort_prompt():
actual = MAIN_BEST_EFFORT_TEMPLATE.build(
tools={"tool name": "Tool description"}
addons={"addon name": "Addon description"}
).render(
error="<error>",
message="<message>",
Expand All @@ -19,7 +19,7 @@ def test_main_best_effort_prompt():
=== ADDONS ===
* tool name - tool description
* addon name - addon description
=== QUERY ===
Expand All @@ -38,7 +38,7 @@ def test_main_best_effort_prompt():

def test_main_best_effort_prompt_with_empty_dialogue():
actual = MAIN_BEST_EFFORT_TEMPLATE.build(
tools={"tool name": "Tool description"}
addons={"addon name": "Addon description"}
).render(
error="<error>",
message="<message>",
Expand All @@ -51,7 +51,7 @@ def test_main_best_effort_prompt_with_empty_dialogue():
=== ADDONS ===
* tool name - tool description
* addon name - addon description
=== QUERY ===
Expand Down

0 comments on commit 810afc2

Please sign in to comment.