Skip to content

Commit

Permalink
feat: migrate latest fixes (#18)
Browse files Browse the repository at this point in the history
* Display json error responses from addons.

(cherry picked from commit 757bf49cf3a82da49a0d4c468df6e988c436b774)

* Remove logger from __init__.py.

(cherry picked from commit 50518302f9bbef7fdcfa3b6f31f141267af5ffb8)

* Propagate reason length

(cherry picked from commit 53b88642c257d26252eb667ca36f830b76ea5046)

* Fix openai error handling.

(cherry picked from commit 3c7fc3a6017ac0871eabc64e60a9d3675baf2c96)
  • Loading branch information
Oleksii-Klimov authored Nov 14, 2023
1 parent 4afb0eb commit d73fd86
Show file tree
Hide file tree
Showing 20 changed files with 202 additions and 225 deletions.
3 changes: 0 additions & 3 deletions aidial_assistant/application/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
import logging

logger = logging.getLogger(__name__)
37 changes: 26 additions & 11 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import logging
from pathlib import Path

from aidial_sdk import HTTPException
from aidial_sdk.chat_completion import FinishReason
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Addon, Request
from aidial_sdk.chat_completion.response import Response
from aiohttp import hdrs
from openai import InvalidRequestError, OpenAIError

from aidial_assistant.application import logger
from aidial_assistant.application.args import parse_args
from aidial_assistant.application.prompts import (
MAIN_SYSTEM_DIALOG_MESSAGE,
RESP_DIALOG_PROMPT,
)
from aidial_assistant.application.server_callback import ServerChainCallback
from aidial_assistant.chain.command_chain import CommandChain, CommandDict
from aidial_assistant.chain.model_client import ModelClient, UsagePublisher
from aidial_assistant.chain.model_client import (
ModelClient,
ReasonLengthException,
UsagePublisher,
)
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.utils.open_ai_plugin import (
Expand All @@ -25,6 +30,8 @@
)
from aidial_assistant.utils.state import get_system_prefix, parse_history

logger = logging.getLogger(__name__)


def get_request_args(request: Request) -> dict[str, str]:
args = {
Expand Down Expand Up @@ -112,19 +119,27 @@ async def chat_completion(
request.messages,
system_message,
)
with response.create_single_choice() as choice:
callback = ServerChainCallback(choice)
try:
await chain.run_chat(history, callback, usage_publisher)
except OpenAIError as e:
logger.exception("Request processing has failed.")
choice = response.create_single_choice()
choice.open()

callback = ServerChainCallback(choice)
finish_reason = FinishReason.STOP
try:
await chain.run_chat(history, callback, usage_publisher)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH
except OpenAIError as e:
if e.error:
raise HTTPException(
str(e),
e.error.message,
status_code=e.http_status or 500,
code=e.code,
code=e.error.code,
)

choice.set_state(callback.state)
raise

choice.set_state(callback.state)
choice.close(finish_reason)

response.set_usage(
usage_publisher.prompt_tokens, usage_publisher.completion_tokens
Expand Down
3 changes: 0 additions & 3 deletions aidial_assistant/chain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
import logging

logger = logging.getLogger(__name__)
40 changes: 15 additions & 25 deletions aidial_assistant/chain/command_chain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import logging
from typing import Any, AsyncIterator, Callable, List

from aidial_sdk.chat_completion.request import Message, Role
from jinja2 import Template

from aidial_assistant.chain import logger
from aidial_assistant.chain.callbacks.chain_callback import ChainCallback
from aidial_assistant.chain.callbacks.command_callback import CommandCallback
from aidial_assistant.chain.callbacks.result_callback import ResultCallback
Expand All @@ -19,15 +19,15 @@
CommandsReader,
)
from aidial_assistant.commands.base import Command, FinalCommand
from aidial_assistant.json_stream.json_node import (
JsonNode,
JsonParsingException,
)
from aidial_assistant.json_stream.exceptions import JsonParsingException
from aidial_assistant.json_stream.json_node import JsonNode
from aidial_assistant.json_stream.json_object import JsonObject
from aidial_assistant.json_stream.json_parser import JsonParser
from aidial_assistant.json_stream.json_string import JsonString
from aidial_assistant.json_stream.tokenator import AsyncPeekable, Tokenator

logger = logging.getLogger(__name__)

MAX_MESSAGE_COUNT = 20
MAX_RETRY_COUNT = 2

Expand Down Expand Up @@ -70,7 +70,7 @@ async def run_chat(
history: List[Message],
callback: ChainCallback,
usage_publisher: UsagePublisher,
) -> str:
):
for message in history[:-1]:
self._log_message(message.role, message.content)

Expand All @@ -95,20 +95,17 @@ async def run_chat(
if isinstance(command, FinalCommand):
if len(responses) > 0:
continue
arg = await anext(args)
result = await CommandChain._to_result(
arg
if isinstance(arg, JsonString)
else arg.to_string_tokens(),
message = await anext(args)
await CommandChain._to_result(
message
if isinstance(message, JsonString)
else message.to_string_tokens(),
# Some relatively large number to avoid CxSAST warning about potential DoS attack.
# Later, the upper limit will be provided by the DIAL Core (proxy).
32000,
callback.result_callback(),
)
self._log_message(
Role.ASSISTANT, json.dumps(root_node.value())
)
return result
return
else:
response = await CommandChain._execute_command(
command_name, command, args, callback
Expand All @@ -118,11 +115,7 @@ async def run_chat(
responses.append(response)

if len(responses) == 0:
# Assume the model has nothing to say
self._log_message(
Role.ASSISTANT, json.dumps(root_node.value())
)
return ""
return

normalized_model_response = json.dumps({"commands": commands})
history.append(
Expand Down Expand Up @@ -205,19 +198,16 @@ async def _to_result(
arg: AsyncIterator[str],
max_model_completion_tokens: int,
callback: ResultCallback,
) -> str:
result = ""
):
try:
for _ in range(max_model_completion_tokens):
token = await anext(arg)
callback.on_result(token)
result += token
logger.warn(
logger.warning(
f"Max token count of {max_model_completion_tokens} exceeded in the reply"
)
except StopAsyncIteration:
pass
return result

@staticmethod
async def _execute_command(
Expand Down
10 changes: 9 additions & 1 deletion aidial_assistant/chain/model_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from aiohttp import ClientSession


class ReasonLengthException(Exception):
pass


class UsagePublisher:
def __init__(self):
self.total_usage = defaultdict(int)
Expand Down Expand Up @@ -55,6 +59,10 @@ async def agenerate(
if usage:
usage_publisher.publish(usage)

text = chunk["choices"][0]["delta"].get("content")
choice = chunk["choices"][0]
text = choice["delta"].get("content")
if text:
yield text

if choice.get("finish_reason") == "length":
raise ReasonLengthException()
11 changes: 10 additions & 1 deletion aidial_assistant/commands/plugin_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,15 @@ def on_result(self, token):
class PluginChainCallback(ChainCallback):
def __init__(self, callback: Callable[[str], None]):
self.callback = callback
self._result = ""

@override
def command_callback(self) -> PluginCommandCallback:
return PluginCommandCallback(self.callback)

@override
def result_callback(self) -> ResultCallback:
return PluginResultCallback(self.callback)
return PluginResultCallback(self._on_result)

@override
def on_state(self, request: str, response: str):
Expand All @@ -79,3 +80,11 @@ def on_state(self, request: str, response: str):
@override
def on_error(self, title: str, error: Exception):
pass

@property
def result(self) -> str:
return self._result

def _on_result(self, token):
self._result += token
self.callback(token)
21 changes: 12 additions & 9 deletions aidial_assistant/commands/run_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
CommandChain,
CommandConstructor,
)
from aidial_assistant.chain.model_client import ModelClient, UsagePublisher
from aidial_assistant.chain.model_client import (
ModelClient,
ReasonLengthException,
UsagePublisher,
)
from aidial_assistant.commands.base import (
Command,
ExecutionCallback,
JsonResult,
ResultObject,
TextResult,
)
from aidial_assistant.commands.open_api import OpenAPIChatCommand
from aidial_assistant.commands.plugin_callback import PluginChainCallback
Expand Down Expand Up @@ -111,10 +115,9 @@ def create_command(op: APIOperation):
command_dict=command_dict,
)

return JsonResult(
await chat.run_chat(
init_messages,
PluginChainCallback(execution_callback),
usage_publisher,
)
)
callback = PluginChainCallback(execution_callback)
try:
await chat.run_chat(init_messages, callback, usage_publisher)
return TextResult(callback.result)
except ReasonLengthException:
return TextResult(callback.result)
3 changes: 0 additions & 3 deletions aidial_assistant/json_stream/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
import logging

logger = logging.getLogger(__name__)
10 changes: 10 additions & 0 deletions aidial_assistant/json_stream/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class JsonParsingException(Exception):
pass


def unexpected_symbol_error(
char: str, char_position: int
) -> JsonParsingException:
return JsonParsingException(
f"Failed to parse json string: unexpected symbol {char} at position {char_position}"
)
65 changes: 30 additions & 35 deletions aidial_assistant/json_stream/json_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from typing_extensions import override

from aidial_assistant.json_stream.exceptions import unexpected_symbol_error
from aidial_assistant.json_stream.json_node import (
ComplexNode,
JsonNode,
NodeResolver,
unexpected_symbol_error,
)
from aidial_assistant.json_stream.json_normalizer import JsonNormalizer
from aidial_assistant.json_stream.tokenator import Tokenator
Expand All @@ -17,7 +17,7 @@
class JsonArray(ComplexNode[list[Any]], AsyncIterator[JsonNode]):
def __init__(self, char_position: int):
super().__init__(char_position)
self.listener = Queue[JsonNode | None | BaseException]()
self.listener = Queue[JsonNode | None]()
self.array: list[JsonNode] = []

@override
Expand All @@ -34,7 +34,7 @@ def __aiter__(self) -> AsyncIterator[JsonNode]:

@override
async def __anext__(self) -> JsonNode:
result = ComplexNode.throw_if_exception(await self.listener.get())
result = await self.listener.get()
if result is None:
raise StopAsyncIteration

Expand All @@ -43,38 +43,33 @@ async def __anext__(self) -> JsonNode:

@override
async def parse(self, stream: Tokenator, dependency_resolver: NodeResolver):
try:
normalised_stream = JsonNormalizer(stream)
char = await anext(normalised_stream)
self._char_position = stream.char_position
if not char == JsonArray.token():
raise unexpected_symbol_error(char, stream.char_position)

separate = False
while True:
char = await normalised_stream.apeek()
if char == "]":
await anext(normalised_stream)
break

if char == ",":
if not separate:
raise unexpected_symbol_error(
char, stream.char_position
)

await anext(normalised_stream)
separate = False
else:
value = await dependency_resolver.resolve(stream)
await self.listener.put(value)
if isinstance(value, ComplexNode):
await value.parse(stream, dependency_resolver)
separate = True

await self.listener.put(None)
except BaseException as e:
await self.listener.put(e)
normalised_stream = JsonNormalizer(stream)
char = await anext(normalised_stream)
self._char_position = stream.char_position
if not char == JsonArray.token():
raise unexpected_symbol_error(char, stream.char_position)

separate = False
while True:
char = await normalised_stream.apeek()
if char == "]":
await anext(normalised_stream)
break

if char == ",":
if not separate:
raise unexpected_symbol_error(char, stream.char_position)

await anext(normalised_stream)
separate = False
else:
value = await dependency_resolver.resolve(stream)
await self.listener.put(value)
if isinstance(value, ComplexNode):
await value.parse(stream, dependency_resolver)
separate = True

await self.listener.put(None)

@override
async def to_string_tokens(self) -> AsyncIterator[str]:
Expand Down
Loading

0 comments on commit d73fd86

Please sign in to comment.