From d73fd86a721bc7bba7896a5a6c9edb51821aa119 Mon Sep 17 00:00:00 2001 From: Oleksii-Klimov <133792808+Oleksii-Klimov@users.noreply.github.com> Date: Tue, 14 Nov 2023 15:29:49 +0000 Subject: [PATCH] feat: migrate latest fixes (#18) * Display json error responses from addons. (cherry picked from commit 757bf49cf3a82da49a0d4c468df6e988c436b774) * Remove logger from __init__.py. (cherry picked from commit 50518302f9bbef7fdcfa3b6f31f141267af5ffb8) * Propagate reason length (cherry picked from commit 53b88642c257d26252eb667ca36f830b76ea5046) * Fix openai error handling. (cherry picked from commit 3c7fc3a6017ac0871eabc64e60a9d3675baf2c96) --- aidial_assistant/application/__init__.py | 3 - .../application/assistant_application.py | 37 ++++++--- aidial_assistant/chain/__init__.py | 3 - aidial_assistant/chain/command_chain.py | 40 ++++----- aidial_assistant/chain/model_client.py | 10 ++- aidial_assistant/commands/plugin_callback.py | 11 ++- aidial_assistant/commands/run_plugin.py | 21 +++-- aidial_assistant/json_stream/__init__.py | 3 - aidial_assistant/json_stream/exceptions.py | 10 +++ aidial_assistant/json_stream/json_array.py | 65 +++++++-------- aidial_assistant/json_stream/json_node.py | 24 ------ aidial_assistant/json_stream/json_null.py | 6 +- aidial_assistant/json_stream/json_object.py | 83 +++++++++---------- aidial_assistant/json_stream/json_parser.py | 5 ++ aidial_assistant/json_stream/json_root.py | 13 +-- aidial_assistant/json_stream/json_string.py | 57 +++++-------- aidial_assistant/open_api/__init__.py | 3 - aidial_assistant/open_api/requester.py | 26 +++--- aidial_assistant/utils/__init__.py | 3 - aidial_assistant/utils/open_ai_plugin.py | 4 +- 20 files changed, 202 insertions(+), 225 deletions(-) create mode 100644 aidial_assistant/json_stream/exceptions.py diff --git a/aidial_assistant/application/__init__.py b/aidial_assistant/application/__init__.py index eea436a..e69de29 100644 --- a/aidial_assistant/application/__init__.py +++ b/aidial_assistant/application/__init__.py @@ -1,3 +0,0 @@ -import logging - -logger = logging.getLogger(__name__) diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index 45c9235..60bfea4 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -1,13 +1,14 @@ +import logging from pathlib import Path from aidial_sdk import HTTPException +from aidial_sdk.chat_completion import FinishReason from aidial_sdk.chat_completion.base import ChatCompletion from aidial_sdk.chat_completion.request import Addon, Request from aidial_sdk.chat_completion.response import Response from aiohttp import hdrs from openai import InvalidRequestError, OpenAIError -from aidial_assistant.application import logger from aidial_assistant.application.args import parse_args from aidial_assistant.application.prompts import ( MAIN_SYSTEM_DIALOG_MESSAGE, @@ -15,7 +16,11 @@ ) from aidial_assistant.application.server_callback import ServerChainCallback from aidial_assistant.chain.command_chain import CommandChain, CommandDict -from aidial_assistant.chain.model_client import ModelClient, UsagePublisher +from aidial_assistant.chain.model_client import ( + ModelClient, + ReasonLengthException, + UsagePublisher, +) from aidial_assistant.commands.reply import Reply from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin from aidial_assistant.utils.open_ai_plugin import ( @@ -25,6 +30,8 @@ ) from aidial_assistant.utils.state import get_system_prefix, parse_history +logger = logging.getLogger(__name__) + def get_request_args(request: Request) -> dict[str, str]: args = { @@ -112,19 +119,27 @@ async def chat_completion( request.messages, system_message, ) - with response.create_single_choice() as choice: - callback = ServerChainCallback(choice) - try: - await chain.run_chat(history, callback, usage_publisher) - except OpenAIError as e: - logger.exception("Request processing has failed.") + choice = response.create_single_choice() + choice.open() + + callback = ServerChainCallback(choice) + finish_reason = FinishReason.STOP + try: + await chain.run_chat(history, callback, usage_publisher) + except ReasonLengthException: + finish_reason = FinishReason.LENGTH + except OpenAIError as e: + if e.error: raise HTTPException( - str(e), + e.error.message, status_code=e.http_status or 500, - code=e.code, + code=e.error.code, ) - choice.set_state(callback.state) + raise + + choice.set_state(callback.state) + choice.close(finish_reason) response.set_usage( usage_publisher.prompt_tokens, usage_publisher.completion_tokens diff --git a/aidial_assistant/chain/__init__.py b/aidial_assistant/chain/__init__.py index eea436a..e69de29 100644 --- a/aidial_assistant/chain/__init__.py +++ b/aidial_assistant/chain/__init__.py @@ -1,3 +0,0 @@ -import logging - -logger = logging.getLogger(__name__) diff --git a/aidial_assistant/chain/command_chain.py b/aidial_assistant/chain/command_chain.py index 90b910f..da089c8 100644 --- a/aidial_assistant/chain/command_chain.py +++ b/aidial_assistant/chain/command_chain.py @@ -1,10 +1,10 @@ import json +import logging from typing import Any, AsyncIterator, Callable, List from aidial_sdk.chat_completion.request import Message, Role from jinja2 import Template -from aidial_assistant.chain import logger from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.command_callback import CommandCallback from aidial_assistant.chain.callbacks.result_callback import ResultCallback @@ -19,15 +19,15 @@ CommandsReader, ) from aidial_assistant.commands.base import Command, FinalCommand -from aidial_assistant.json_stream.json_node import ( - JsonNode, - JsonParsingException, -) +from aidial_assistant.json_stream.exceptions import JsonParsingException +from aidial_assistant.json_stream.json_node import JsonNode from aidial_assistant.json_stream.json_object import JsonObject from aidial_assistant.json_stream.json_parser import JsonParser from aidial_assistant.json_stream.json_string import JsonString from aidial_assistant.json_stream.tokenator import AsyncPeekable, Tokenator +logger = logging.getLogger(__name__) + MAX_MESSAGE_COUNT = 20 MAX_RETRY_COUNT = 2 @@ -70,7 +70,7 @@ async def run_chat( history: List[Message], callback: ChainCallback, usage_publisher: UsagePublisher, - ) -> str: + ): for message in history[:-1]: self._log_message(message.role, message.content) @@ -95,20 +95,17 @@ async def run_chat( if isinstance(command, FinalCommand): if len(responses) > 0: continue - arg = await anext(args) - result = await CommandChain._to_result( - arg - if isinstance(arg, JsonString) - else arg.to_string_tokens(), + message = await anext(args) + await CommandChain._to_result( + message + if isinstance(message, JsonString) + else message.to_string_tokens(), # Some relatively large number to avoid CxSAST warning about potential DoS attack. # Later, the upper limit will be provided by the DIAL Core (proxy). 32000, callback.result_callback(), ) - self._log_message( - Role.ASSISTANT, json.dumps(root_node.value()) - ) - return result + return else: response = await CommandChain._execute_command( command_name, command, args, callback @@ -118,11 +115,7 @@ async def run_chat( responses.append(response) if len(responses) == 0: - # Assume the model has nothing to say - self._log_message( - Role.ASSISTANT, json.dumps(root_node.value()) - ) - return "" + return normalized_model_response = json.dumps({"commands": commands}) history.append( @@ -205,19 +198,16 @@ async def _to_result( arg: AsyncIterator[str], max_model_completion_tokens: int, callback: ResultCallback, - ) -> str: - result = "" + ): try: for _ in range(max_model_completion_tokens): token = await anext(arg) callback.on_result(token) - result += token - logger.warn( + logger.warning( f"Max token count of {max_model_completion_tokens} exceeded in the reply" ) except StopAsyncIteration: pass - return result @staticmethod async def _execute_command( diff --git a/aidial_assistant/chain/model_client.py b/aidial_assistant/chain/model_client.py index 2eed550..d04d348 100644 --- a/aidial_assistant/chain/model_client.py +++ b/aidial_assistant/chain/model_client.py @@ -7,6 +7,10 @@ from aiohttp import ClientSession +class ReasonLengthException(Exception): + pass + + class UsagePublisher: def __init__(self): self.total_usage = defaultdict(int) @@ -55,6 +59,10 @@ async def agenerate( if usage: usage_publisher.publish(usage) - text = chunk["choices"][0]["delta"].get("content") + choice = chunk["choices"][0] + text = choice["delta"].get("content") if text: yield text + + if choice.get("finish_reason") == "length": + raise ReasonLengthException() diff --git a/aidial_assistant/commands/plugin_callback.py b/aidial_assistant/commands/plugin_callback.py index 13f0b89..dfc7b6d 100644 --- a/aidial_assistant/commands/plugin_callback.py +++ b/aidial_assistant/commands/plugin_callback.py @@ -62,6 +62,7 @@ def on_result(self, token): class PluginChainCallback(ChainCallback): def __init__(self, callback: Callable[[str], None]): self.callback = callback + self._result = "" @override def command_callback(self) -> PluginCommandCallback: @@ -69,7 +70,7 @@ def command_callback(self) -> PluginCommandCallback: @override def result_callback(self) -> ResultCallback: - return PluginResultCallback(self.callback) + return PluginResultCallback(self._on_result) @override def on_state(self, request: str, response: str): @@ -79,3 +80,11 @@ def on_state(self, request: str, response: str): @override def on_error(self, title: str, error: Exception): pass + + @property + def result(self) -> str: + return self._result + + def _on_result(self, token): + self._result += token + self.callback(token) diff --git a/aidial_assistant/commands/run_plugin.py b/aidial_assistant/commands/run_plugin.py index fdaaed1..5ccf1a9 100644 --- a/aidial_assistant/commands/run_plugin.py +++ b/aidial_assistant/commands/run_plugin.py @@ -13,12 +13,16 @@ CommandChain, CommandConstructor, ) -from aidial_assistant.chain.model_client import ModelClient, UsagePublisher +from aidial_assistant.chain.model_client import ( + ModelClient, + ReasonLengthException, + UsagePublisher, +) from aidial_assistant.commands.base import ( Command, ExecutionCallback, - JsonResult, ResultObject, + TextResult, ) from aidial_assistant.commands.open_api import OpenAPIChatCommand from aidial_assistant.commands.plugin_callback import PluginChainCallback @@ -111,10 +115,9 @@ def create_command(op: APIOperation): command_dict=command_dict, ) - return JsonResult( - await chat.run_chat( - init_messages, - PluginChainCallback(execution_callback), - usage_publisher, - ) - ) + callback = PluginChainCallback(execution_callback) + try: + await chat.run_chat(init_messages, callback, usage_publisher) + return TextResult(callback.result) + except ReasonLengthException: + return TextResult(callback.result) diff --git a/aidial_assistant/json_stream/__init__.py b/aidial_assistant/json_stream/__init__.py index eea436a..e69de29 100644 --- a/aidial_assistant/json_stream/__init__.py +++ b/aidial_assistant/json_stream/__init__.py @@ -1,3 +0,0 @@ -import logging - -logger = logging.getLogger(__name__) diff --git a/aidial_assistant/json_stream/exceptions.py b/aidial_assistant/json_stream/exceptions.py new file mode 100644 index 0000000..e4675fc --- /dev/null +++ b/aidial_assistant/json_stream/exceptions.py @@ -0,0 +1,10 @@ +class JsonParsingException(Exception): + pass + + +def unexpected_symbol_error( + char: str, char_position: int +) -> JsonParsingException: + return JsonParsingException( + f"Failed to parse json string: unexpected symbol {char} at position {char_position}" + ) diff --git a/aidial_assistant/json_stream/json_array.py b/aidial_assistant/json_stream/json_array.py index fe2c812..a6fe1c9 100644 --- a/aidial_assistant/json_stream/json_array.py +++ b/aidial_assistant/json_stream/json_array.py @@ -4,11 +4,11 @@ from typing_extensions import override +from aidial_assistant.json_stream.exceptions import unexpected_symbol_error from aidial_assistant.json_stream.json_node import ( ComplexNode, JsonNode, NodeResolver, - unexpected_symbol_error, ) from aidial_assistant.json_stream.json_normalizer import JsonNormalizer from aidial_assistant.json_stream.tokenator import Tokenator @@ -17,7 +17,7 @@ class JsonArray(ComplexNode[list[Any]], AsyncIterator[JsonNode]): def __init__(self, char_position: int): super().__init__(char_position) - self.listener = Queue[JsonNode | None | BaseException]() + self.listener = Queue[JsonNode | None]() self.array: list[JsonNode] = [] @override @@ -34,7 +34,7 @@ def __aiter__(self) -> AsyncIterator[JsonNode]: @override async def __anext__(self) -> JsonNode: - result = ComplexNode.throw_if_exception(await self.listener.get()) + result = await self.listener.get() if result is None: raise StopAsyncIteration @@ -43,38 +43,33 @@ async def __anext__(self) -> JsonNode: @override async def parse(self, stream: Tokenator, dependency_resolver: NodeResolver): - try: - normalised_stream = JsonNormalizer(stream) - char = await anext(normalised_stream) - self._char_position = stream.char_position - if not char == JsonArray.token(): - raise unexpected_symbol_error(char, stream.char_position) - - separate = False - while True: - char = await normalised_stream.apeek() - if char == "]": - await anext(normalised_stream) - break - - if char == ",": - if not separate: - raise unexpected_symbol_error( - char, stream.char_position - ) - - await anext(normalised_stream) - separate = False - else: - value = await dependency_resolver.resolve(stream) - await self.listener.put(value) - if isinstance(value, ComplexNode): - await value.parse(stream, dependency_resolver) - separate = True - - await self.listener.put(None) - except BaseException as e: - await self.listener.put(e) + normalised_stream = JsonNormalizer(stream) + char = await anext(normalised_stream) + self._char_position = stream.char_position + if not char == JsonArray.token(): + raise unexpected_symbol_error(char, stream.char_position) + + separate = False + while True: + char = await normalised_stream.apeek() + if char == "]": + await anext(normalised_stream) + break + + if char == ",": + if not separate: + raise unexpected_symbol_error(char, stream.char_position) + + await anext(normalised_stream) + separate = False + else: + value = await dependency_resolver.resolve(stream) + await self.listener.put(value) + if isinstance(value, ComplexNode): + await value.parse(stream, dependency_resolver) + separate = True + + await self.listener.put(None) @override async def to_string_tokens(self) -> AsyncIterator[str]: diff --git a/aidial_assistant/json_stream/json_node.py b/aidial_assistant/json_stream/json_node.py index 22f6d8c..5af0da4 100644 --- a/aidial_assistant/json_stream/json_node.py +++ b/aidial_assistant/json_stream/json_node.py @@ -7,18 +7,6 @@ from aidial_assistant.json_stream.tokenator import Tokenator -class JsonParsingException(Exception): - pass - - -def unexpected_symbol_error( - char: str, char_position: int -) -> JsonParsingException: - return JsonParsingException( - f"Failed to parse json string: unexpected symbol {char} at position {char_position}" - ) - - class NodeResolver(ABC): @abstractmethod async def resolve(self, stream: Tokenator) -> "JsonNode": @@ -57,18 +45,6 @@ def __init__(self, char_position: int): async def parse(self, stream: Tokenator, dependency_resolver: NodeResolver): pass - @staticmethod - def throw_if_exception(entry): - if isinstance(entry, StopAsyncIteration): - raise JsonParsingException( - "Failed to parse json: unexpected end of stream." - ) - - if isinstance(entry, BaseException): - raise entry - - return entry - class PrimitiveNode(JsonNode[T], ABC, Generic[T]): @abstractmethod diff --git a/aidial_assistant/json_stream/json_null.py b/aidial_assistant/json_stream/json_null.py index 584a9c4..94757ab 100644 --- a/aidial_assistant/json_stream/json_null.py +++ b/aidial_assistant/json_stream/json_null.py @@ -1,9 +1,7 @@ from typing_extensions import override -from aidial_assistant.json_stream.json_node import ( - PrimitiveNode, - unexpected_symbol_error, -) +from aidial_assistant.json_stream.exceptions import unexpected_symbol_error +from aidial_assistant.json_stream.json_node import PrimitiveNode NULL_STRING = "null" diff --git a/aidial_assistant/json_stream/json_object.py b/aidial_assistant/json_stream/json_object.py index a964dac..792f359 100644 --- a/aidial_assistant/json_stream/json_object.py +++ b/aidial_assistant/json_stream/json_object.py @@ -5,11 +5,11 @@ from typing_extensions import override +from aidial_assistant.json_stream.exceptions import unexpected_symbol_error from aidial_assistant.json_stream.json_node import ( ComplexNode, JsonNode, NodeResolver, - unexpected_symbol_error, ) from aidial_assistant.json_stream.json_normalizer import JsonNormalizer from aidial_assistant.json_stream.json_string import JsonString @@ -22,7 +22,7 @@ class JsonObject( ): def __init__(self, char_position: int): super().__init__(char_position) - self.listener = Queue[Tuple[str, JsonNode] | None | BaseException]() + self.listener = Queue[Tuple[str, JsonNode] | None]() self._object: dict[str, JsonNode] = {} @override @@ -34,7 +34,7 @@ def __aiter__(self) -> AsyncIterator[Tuple[str, JsonNode]]: @override async def __anext__(self) -> Tuple[str, JsonNode]: - result = ComplexNode.throw_if_exception(await self.listener.get()) + result = await self.listener.get() if result is None: raise StopAsyncIteration @@ -57,52 +57,43 @@ async def get(self, key: str) -> JsonNode: @override async def parse(self, stream: Tokenator, dependency_resolver: NodeResolver): - try: - normalised_stream = JsonNormalizer(stream) - char = await anext(normalised_stream) - if not char == JsonObject.token(): - raise unexpected_symbol_error(char, stream.char_position) + normalised_stream = JsonNormalizer(stream) + char = await anext(normalised_stream) + if not char == JsonObject.token(): + raise unexpected_symbol_error(char, stream.char_position) + + separate = False + while True: + char = await normalised_stream.apeek() + + if char == "}": + await normalised_stream.askip() + break - separate = False - while True: - char = await normalised_stream.apeek() - - if char == "}": - await normalised_stream.askip() - break - - if char == ",": - if not separate: - raise unexpected_symbol_error( - char, stream.char_position - ) - - await normalised_stream.askip() - separate = False - elif char == '"': - if separate: - raise unexpected_symbol_error( - char, stream.char_position - ) - - key = await join_string(JsonString.read(stream)) - colon = await anext(normalised_stream) - if not colon == ":": - raise unexpected_symbol_error( - colon, stream.char_position - ) - - value = await dependency_resolver.resolve(stream) - await self.listener.put((key, value)) - if isinstance(value, ComplexNode): - await value.parse(stream, dependency_resolver) - separate = True - else: + if char == ",": + if not separate: raise unexpected_symbol_error(char, stream.char_position) - await self.listener.put(None) - except BaseException as e: - await self.listener.put(e) + await normalised_stream.askip() + separate = False + elif char == '"': + if separate: + raise unexpected_symbol_error(char, stream.char_position) + + key = await join_string(JsonString.read(stream)) + colon = await anext(normalised_stream) + if not colon == ":": + raise unexpected_symbol_error(colon, stream.char_position) + + value = await dependency_resolver.resolve(stream) + await self.listener.put((key, value)) + if isinstance(value, ComplexNode): + await value.parse(stream, dependency_resolver) + separate = True + else: + raise unexpected_symbol_error(char, stream.char_position) + + await self.listener.put(None) @override async def to_string_tokens(self) -> AsyncIterator[str]: diff --git a/aidial_assistant/json_stream/json_parser.py b/aidial_assistant/json_stream/json_parser.py index 27d876d..48c1f83 100644 --- a/aidial_assistant/json_stream/json_parser.py +++ b/aidial_assistant/json_stream/json_parser.py @@ -2,6 +2,7 @@ from contextlib import asynccontextmanager from typing import Any, AsyncGenerator +from aidial_assistant.json_stream.exceptions import JsonParsingException from aidial_assistant.json_stream.json_array import JsonArray from aidial_assistant.json_stream.json_node import ComplexNode, JsonNode from aidial_assistant.json_stream.json_object import JsonObject @@ -60,6 +61,10 @@ async def _parse_root(root: JsonRoot, stream: Tokenator): node = await root.node() if isinstance(node, ComplexNode): await node.parse(stream, node_resolver) + except StopAsyncIteration: + raise JsonParsingException( + "Failed to parse json: unexpected end of stream." + ) finally: # flush the stream async for _ in stream: diff --git a/aidial_assistant/json_stream/json_root.py b/aidial_assistant/json_stream/json_root.py index 24cc4c0..d889c15 100644 --- a/aidial_assistant/json_stream/json_root.py +++ b/aidial_assistant/json_stream/json_root.py @@ -3,6 +3,7 @@ from typing_extensions import override +from aidial_assistant.json_stream.exceptions import unexpected_symbol_error from aidial_assistant.json_stream.json_array import JsonArray from aidial_assistant.json_stream.json_bool import JsonBoolean from aidial_assistant.json_stream.json_node import ( @@ -10,7 +11,6 @@ JsonNode, NodeResolver, PrimitiveNode, - unexpected_symbol_error, ) from aidial_assistant.json_stream.json_normalizer import JsonNormalizer from aidial_assistant.json_stream.json_null import JsonNull @@ -52,7 +52,7 @@ async def resolve(self, stream: Tokenator) -> JsonNode: class JsonRoot(ComplexNode[Any]): def __init__(self): super().__init__(0) - self._node: JsonNode | BaseException | None = None + self._node: JsonNode | None = None self._event = asyncio.Event() async def node(self) -> JsonNode: @@ -61,7 +61,7 @@ async def node(self) -> JsonNode: # Should never happen raise Exception("Node was not parsed") - return ComplexNode.throw_if_exception(self._node) + return self._node @override def type(self) -> str: @@ -71,8 +71,6 @@ def type(self) -> str: async def parse(self, stream: Tokenator, dependency_resolver: NodeResolver): try: self._node = await dependency_resolver.resolve(stream) - except BaseException as e: - self._node = e finally: self._event.set() @@ -84,7 +82,4 @@ async def to_string_tokens(self) -> AsyncIterator[str]: @override def value(self) -> Any: - if isinstance(self._node, JsonNode): - return self._node.value() - - return None + return self._node.value() if self._node else None diff --git a/aidial_assistant/json_stream/json_string.py b/aidial_assistant/json_stream/json_string.py index fcb45e4..d724d9f 100644 --- a/aidial_assistant/json_stream/json_string.py +++ b/aidial_assistant/json_stream/json_string.py @@ -4,18 +4,15 @@ from typing_extensions import override -from aidial_assistant.json_stream.json_node import ( - ComplexNode, - NodeResolver, - unexpected_symbol_error, -) +from aidial_assistant.json_stream.exceptions import unexpected_symbol_error +from aidial_assistant.json_stream.json_node import ComplexNode, NodeResolver from aidial_assistant.json_stream.tokenator import Tokenator class JsonString(ComplexNode[str], AsyncIterator[str]): def __init__(self, char_position: int): super().__init__(char_position) - self._listener = Queue[str | None | BaseException]() + self._listener = Queue[str | None]() self._buffer = "" @override @@ -31,7 +28,7 @@ def __aiter__(self) -> AsyncIterator[str]: @override async def __anext__(self) -> str: - result = ComplexNode.throw_if_exception(await self._listener.get()) + result = await self._listener.get() if result is None: raise StopAsyncIteration @@ -40,12 +37,9 @@ async def __anext__(self) -> str: @override async def parse(self, stream: Tokenator, dependency_resolver: NodeResolver): - try: - async for token in JsonString.read(stream): - await self._listener.put(token) - await self._listener.put(None) - except BaseException as e: - await self._listener.put(e) + async for token in JsonString.read(stream): + await self._listener.put(token) + await self._listener.put(None) @override async def to_string_tokens(self) -> AsyncIterator[str]: @@ -56,29 +50,24 @@ async def to_string_tokens(self) -> AsyncIterator[str]: @staticmethod async def read(stream: Tokenator) -> AsyncIterator[str]: - try: + char = await anext(stream) + if not char == JsonString.token(): + raise unexpected_symbol_error(char, stream.char_position) + result = "" + token_position = stream.token_position + while True: char = await anext(stream) - if not char == JsonString.token(): - raise unexpected_symbol_error(char, stream.char_position) - result = "" - token_position = stream.token_position - while True: - char = await anext(stream) - if char == JsonString.token(): - break - - result += ( - await JsonString.escape(stream) if char == "\\" else char - ) - if token_position != stream.token_position: - yield result - result = "" - token_position = stream.token_position - - if result: + if char == JsonString.token(): + break + + result += await JsonString.escape(stream) if char == "\\" else char + if token_position != stream.token_position: yield result - except StopAsyncIteration: - pass + result = "" + token_position = stream.token_position + + if result: + yield result @staticmethod async def escape(stream: Tokenator) -> str: diff --git a/aidial_assistant/open_api/__init__.py b/aidial_assistant/open_api/__init__.py index eea436a..e69de29 100644 --- a/aidial_assistant/open_api/__init__.py +++ b/aidial_assistant/open_api/__init__.py @@ -1,3 +0,0 @@ -import logging - -logger = logging.getLogger(__name__) diff --git a/aidial_assistant/open_api/requester.py b/aidial_assistant/open_api/requester.py index 80864ea..b46665e 100644 --- a/aidial_assistant/open_api/requester.py +++ b/aidial_assistant/open_api/requester.py @@ -1,13 +1,16 @@ import json +import logging from typing import Dict, List, NamedTuple, Optional +import aiohttp.client_exceptions from aiohttp import hdrs from langchain.tools.openapi.utils.api_models import APIOperation from aidial_assistant.commands.base import JsonResult, ResultObject, TextResult -from aidial_assistant.open_api import logger from aidial_assistant.utils.requests import arequest +logger = logging.getLogger(__name__) + class _ParamMapping(NamedTuple): """Mapping from parameter name to parameter value.""" @@ -87,15 +90,18 @@ async def execute( self.operation.method.value, headers=headers, **request_args # type: ignore ) as response: if response.status != 200: - method_str = str(self.operation.method.value) # type: ignore - error_object = { - "reason": response.reason, - "status_code": response.status, - "method:": method_str.upper(), - "url": request_args["url"], - "params": request_args["params"], - } - return JsonResult(json.dumps(error_object)) + try: + return JsonResult(json.dumps(await response.json())) + except aiohttp.ContentTypeError: + method_str = str(self.operation.method.value) # type: ignore + error_object = { + "reason": response.reason, + "status_code": response.status, + "method:": method_str.upper(), + "url": request_args["url"], + "params": request_args["params"], + } + return JsonResult(json.dumps(error_object)) if "text" in response.headers[hdrs.CONTENT_TYPE]: return TextResult(await response.text()) diff --git a/aidial_assistant/utils/__init__.py b/aidial_assistant/utils/__init__.py index e3927dd..e69de29 100644 --- a/aidial_assistant/utils/__init__.py +++ b/aidial_assistant/utils/__init__.py @@ -1,3 +0,0 @@ -import logging.config - -logger = logging.getLogger(__name__) diff --git a/aidial_assistant/utils/open_ai_plugin.py b/aidial_assistant/utils/open_ai_plugin.py index 12c8d6c..1f41e03 100644 --- a/aidial_assistant/utils/open_ai_plugin.py +++ b/aidial_assistant/utils/open_ai_plugin.py @@ -1,3 +1,4 @@ +import logging from typing import Mapping from urllib.parse import urljoin @@ -8,9 +9,10 @@ from pydantic import BaseModel, parse_obj_as from starlette.status import HTTP_401_UNAUTHORIZED -from aidial_assistant.utils import logger from aidial_assistant.utils.requests import aget +logger = logging.getLogger(__name__) + class AuthConf(BaseModel): type: str