From 7ef502a33e87313a343c6d39100fad3030376f2b Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Sun, 9 Jun 2024 12:11:36 -0700 Subject: [PATCH] Refactor chat (#60) --- examples/assistant.py | 5 ++-- examples/futures.py | 20 +++++++------- src/agnext/chat/agents/__init__.py | 4 +++ src/agnext/chat/agents/base.py | 14 ---------- .../chat/agents/chat_completion_agent.py | 4 +-- src/agnext/chat/agents/oai_assistant.py | 9 +++---- src/agnext/chat/patterns/group_chat.py | 16 ++++++------ src/agnext/chat/patterns/orchestrator_chat.py | 11 ++++---- src/agnext/chat/patterns/two_agent_chat.py | 10 +++---- src/agnext/chat/types.py | 4 +-- src/agnext/chat/utils.py | 14 +++++----- src/agnext/components/_type_routed_agent.py | 8 +++--- src/agnext/core/_agent.py | 9 ++++++- src/agnext/core/_base_agent.py | 12 ++++++--- tests/test_cancellation.py | 20 +++++++------- tests/test_intervention.py | 26 +++++++++---------- tests/test_runtime.py | 8 +++--- tests/test_state.py | 8 +++--- 18 files changed, 99 insertions(+), 103 deletions(-) delete mode 100644 src/agnext/chat/agents/base.py diff --git a/examples/assistant.py b/examples/assistant.py index 7b82583cc341..676239b39f81 100644 --- a/examples/assistant.py +++ b/examples/assistant.py @@ -10,7 +10,6 @@ import aiofiles import openai from agnext.application import SingleThreadedAgentRuntime -from agnext.chat.agents.base import BaseChatAgent from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent from agnext.chat.patterns.group_chat import GroupChatOutput from agnext.chat.patterns.two_agent_chat import TwoAgentChat @@ -38,7 +37,7 @@ def reset(self) -> None: sep = "-" * 50 -class UserProxyAgent(BaseChatAgent, TypeRoutedAgent): # type: ignore +class UserProxyAgent(TypeRoutedAgent): # type: ignore def __init__( self, name: str, @@ -52,7 +51,7 @@ def __init__( name=name, description="A human user", runtime=runtime, - ) + ) # type: ignore self._client = client self._assistant_id = assistant_id self._thread_id = thread_id diff --git a/examples/futures.py b/examples/futures.py index 765fabfdce50..84325b4cb79c 100644 --- a/examples/futures.py +++ b/examples/futures.py @@ -12,22 +12,22 @@ class MessageType: sender: str -class Inner(TypeRoutedAgent): - def __init__(self, name: str, router: AgentRuntime) -> None: - super().__init__(name, router) +class Inner(TypeRoutedAgent): # type: ignore + def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore + super().__init__(name, "The inner agent", router) - @message_handler() - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + @message_handler() # type: ignore + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore return MessageType(body=f"Inner: {message.body}", sender=self.name) -class Outer(TypeRoutedAgent): - def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None: - super().__init__(name, router) +class Outer(TypeRoutedAgent): # type: ignore + def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None: # type: ignore + super().__init__(name, "The outter agent", router) self._inner = inner - @message_handler() - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + @message_handler() # type: ignore + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore inner_response = self._send_message(message, self._inner) inner_message = await inner_response assert isinstance(inner_message, MessageType) diff --git a/src/agnext/chat/agents/__init__.py b/src/agnext/chat/agents/__init__.py index e69de29bb2d1..5846b5a20cb5 100644 --- a/src/agnext/chat/agents/__init__.py +++ b/src/agnext/chat/agents/__init__.py @@ -0,0 +1,4 @@ +from .chat_completion_agent import ChatCompletionAgent +from .oai_assistant import OpenAIAssistantAgent + +__all__ = ["ChatCompletionAgent", "OpenAIAssistantAgent"] diff --git a/src/agnext/chat/agents/base.py b/src/agnext/chat/agents/base.py deleted file mode 100644 index 5ee3605287b0..000000000000 --- a/src/agnext/chat/agents/base.py +++ /dev/null @@ -1,14 +0,0 @@ -from agnext.core import AgentRuntime, BaseAgent - - -class BaseChatAgent(BaseAgent): - """The BaseAgent class for the chat API.""" - - def __init__(self, name: str, description: str, runtime: AgentRuntime) -> None: - super().__init__(name, runtime) - self._description = description - - @property - def description(self) -> str: - """The description of the agent.""" - return self._description diff --git a/src/agnext/chat/agents/chat_completion_agent.py b/src/agnext/chat/agents/chat_completion_agent.py index a6a45e035554..c46d78d572f2 100644 --- a/src/agnext/chat/agents/chat_completion_agent.py +++ b/src/agnext/chat/agents/chat_completion_agent.py @@ -25,10 +25,9 @@ TextMessage, ) from ..utils import convert_messages_to_llm_messages -from .base import BaseChatAgent -class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent): +class ChatCompletionAgent(TypeRoutedAgent): def __init__( self, name: str, @@ -40,6 +39,7 @@ def __init__( tools: Sequence[Tool] = [], ) -> None: super().__init__(name, description, runtime) + self._description = description self._system_messages = system_messages self._client = model_client self._memory = memory diff --git a/src/agnext/chat/agents/oai_assistant.py b/src/agnext/chat/agents/oai_assistant.py index 7a75059e1a8a..bd569f663c0e 100644 --- a/src/agnext/chat/agents/oai_assistant.py +++ b/src/agnext/chat/agents/oai_assistant.py @@ -4,13 +4,12 @@ from openai import AsyncAssistantEventHandler from openai.types.beta import AssistantResponseFormatParam -from agnext.chat.agents.base import BaseChatAgent -from agnext.chat.types import Reset, RespondNow, ResponseFormat, TextMessage -from agnext.components import TypeRoutedAgent, message_handler -from agnext.core import AgentRuntime, CancellationToken +from ...components import TypeRoutedAgent, message_handler +from ...core import AgentRuntime, CancellationToken +from ..types import Reset, RespondNow, ResponseFormat, TextMessage -class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent): +class OpenAIAssistantAgent(TypeRoutedAgent): def __init__( self, name: str, diff --git a/src/agnext/chat/patterns/group_chat.py b/src/agnext/chat/patterns/group_chat.py index 33d3c8c68273..55e8bd07dd55 100644 --- a/src/agnext/chat/patterns/group_chat.py +++ b/src/agnext/chat/patterns/group_chat.py @@ -1,8 +1,7 @@ from typing import Any, List, Protocol, Sequence from ...components import TypeRoutedAgent, message_handler -from ...core import AgentRuntime, CancellationToken -from ..agents.base import BaseChatAgent +from ...core import Agent, AgentRuntime, CancellationToken from ..types import Reset, RespondNow, TextMessage @@ -14,17 +13,18 @@ def get_output(self) -> Any: ... def reset(self) -> None: ... -class GroupChat(BaseChatAgent, TypeRoutedAgent): +class GroupChat(TypeRoutedAgent): def __init__( self, name: str, description: str, runtime: AgentRuntime, - agents: Sequence[BaseChatAgent], + participants: Sequence[Agent], num_rounds: int, output: GroupChatOutput, ) -> None: - self._agents = agents + self._description = description + self._participants = participants self._num_rounds = num_rounds self._history: List[Any] = [] self._output = output @@ -32,7 +32,7 @@ def __init__( @property def subscriptions(self) -> Sequence[type]: - agent_sublists = [agent.subscriptions for agent in self._agents] + agent_sublists = [agent.subscriptions for agent in self._participants] return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist] @message_handler() @@ -55,10 +55,10 @@ async def on_text_message(self, message: TextMessage, cancellation_token: Cancel while round < self._num_rounds: # TODO: add support for advanced speaker selection. # Select speaker (round-robin for now). - speaker = self._agents[round % len(self._agents)] + speaker = self._participants[round % len(self._participants)] # Send the last message to all agents except the previous speaker. - for agent in [agent for agent in self._agents if agent is not prev_speaker]: + for agent in [agent for agent in self._participants if agent is not prev_speaker]: # TODO gather and await _ = await self._send_message( self._history[-1], diff --git a/src/agnext/chat/patterns/orchestrator_chat.py b/src/agnext/chat/patterns/orchestrator_chat.py index 6bc9999ffcd9..427750332f9f 100644 --- a/src/agnext/chat/patterns/orchestrator_chat.py +++ b/src/agnext/chat/patterns/orchestrator_chat.py @@ -2,22 +2,21 @@ from typing import Any, Sequence, Tuple from ...components import TypeRoutedAgent, message_handler -from ...core import AgentRuntime, CancellationToken -from ..agents.base import BaseChatAgent +from ...core import Agent, AgentRuntime, CancellationToken from ..types import Reset, RespondNow, ResponseFormat, TextMessage __all__ = ["OrchestratorChat"] -class OrchestratorChat(BaseChatAgent, TypeRoutedAgent): +class OrchestratorChat(TypeRoutedAgent): def __init__( self, name: str, description: str, runtime: AgentRuntime, - orchestrator: BaseChatAgent, - planner: BaseChatAgent, - specialists: Sequence[BaseChatAgent], + orchestrator: Agent, + planner: Agent, + specialists: Sequence[Agent], max_turns: int = 30, max_stalled_turns_before_retry: int = 2, max_retry_attempts: int = 1, diff --git a/src/agnext/chat/patterns/two_agent_chat.py b/src/agnext/chat/patterns/two_agent_chat.py index a1d45a804e43..69cf54648de7 100644 --- a/src/agnext/chat/patterns/two_agent_chat.py +++ b/src/agnext/chat/patterns/two_agent_chat.py @@ -1,7 +1,5 @@ -from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput - -from ...core import AgentRuntime -from ..agents.base import BaseChatAgent +from ...core import Agent, AgentRuntime +from .group_chat import GroupChat, GroupChatOutput # TODO: rewrite this with a new message type calling for add to message @@ -12,8 +10,8 @@ def __init__( name: str, description: str, runtime: AgentRuntime, - first_speaker: BaseChatAgent, - second_speaker: BaseChatAgent, + first_speaker: Agent, + second_speaker: Agent, num_rounds: int, output: GroupChatOutput, ) -> None: diff --git a/src/agnext/chat/types.py b/src/agnext/chat/types.py index 308f135db446..a338d85b32dc 100644 --- a/src/agnext/chat/types.py +++ b/src/agnext/chat/types.py @@ -4,8 +4,8 @@ from enum import Enum from typing import List, Union -from agnext.components import FunctionCall, Image -from agnext.components.models import FunctionExecutionResultMessage +from ..components import FunctionCall, Image +from ..components.models import FunctionExecutionResultMessage @dataclass(kw_only=True) diff --git a/src/agnext/chat/utils.py b/src/agnext/chat/utils.py index cfee683a77c5..d07924666fcd 100644 --- a/src/agnext/chat/utils.py +++ b/src/agnext/chat/utils.py @@ -2,19 +2,19 @@ from typing_extensions import Literal -from agnext.chat.types import ( - FunctionCallMessage, - Message, - MultiModalMessage, - TextMessage, -) -from agnext.components.models import ( +from ..components.models import ( AssistantMessage, FunctionExecutionResult, FunctionExecutionResultMessage, LLMMessage, UserMessage, ) +from .types import ( + FunctionCallMessage, + Message, + MultiModalMessage, + TextMessage, +) def convert_content_message_to_assistant_message( diff --git a/src/agnext/components/_type_routed_agent.py b/src/agnext/components/_type_routed_agent.py index 88bbdb35f872..2d62bc735a4c 100644 --- a/src/agnext/components/_type_routed_agent.py +++ b/src/agnext/components/_type_routed_agent.py @@ -21,8 +21,8 @@ runtime_checkable, ) -from agnext.core import AgentRuntime, BaseAgent, CancellationToken -from agnext.core.exceptions import CantHandleException +from ..core import AgentRuntime, BaseAgent, CancellationToken +from ..core.exceptions import CantHandleException logger = logging.getLogger("agnext") @@ -132,7 +132,7 @@ async def wrapper(self: Any, message: ReceivesT, cancellation_token: Cancellatio class TypeRoutedAgent(BaseAgent): - def __init__(self, name: str, router: AgentRuntime) -> None: + def __init__(self, name: str, description: str, runtime: AgentRuntime) -> None: # Self is already bound to the handlers self._handlers: Dict[ Type[Any], @@ -147,7 +147,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None: for target_type in message_handler.target_types: self._handlers[target_type] = message_handler - super().__init__(name, router) + super().__init__(name, description, runtime) @property def subscriptions(self) -> Sequence[Type[Any]]: diff --git a/src/agnext/core/_agent.py b/src/agnext/core/_agent.py index 8869669f7ca8..c4189e683f58 100644 --- a/src/agnext/core/_agent.py +++ b/src/agnext/core/_agent.py @@ -1,6 +1,6 @@ from typing import Any, Mapping, Protocol, Sequence, runtime_checkable -from agnext.core._cancellation_token import CancellationToken +from ._cancellation_token import CancellationToken @runtime_checkable @@ -14,6 +14,13 @@ def name(self) -> str: """ ... + @property + def description(self) -> str: + """Description of the agent. + + A human-readable description of the agent.""" + ... + @property def subscriptions(self) -> Sequence[type]: """Types of messages that this agent can receive.""" diff --git a/src/agnext/core/_base_agent.py b/src/agnext/core/_base_agent.py index f570f2b3a8b2..7b3aa74f1c63 100644 --- a/src/agnext/core/_base_agent.py +++ b/src/agnext/core/_base_agent.py @@ -3,10 +3,9 @@ from asyncio import Future from typing import Any, Mapping, Sequence, TypeVar -from agnext.core._agent_runtime import AgentRuntime -from agnext.core._cancellation_token import CancellationToken - from ._agent import Agent +from ._agent_runtime import AgentRuntime +from ._cancellation_token import CancellationToken ConsumesT = TypeVar("ConsumesT") ProducesT = TypeVar("ProducesT", covariant=True) @@ -16,8 +15,9 @@ class BaseAgent(ABC, Agent): - def __init__(self, name: str, router: AgentRuntime) -> None: + def __init__(self, name: str, description: str, router: AgentRuntime) -> None: self._name = name + self._description = description self._router = router router.add_agent(self) @@ -25,6 +25,10 @@ def __init__(self, name: str, router: AgentRuntime) -> None: def name(self) -> str: return self._name + @property + def description(self) -> str: + return self._description + @property @abstractmethod def subscriptions(self) -> Sequence[type]: diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 2f888554f725..ab0a61bc62ab 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -15,14 +15,14 @@ class MessageType: # To do cancellation, only the token should be interacted with as a user # If you cancel a future, it may not work as you expect. -class LongRunningAgent(TypeRoutedAgent): - def __init__(self, name: str, router: AgentRuntime) -> None: - super().__init__(name, router) +class LongRunningAgent(TypeRoutedAgent): # type: ignore + def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore + super().__init__(name, "A long running agent", router) self.called = False self.cancelled = False - @message_handler() - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + @message_handler() # type: ignore + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore self.called = True sleep = asyncio.ensure_future(asyncio.sleep(100)) cancellation_token.link_future(sleep) @@ -33,15 +33,15 @@ async def on_new_message(self, message: MessageType, cancellation_token: Cancell self.cancelled = True raise -class NestingLongRunningAgent(TypeRoutedAgent): - def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None: - super().__init__(name, router) +class NestingLongRunningAgent(TypeRoutedAgent): # type: ignore + def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None: # type: ignore + super().__init__(name, "A nesting long running agent", router) self.called = False self.cancelled = False self._nested_agent = nested_agent - @message_handler() - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + @message_handler() # type: ignore + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore self.called = True response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token) try: diff --git a/tests/test_intervention.py b/tests/test_intervention.py index b0ccd83f9aba..c23c415bd834 100644 --- a/tests/test_intervention.py +++ b/tests/test_intervention.py @@ -12,25 +12,25 @@ class MessageType: ... -class LoopbackAgent(TypeRoutedAgent): - def __init__(self, name: str, router: AgentRuntime) -> None: - super().__init__(name, router) +class LoopbackAgent(TypeRoutedAgent): # type: ignore + def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore + super().__init__(name, "A loop back agent.", router) self.num_calls = 0 - @message_handler() - async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: + @message_handler() # type: ignore + async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore self.num_calls += 1 return message @pytest.mark.asyncio async def test_intervention_count_messages() -> None: - class DebugInterventionHandler(DefaultInterventionHandler): + class DebugInterventionHandler(DefaultInterventionHandler): # type: ignore def __init__(self) -> None: self.num_messages = 0 - async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType: + async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType: # type: ignore self.num_messages += 1 return message @@ -49,9 +49,9 @@ async def on_send(self, message: MessageType, *, sender: Agent | None, recipient @pytest.mark.asyncio async def test_intervention_drop_send() -> None: - class DropSendInterventionHandler(DefaultInterventionHandler): - async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType | type[DropMessage]: - return DropMessage + class DropSendInterventionHandler(DefaultInterventionHandler): # type: ignore + async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType | type[DropMessage]: # type: ignore + return DropMessage # type: ignore handler = DropSendInterventionHandler() router = SingleThreadedAgentRuntime(before_send=handler) @@ -71,9 +71,9 @@ async def on_send(self, message: MessageType, *, sender: Agent | None, recipient @pytest.mark.asyncio async def test_intervention_drop_response() -> None: - class DropResponseInterventionHandler(DefaultInterventionHandler): - async def on_response(self, message: MessageType, *, sender: Agent, recipient: Agent | None) -> MessageType | type[DropMessage]: - return DropMessage + class DropResponseInterventionHandler(DefaultInterventionHandler): # type: ignore + async def on_response(self, message: MessageType, *, sender: Agent, recipient: Agent | None) -> MessageType | type[DropMessage]: # type: ignore + return DropMessage # type: ignore handler = DropResponseInterventionHandler() router = SingleThreadedAgentRuntime(before_send=handler) diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 5acf11ed5694..5f2edaafd50c 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -5,15 +5,15 @@ from agnext.core import AgentRuntime, BaseAgent, CancellationToken -class NoopAgent(BaseAgent): - def __init__(self, name: str, router: AgentRuntime) -> None: - super().__init__(name, router) +class NoopAgent(BaseAgent): # type: ignore + def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore + super().__init__(name, "A no op agent", router) @property def subscriptions(self) -> Sequence[type]: return [] - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore raise NotImplementedError diff --git a/tests/test_state.py b/tests/test_state.py index 52e6bff4c29e..63c170468d9f 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -5,16 +5,16 @@ from agnext.core import AgentRuntime, BaseAgent, CancellationToken -class StatefulAgent(BaseAgent): - def __init__(self, name: str, runtime: AgentRuntime) -> None: - super().__init__(name, runtime) +class StatefulAgent(BaseAgent): # type: ignore + def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore + super().__init__(name, "A stateful agent", runtime) self.state = 0 @property def subscriptions(self) -> Sequence[type]: return [] - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: + async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # type: ignore raise NotImplementedError def save_state(self) -> Mapping[str, Any]: